import numpy as np
from numpy.typing import NDArray
import matplotlib.pyplot as plt
from pathlib import Path
from astropy.wcs import WCS
from astropy.io import fits
from astropy import units as u
from logging import Logger
import os
[docs]
def quick_image(
image: NDArray[np.integer | np.floating | np.complexfloating],
xvals: NDArray[np.integer | np.floating] = None,
yvals: NDArray[np.integer | np.floating] = None,
data_range: NDArray[np.integer | np.floating] = None,
data_min_abs: float = None,
xrange: NDArray[np.integer | np.floating] = None,
yrange: NDArray[np.integer | np.floating] = None,
data_aspect: float = None,
log: bool = False,
color_profile: str = "log_cut",
xtitle: str = None,
ytitle: str = None,
title: str = None,
cb_title: str = None,
note: str = None,
charsize: int = None,
xlog: bool = False,
ylog: bool = False,
multi_pos: list = None,
start_multi_params: dict = None,
alpha: float = None,
missing_value: int | float | complex = None,
savefile: str = None,
png: bool = False,
eps: bool = False,
pdf: bool = False,
) -> None:
"""
General function to display and/or save a 2D data array as an image with an appropriately
scaled color bar.
Parameters
----------
image : NDArray[np.integer | np.floating | np.complexfloating]
A 2D array of data to be displayed as an image.
The data can be of type int, float, or complex.
xvals : NDArray[np.integer | np.floating], optional
An array of x-axis values, by default None
yvals : NDArray[np.integer | np.floating], optional
An array of y-axis values, by default None
data_range : NDArray[np.integer | np.floating], optional
Min/max color bar range, by default [np.nanmin(image), np.nanmax(image)]
data_min_abs : float, optional
The minimum absolute value for the color bar, by default None
xrange : NDArray[np.integer | np.floating], optional
The indices (or xvals, if provided) to zoom the image, by default None
yrange : NDArray[np.integer | np.floating], optional
The indices (or yvals, if provided) to zoom the image, by default None
data_aspect : int | float, optional
The aspect ratio of y to x, by default None
log : bool, optional
Color bar on logarithmic scale, by default False
color_profile : str, optional
Color bar profiles for logarithmic scaling.
"log_cut", "sym_log", "abs", by default "log_cut"
xtitle : str, optional
The title of the x-axis, by default None
ytitle : str, optional
The title of the x-axis, by default None
title : str, optional
The title of the image, by default None
cb_title : str, optional
The title of the color bar, by default None
note : str, optional
A small note to place on the bottom right of the image, by default None
charsize : int, optional
The size of the font, by default None
xlog : bool, optional
Use logarithmic scale for the x-axis, by default False
ylog : bool, optional
Use logarithmic scale for the y-axis, by default False
multi_pos : list, optional
A list of 4 elements defining the position of the plot in a multi-panel layout, by default None
start_multi_params : dict, optional
Parameters for starting a multi-panel layout, by default None
alpha : float, optional
Transparency for the image, by default None
missing_value : int | float | complex, optional
Exclude value from the color bar, by default None
savefile : str, optional
The save file name, by default None
png : bool, optional
Create a png of the image, by default False
eps : bool, optional
Create an eps of the image, by default False
pdf : bool, optional
Create a pdf of the image, by default False
Returns
-------
None
Displays the image and/or saves it to disk.
"""
# Determine if the output is to be saved to disk
pub = bool(savefile or png or eps or pdf)
# Handle file extension and output format
if pub:
if not (png or eps or pdf):
if savefile:
# Convert savefile to a Path object if it's a string
savefile = Path(savefile) if isinstance(savefile, str) else savefile
extension = savefile.suffix.lower()
if extension == ".eps":
eps = True
elif extension == ".png":
png = True
elif extension == ".pdf":
pdf = True
else:
print("Unrecognized extension, using PNG")
png = True
# Set default savefile if not provided
if not savefile:
savefile = "idl_quick_image"
print(
f"No filename specified for quick_image output. Using {os.getcwd()}/{savefile}"
)
# Ensure only one output format is set
formats_set = sum([png, eps, pdf])
if formats_set > 1:
print("Only one of eps, png, pdf can be set. Defaulting to PNG.")
eps = pdf = False
png = True
# Append the appropriate file extension
if isinstance(savefile, Path):
if png:
savefile = savefile.with_suffix(".png")
elif pdf:
savefile = savefile.with_suffix(".pdf")
elif eps:
savefile = savefile.with_suffix(".eps")
elif isinstance(savefile, str):
if png:
savefile += ".png"
elif pdf:
savefile += ".pdf"
elif eps:
savefile += ".eps"
# Validate the image input
if image is None or not isinstance(image, np.ndarray):
print("Image is undefined or not a valid numpy array.")
return
# Ensure the image is 2D
if image.ndim != 2:
print("Image must be 2-dimensional.")
return
# Handle complex images. Default is to show the real part.
if np.iscomplexobj(image):
print("Image is complex, showing real part.")
image = np.real(image)
# Handle missing values by setting them to NaN
if missing_value is not None:
wh_missing = np.where(image == missing_value)
count_missing = len(wh_missing[0])
if count_missing > 0:
image[wh_missing] = np.nan
missing_color = 0
else:
count_missing = 0
wh_missing = None
missing_color = None
# Validate that 2-value inputs are only 2 values
if data_range is not None:
if not isinstance(data_range, np.ndarray) or len(data_range) != 2:
raise ValueError("data_range must be an array with exactly two values.")
if xrange is not None:
if not isinstance(xrange, np.ndarray) or len(xrange) != 2:
raise ValueError("xrange must be an array with exactly two values.")
if yrange is not None:
if not isinstance(yrange, np.ndarray) or len(yrange) != 2:
raise ValueError("yrange must be an array with exactly two values.")
# Apply logarithmic scaling if set. This modifies the image input directly
# to be logarithmically scaled in the color bar range.
if log:
image, cb_ticks, cb_ticknames = log_color_calc(
data=image,
data_range=data_range,
color_profile=color_profile,
log_cut_val=None,
min_abs=data_min_abs,
count_missing=count_missing,
wh_missing=wh_missing,
missing_color=missing_color,
invert_colorbar=False,
)
else:
# Apply linear scaling by default. This modifies the image input directly
# to be linearly scaled in the color bar range.
if data_range is None:
data_range = [np.nanmin(image), np.nanmax(image)]
data_color_range, data_n_colors = color_range(count_missing=count_missing)
# Scale image data to be in the color range
image = (image - data_range[0]) * (data_n_colors - 1) / (
data_range[1] - data_range[0]
) + data_color_range[0]
print(data_range, data_color_range, data_n_colors)
# Handle out-of-bounds values
wh_low = np.where(image < data_range[0])
if len(wh_low[0]) > 0:
image[wh_low] = data_color_range[0]
wh_high = np.where(image > data_range[1])
if len(wh_high[0]) > 0:
image[wh_high] = data_color_range[1]
# Handle missing values
if missing_value is not None and count_missing > 0:
image[wh_missing] = missing_color
cb_ticks = np.linspace(data_color_range[0], data_color_range[1], num=5)
cb_ticknames = [
f"{tick * (data_range[1] - data_range[0]) / (data_n_colors - 1) + data_range[0]:.2g}"
for tick in cb_ticks
]
print(cb_ticks, cb_ticknames)
# Set up the plot
fig, ax = plt.subplots()
cmap = plt.get_cmap("viridis")
# Set up the x and y ranges
extent = None
if xvals is not None and yvals is not None:
# Default extent based on full xvals and yvals
extent = [xvals[0], xvals[-1], yvals[0], yvals[-1]]
# Apply xrange to crop the image and adjust extent
if xrange is not None:
x_indices = np.logical_and(xvals >= xrange[0], xvals <= xrange[1])
image = image[:, x_indices]
xvals = xvals[x_indices] # Update xvals to match cropped image
extent[0], extent[1] = xrange[0], xrange[1]
# Apply yrange to crop the image and adjust extent
if yrange is not None:
y_indices = np.logical_and(yvals >= yrange[0], yvals <= yrange[1])
image = image[y_indices, :]
yvals = yvals[y_indices] # Update yvals to match cropped image
extent[2], extent[3] = yrange[0], yrange[1]
elif xrange is not None and yrange is not None:
# If xvals and yvals are not provided, use xrange and yrange directly
extent = [xrange[0], xrange[1], yrange[0], yrange[1]]
image = image[np.ix_(yrange, xrange)]
im = ax.imshow(
image,
extent=extent,
aspect=data_aspect or "auto",
cmap=cmap,
vmin=0,
vmax=255,
alpha=alpha,
)
# Add titles and labels
if title:
ax.set_title(title, fontsize=charsize or 12)
if xtitle:
ax.set_xlabel(xtitle, fontsize=charsize or 10)
if ytitle:
ax.set_ylabel(ytitle, fontsize=charsize or 10)
# Handle logarithmic axes
if xlog:
ax.set_xscale("log")
if ylog:
ax.set_yscale("log")
# Add colorbar
cbar = plt.colorbar(im, ax=ax)
# if log:
cbar.set_ticks(cb_ticks)
cbar.set_ticklabels(cb_ticknames)
if cb_title:
cbar.set_label(cb_title, fontsize=charsize or 10)
# Add note if provided
if note:
plt.figtext(
0.99, 0.02, note, horizontalalignment="right", fontsize=charsize or 8
)
# Multi-panel plotting
if multi_pos is not None:
if len(multi_pos) != 4:
raise ValueError(
"multi_pos must be a 4-element list defining the plot position."
)
ax.set_position(multi_pos)
# Handle start_multi_params for multi-panel layout
if start_multi_params is not None:
nrows = start_multi_params.get("nrow", 1)
ncols = start_multi_params.get("ncol", 1)
index = start_multi_params.get("index", 1) - 1 # Convert to 0-based index
ax.set_position(
[
(index % ncols) / ncols,
1 - (index // ncols + 1) / nrows,
1 / ncols,
1 / nrows,
]
)
# Save or show the plot
if pub:
plt.savefig(savefile, dpi=300, bbox_inches="tight")
else:
plt.show()
plt.close(fig)
[docs]
def log_color_calc(
data: NDArray[np.integer | np.floating | np.complexfloating],
data_range: NDArray[np.integer | np.floating] = None,
color_profile: str = "log_cut",
log_cut_val: float = None,
min_abs: float = None,
count_missing: int = None,
wh_missing: NDArray[np.integer] = None,
missing_color: int = None,
invert_colorbar: bool = False,
) -> tuple:
"""
Translated version of log_color_calc from IDL to Python.
Parameters
----------
data : NDArray[np.integer | np.floating | np.complexfloating]
A 2D array of data to be displayed as an image.
The data can be of type int, float, or complex.
data_range : NDArray[np.integer | np.floating], optional
Min/max color bar range, by default [np.nanmin(image), np.nanmax(image)]
color_profile : str, optional
Color bar profiles for logarithmic scaling.
"log_cut", "sym_log", "abs", by default "log_cut"
log_cut_val : int | float, optional
Minimum log value to cut at, by default None
data_min_abs : int | float, optional
The minimum absolute value for the color bar, by default None
count_missing : int, optional
The number of missing values, by default None
wh_missing : NDArray[np.integer], optional
The location of the missing values, by default None
missing_color : int, optional
The index of the color bar for missing values, by default None
invert_colorbar : bool, optional
Invert the color bar, by default False
Returns
-------
data_log_norm : NDArray[np.int | np.float64]
The normalized data array.
cb_ticks : NDArray[np.int | np.float64]
The color bar ticks.
cb_ticknames : NDArray[np.int | np.float64]
The color bar tick names.
"""
# Define valid color profiles
color_profile_enum = ["log_cut", "sym_log", "abs"]
if color_profile not in color_profile_enum:
raise ValueError(
f"Color profile must be one of: {', '.join(color_profile_enum)}"
)
# Handle data_range
if data_range is None:
data_range = [np.nanmin(data), np.nanmax(data)]
else:
if len(data_range) != 2:
raise ValueError("data_range must be a 2-element vector")
if data_range[1] < data_range[0]:
raise ValueError("data_range[0] must be less than data_range[1]")
# Handle sym_log profile constraints
if color_profile == "sym_log" and data_range[0] > 0:
print(
"sym_log profile cannot be selected with an entirely positive data range. Switching to log_cut"
)
color_profile = "log_cut"
data_color_range, data_n_colors = color_range(count_missing=count_missing)
# Handle positive values
wh_pos = np.where(data > 0)
count_pos = len(wh_pos[0])
if count_pos > 0:
min_pos = np.nanmin(data[wh_pos])
elif data_range[0] > 0:
min_pos = data_range[0]
elif data_range[1] > 0:
min_pos = data_range[1] / 10
else:
min_pos = 0.01
# Handle negative values
wh_neg = np.where(data < 0)
count_neg = len(wh_neg[0])
if count_neg > 0:
max_neg = np.nanmax(data[wh_neg])
elif data_range[1] < 0:
max_neg = data_range[1]
else:
max_neg = data_range[0] / 10
# Handle zero values
wh_zero = np.where(data == 0)
count_zero = len(wh_zero[0])
# Handle log_cut color profile
if color_profile == "log_cut":
if data_range[1] < 0:
raise ValueError(
"log_cut color profile will not work for entirely negative arrays."
)
if log_cut_val is None:
if data_range[0] > 0:
log_cut_val = np.log10(data_range[0])
else:
log_cut_val = np.log10(min_pos)
log_data_range = [log_cut_val, np.log10(data_range[1])]
# Handle zero values
if count_zero > 0:
min_pos_color = 2
zero_color = 1
zero_val = log_data_range[0]
else:
min_pos_color = 1
neg_color = 0
neg_val = log_data_range[0]
data_log = np.log10(data)
wh_under = np.where(data < 10**log_cut_val)
if len(wh_under[0]) > 0:
data_log[wh_under] = log_data_range[0]
wh_over = np.where(data_log > log_data_range[1])
if len(wh_over[0]) > 0:
data_log[wh_over] = log_data_range[1]
# Normalize data
data_log_norm = (
(data_log - log_data_range[0])
* (data_n_colors - min_pos_color - 1)
/ (log_data_range[1] - log_data_range[0])
+ data_color_range[0]
+ min_pos_color
)
if count_neg > 0:
data_log_norm[wh_neg] = neg_color
if count_zero > 0:
data_log_norm[wh_zero] = zero_color
elif color_profile == "sym_log":
if data_range[0] >= 0 or data_range[1] <= 0:
raise ValueError(
"sym_log color profile requires both negative and positive values in data_range."
)
# Calculate the minimum absolute value
if min_abs is None:
if count_pos > 0 and count_neg > 0:
min_abs = min(min_pos, abs(max_neg))
elif count_pos > 0:
min_abs = min_pos
elif count_neg > 0:
min_abs = abs(max_neg)
else:
min_abs = 1.0
log_data_range = [np.log10(min_abs), np.log10(data_range[1])]
# Normalize data
data_log_norm = np.zeros_like(data, dtype=float)
wh_pos = np.where(data > 0)
wh_neg = np.where(data < 0)
wh_zero = np.where(data == 0)
midpoint = (data_color_range[1] - data_color_range[0]) // 2
if len(wh_pos[0]) > 0:
data_log_norm[wh_pos] = (
(np.log10(data[wh_pos]) - log_data_range[0])
* (midpoint)
/ (log_data_range[1] - log_data_range[0])
+ data_color_range[0]
+ midpoint
)
if len(wh_neg[0]) > 0:
# Reverse the mapping for negative values
data_log_norm[wh_neg] = (
data_color_range[0]
+ midpoint
- (
(np.log10(abs(data[wh_neg])) - log_data_range[0])
* midpoint
/ (log_data_range[1] - log_data_range[0])
)
)
if len(wh_zero[0]) > 0:
data_log_norm[wh_zero] = data_color_range[0] + midpoint
# Handle out-of-bounds values
wh_under = np.where(data_log_norm < data_color_range[0])
if len(wh_under[0]) > 0:
data_log_norm[wh_under] = data_color_range[0]
wh_over = np.where(data_log_norm > data_color_range[1])
if len(wh_over[0]) > 0:
data_log_norm[wh_over] = data_color_range[1]
# Handle abs color profile
elif color_profile == "abs":
data_abs = np.abs(data)
data_log_norm = (data_abs - data_range[0]) * (data_n_colors - 1) / (
data_range[1] - data_range[0]
) + data_color_range[0]
# Handle out-of-bounds values
wh_under = np.where(data_log_norm < data_color_range[0])
if len(wh_under[0]) > 0:
data_log_norm[wh_under] = data_color_range[0]
wh_over = np.where(data_log_norm > data_color_range[1])
if len(wh_over[0]) > 0:
data_log_norm[wh_over] = data_color_range[1]
# Handle missing values
if count_missing > 0:
data_log_norm[wh_missing] = missing_color
# Handle invert_colorbar option
if invert_colorbar:
data_log_norm = data_color_range[1] - (data_log_norm - data_color_range[0])
# Generate colorbar ticks and tick names
if color_profile == "log_cut":
cb_ticks = np.linspace(data_color_range[0], data_color_range[1], num=5)
cb_ticknames = [
f"{10**(tick * (log_data_range[1] - log_data_range[0]) / (data_n_colors - 1) + log_data_range[0]):.2g}"
for tick in cb_ticks
]
elif color_profile == "sym_log":
pos_ticks = np.linspace(midpoint, data_color_range[1], num=5)
neg_ticks = np.linspace(data_color_range[0], midpoint, num=5)
cb_ticks = np.concatenate([neg_ticks, [midpoint], pos_ticks])
cb_ticknames = (
[
f"-{10**(log_data_range[1] - (tick - data_color_range[0]) * (log_data_range[1] - log_data_range[0]) / midpoint):.2g}"
for tick in neg_ticks
]
+ ["0"]
+ [
f"{10**((tick - midpoint) * (log_data_range[1] - log_data_range[0]) / midpoint + log_data_range[0]):.2g}"
for tick in pos_ticks
]
)
elif color_profile == "abs":
cb_ticks = np.linspace(data_color_range[0], data_color_range[1], num=5)
cb_ticknames = [
f"{tick * (data_range[1] - data_range[0]) / (data_n_colors - 1) + data_range[0]:.2g}"
for tick in cb_ticks
]
return data_log_norm, cb_ticks, cb_ticknames
[docs]
def color_range(count_missing: int = None) -> tuple:
"""
Define the color range for the image data.
Parameters
----------
count_missing : int, optional
Count of missing values, by default None
Returns
-------
tuple
A tuple containing the color range and the number of colors.
"""
# Initialize color range
color_range = [0, 255]
if count_missing > 0:
data_color_range = [1, 255]
else:
data_color_range = color_range
data_n_colors = data_color_range[1] - data_color_range[0] + 1
return data_color_range, data_n_colors
[docs]
def plot_fits_image(
fits_file: str,
output_path: str,
logger: Logger,
title: str = "FITS Image",
) -> None:
"""
Plot a FITS image using Astropy and save it to the specified output directory.
Parameters
----------
fits_file : str
Path to the FITS file.
output_path : str
Path to output image file.
title : str, optional
Title of the plot, by default "FITS Image".
logger : Logger
pyfhd's logger for displaying errors and info to the log files
Returns
-------
None
The function saves the plot to the specified output path.
"""
# Open the FITS file
with fits.open(fits_file) as hdul:
# Get the data from the first extension
data = hdul[0].data
# Check that the data is 2D and non-zero
if data is None or data.ndim != 2:
logger.warning(
f"FITS data must be a 2D array, no image made for {fits_file}."
)
return
if not np.any(data):
logger.warning(
f"FITS data array contains only zeros, no image made for {fits_file}."
)
return
# Get the data from the first extension
header = hdul[0].header
header["CTYPE1"] = "RA---TAN"
header["CTYPE2"] = "DEC--TAN"
# Get units from header
if "BUNIT" not in header:
unit = "Jy/str"
else:
unit = header["BUNIT"]
# Create a WCS object for the image
wcs = WCS(header, relax=True)
# Calculate the extent of the image in degrees
ny, nx = data.shape
x_min, x_max = wcs.wcs_pix2world([0, nx - 1], [0, 0], 0)[0]
y_min, y_max = wcs.wcs_pix2world([0, 0], [0, ny - 1], 0)[1]
x_extent = abs(x_max - x_min) # Extent in degrees along the x-axis
y_extent = abs(y_max - y_min) # Extent in degrees along the y-axis
# Set grid spacing to the extent divided by 4
min_spacing = 2 * u.deg
spacing_x = max(x_extent / 4, min_spacing.value) * u.deg
spacing_y = max(y_extent / 4, min_spacing.value) * u.deg
# Calculate the percentile-based color bar range
percentile_range = (1, 99)
vmin, vmax = np.percentile(data[np.isfinite(data)], percentile_range)
# Create a figure and axis with WCS projection
fig, ax = plt.subplots(subplot_kw={"projection": wcs})
# Plot the image data
im = ax.imshow(
data, origin="lower", cmap="gray", aspect="auto", vmin=vmin, vmax=vmax
)
# Add a WCS-based grid
ax.grid(color="white", ls="--", alpha=0.5)
ax.coords.grid(True, color="white", linestyle="--", alpha=0.5)
ax.coords[0].set_axislabel("Right Ascension (J2000)")
ax.coords[1].set_axislabel("Declination (J2000)")
# Customize tick labels for grid lines with dynamic spacing
ax.coords[0].set_ticks(spacing=spacing_x, color="white", size=8, width=1)
ax.coords[0].set_ticklabel(size=10, exclude_overlapping=True)
ax.coords[1].set_ticks(spacing=spacing_y, color="white", size=8, width=1)
ax.coords[1].set_ticklabel(size=10, exclude_overlapping=True)
# Add colorbar
cbar = plt.colorbar(im, ax=ax, orientation="vertical")
cbar.set_label("Flux density (" + unit + ")")
# Set title
if title:
ax.set_title(title)
elif title is None:
ax.set_title("FITS Image")
# Save the plot to the output path
plt.savefig(output_path, dpi=300)
plt.close(fig)