Coverage for PyFHD/plotting/image.py: 3%

304 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-01 10:58 +0800

1import numpy as np 

2from numpy.typing import NDArray 

3import matplotlib.pyplot as plt 

4from pathlib import Path 

5from astropy.wcs import WCS 

6from astropy.io import fits 

7from astropy import units as u 

8from logging import Logger 

9import os 

10 

11 

12def quick_image( 

13 image: NDArray[np.integer | np.floating | np.complexfloating], 

14 xvals: NDArray[np.integer | np.floating] = None, 

15 yvals: NDArray[np.integer | np.floating] = None, 

16 data_range: NDArray[np.integer | np.floating] = None, 

17 data_min_abs: float = None, 

18 xrange: NDArray[np.integer | np.floating] = None, 

19 yrange: NDArray[np.integer | np.floating] = None, 

20 data_aspect: float = None, 

21 log: bool = False, 

22 color_profile: str = "log_cut", 

23 xtitle: str = None, 

24 ytitle: str = None, 

25 title: str = None, 

26 cb_title: str = None, 

27 note: str = None, 

28 charsize: int = None, 

29 xlog: bool = False, 

30 ylog: bool = False, 

31 multi_pos: list = None, 

32 start_multi_params: dict = None, 

33 alpha: float = None, 

34 missing_value: int | float | complex = None, 

35 savefile: str = None, 

36 png: bool = False, 

37 eps: bool = False, 

38 pdf: bool = False, 

39) -> None: 

40 """ 

41 General function to display and/or save a 2D data array as an image with an appropriately 

42 scaled color bar. 

43 

44 

45 Parameters 

46 ---------- 

47 image : NDArray[np.integer | np.floating | np.complexfloating] 

48 A 2D array of data to be displayed as an image. 

49 The data can be of type int, float, or complex. 

50 xvals : NDArray[np.integer | np.floating], optional 

51 An array of x-axis values, by default None 

52 yvals : NDArray[np.integer | np.floating], optional 

53 An array of y-axis values, by default None 

54 data_range : NDArray[np.integer | np.floating], optional 

55 Min/max color bar range, by default [np.nanmin(image), np.nanmax(image)] 

56 data_min_abs : float, optional 

57 The minimum absolute value for the color bar, by default None 

58 xrange : NDArray[np.integer | np.floating], optional 

59 The indices (or xvals, if provided) to zoom the image, by default None 

60 yrange : NDArray[np.integer | np.floating], optional 

61 The indices (or yvals, if provided) to zoom the image, by default None 

62 data_aspect : int | float, optional 

63 The aspect ratio of y to x, by default None 

64 log : bool, optional 

65 Color bar on logarithmic scale, by default False 

66 color_profile : str, optional 

67 Color bar profiles for logarithmic scaling. 

68 "log_cut", "sym_log", "abs", by default "log_cut" 

69 xtitle : str, optional 

70 The title of the x-axis, by default None 

71 ytitle : str, optional 

72 The title of the x-axis, by default None 

73 title : str, optional 

74 The title of the image, by default None 

75 cb_title : str, optional 

76 The title of the color bar, by default None 

77 note : str, optional 

78 A small note to place on the bottom right of the image, by default None 

79 charsize : int, optional 

80 The size of the font, by default None 

81 xlog : bool, optional 

82 Use logarithmic scale for the x-axis, by default False 

83 ylog : bool, optional 

84 Use logarithmic scale for the y-axis, by default False 

85 multi_pos : list, optional 

86 A list of 4 elements defining the position of the plot in a multi-panel layout, by default None 

87 start_multi_params : dict, optional 

88 Parameters for starting a multi-panel layout, by default None 

89 alpha : float, optional 

90 Transparency for the image, by default None 

91 missing_value : int | float | complex, optional 

92 Exclude value from the color bar, by default None 

93 savefile : str, optional 

94 The save file name, by default None 

95 png : bool, optional 

96 Create a png of the image, by default False 

97 eps : bool, optional 

98 Create an eps of the image, by default False 

99 pdf : bool, optional 

100 Create a pdf of the image, by default False 

101 

102 Returns 

103 ------- 

104 None 

105 Displays the image and/or saves it to disk. 

106 """ 

107 

108 # Determine if the output is to be saved to disk 

109 pub = bool(savefile or png or eps or pdf) 

110 

111 # Handle file extension and output format 

112 if pub: 

113 if not (png or eps or pdf): 

114 if savefile: 

115 # Convert savefile to a Path object if it's a string 

116 savefile = Path(savefile) if isinstance(savefile, str) else savefile 

117 extension = savefile.suffix.lower() 

118 if extension == ".eps": 

119 eps = True 

120 elif extension == ".png": 

121 png = True 

122 elif extension == ".pdf": 

123 pdf = True 

124 else: 

125 print("Unrecognized extension, using PNG") 

126 png = True 

127 

128 # Set default savefile if not provided 

129 if not savefile: 

130 savefile = "idl_quick_image" 

131 print( 

132 f"No filename specified for quick_image output. Using {os.getcwd()}/{savefile}" 

133 ) 

134 

135 # Ensure only one output format is set 

136 formats_set = sum([png, eps, pdf]) 

137 if formats_set > 1: 

138 print("Only one of eps, png, pdf can be set. Defaulting to PNG.") 

139 eps = pdf = False 

140 png = True 

141 

142 # Append the appropriate file extension 

143 if isinstance(savefile, Path): 

144 if png: 

145 savefile = savefile.with_suffix(".png") 

146 elif pdf: 

147 savefile = savefile.with_suffix(".pdf") 

148 elif eps: 

149 savefile = savefile.with_suffix(".eps") 

150 elif isinstance(savefile, str): 

151 if png: 

152 savefile += ".png" 

153 elif pdf: 

154 savefile += ".pdf" 

155 elif eps: 

156 savefile += ".eps" 

157 

158 # Validate the image input 

159 if image is None or not isinstance(image, np.ndarray): 

160 print("Image is undefined or not a valid numpy array.") 

161 return 

162 

163 # Ensure the image is 2D 

164 if image.ndim != 2: 

165 print("Image must be 2-dimensional.") 

166 return 

167 

168 # Handle complex images. Default is to show the real part. 

169 if np.iscomplexobj(image): 

170 print("Image is complex, showing real part.") 

171 image = np.real(image) 

172 

173 # Handle missing values by setting them to NaN 

174 if missing_value is not None: 

175 wh_missing = np.where(image == missing_value) 

176 count_missing = len(wh_missing[0]) 

177 if count_missing > 0: 

178 image[wh_missing] = np.nan 

179 missing_color = 0 

180 else: 

181 count_missing = 0 

182 wh_missing = None 

183 missing_color = None 

184 

185 # Validate that 2-value inputs are only 2 values 

186 if data_range is not None: 

187 if not isinstance(data_range, np.ndarray) or len(data_range) != 2: 

188 raise ValueError("data_range must be an array with exactly two values.") 

189 if xrange is not None: 

190 if not isinstance(xrange, np.ndarray) or len(xrange) != 2: 

191 raise ValueError("xrange must be an array with exactly two values.") 

192 if yrange is not None: 

193 if not isinstance(yrange, np.ndarray) or len(yrange) != 2: 

194 raise ValueError("yrange must be an array with exactly two values.") 

195 

196 # Apply logarithmic scaling if set. This modifies the image input directly 

197 # to be logarithmically scaled in the color bar range. 

198 if log: 

199 image, cb_ticks, cb_ticknames = log_color_calc( 

200 data=image, 

201 data_range=data_range, 

202 color_profile=color_profile, 

203 log_cut_val=None, 

204 min_abs=data_min_abs, 

205 count_missing=count_missing, 

206 wh_missing=wh_missing, 

207 missing_color=missing_color, 

208 invert_colorbar=False, 

209 ) 

210 else: 

211 # Apply linear scaling by default. This modifies the image input directly 

212 # to be linearly scaled in the color bar range. 

213 if data_range is None: 

214 data_range = [np.nanmin(image), np.nanmax(image)] 

215 

216 data_color_range, data_n_colors = color_range(count_missing=count_missing) 

217 

218 # Scale image data to be in the color range 

219 image = (image - data_range[0]) * (data_n_colors - 1) / ( 

220 data_range[1] - data_range[0] 

221 ) + data_color_range[0] 

222 print(data_range, data_color_range, data_n_colors) 

223 

224 # Handle out-of-bounds values 

225 wh_low = np.where(image < data_range[0]) 

226 if len(wh_low[0]) > 0: 

227 image[wh_low] = data_color_range[0] 

228 wh_high = np.where(image > data_range[1]) 

229 if len(wh_high[0]) > 0: 

230 image[wh_high] = data_color_range[1] 

231 

232 # Handle missing values 

233 if missing_value is not None and count_missing > 0: 

234 image[wh_missing] = missing_color 

235 

236 cb_ticks = np.linspace(data_color_range[0], data_color_range[1], num=5) 

237 cb_ticknames = [ 

238 f"{tick * (data_range[1] - data_range[0]) / (data_n_colors - 1) + data_range[0]:.2g}" 

239 for tick in cb_ticks 

240 ] 

241 print(cb_ticks, cb_ticknames) 

242 

243 # Set up the plot 

244 fig, ax = plt.subplots() 

245 cmap = plt.get_cmap("viridis") 

246 

247 # Set up the x and y ranges 

248 extent = None 

249 if xvals is not None and yvals is not None: 

250 # Default extent based on full xvals and yvals 

251 extent = [xvals[0], xvals[-1], yvals[0], yvals[-1]] 

252 # Apply xrange to crop the image and adjust extent 

253 if xrange is not None: 

254 x_indices = np.logical_and(xvals >= xrange[0], xvals <= xrange[1]) 

255 image = image[:, x_indices] 

256 xvals = xvals[x_indices] # Update xvals to match cropped image 

257 extent[0], extent[1] = xrange[0], xrange[1] 

258 # Apply yrange to crop the image and adjust extent 

259 if yrange is not None: 

260 y_indices = np.logical_and(yvals >= yrange[0], yvals <= yrange[1]) 

261 image = image[y_indices, :] 

262 yvals = yvals[y_indices] # Update yvals to match cropped image 

263 extent[2], extent[3] = yrange[0], yrange[1] 

264 elif xrange is not None and yrange is not None: 

265 # If xvals and yvals are not provided, use xrange and yrange directly 

266 extent = [xrange[0], xrange[1], yrange[0], yrange[1]] 

267 image = image[np.ix_(yrange, xrange)] 

268 

269 im = ax.imshow( 

270 image, 

271 extent=extent, 

272 aspect=data_aspect or "auto", 

273 cmap=cmap, 

274 vmin=0, 

275 vmax=255, 

276 alpha=alpha, 

277 ) 

278 

279 # Add titles and labels 

280 if title: 

281 ax.set_title(title, fontsize=charsize or 12) 

282 if xtitle: 

283 ax.set_xlabel(xtitle, fontsize=charsize or 10) 

284 if ytitle: 

285 ax.set_ylabel(ytitle, fontsize=charsize or 10) 

286 

287 # Handle logarithmic axes 

288 if xlog: 

289 ax.set_xscale("log") 

290 if ylog: 

291 ax.set_yscale("log") 

292 

293 # Add colorbar 

294 cbar = plt.colorbar(im, ax=ax) 

295 # if log: 

296 cbar.set_ticks(cb_ticks) 

297 cbar.set_ticklabels(cb_ticknames) 

298 if cb_title: 

299 cbar.set_label(cb_title, fontsize=charsize or 10) 

300 

301 # Add note if provided 

302 if note: 

303 plt.figtext( 

304 0.99, 0.02, note, horizontalalignment="right", fontsize=charsize or 8 

305 ) 

306 

307 # Multi-panel plotting 

308 if multi_pos is not None: 

309 if len(multi_pos) != 4: 

310 raise ValueError( 

311 "multi_pos must be a 4-element list defining the plot position." 

312 ) 

313 ax.set_position(multi_pos) 

314 

315 # Handle start_multi_params for multi-panel layout 

316 if start_multi_params is not None: 

317 nrows = start_multi_params.get("nrow", 1) 

318 ncols = start_multi_params.get("ncol", 1) 

319 index = start_multi_params.get("index", 1) - 1 # Convert to 0-based index 

320 ax.set_position( 

321 [ 

322 (index % ncols) / ncols, 

323 1 - (index // ncols + 1) / nrows, 

324 1 / ncols, 

325 1 / nrows, 

326 ] 

327 ) 

328 

329 # Save or show the plot 

330 if pub: 

331 plt.savefig(savefile, dpi=300, bbox_inches="tight") 

332 else: 

333 plt.show() 

334 

335 plt.close(fig) 

336 

337 

338def log_color_calc( 

339 data: NDArray[np.integer | np.floating | np.complexfloating], 

340 data_range: NDArray[np.integer | np.floating] = None, 

341 color_profile: str = "log_cut", 

342 log_cut_val: float = None, 

343 min_abs: float = None, 

344 count_missing: int = None, 

345 wh_missing: NDArray[np.integer] = None, 

346 missing_color: int = None, 

347 invert_colorbar: bool = False, 

348) -> tuple: 

349 """ 

350 Translated version of log_color_calc from IDL to Python. 

351 

352 Parameters 

353 ---------- 

354 data : NDArray[np.integer | np.floating | np.complexfloating] 

355 A 2D array of data to be displayed as an image. 

356 The data can be of type int, float, or complex. 

357 data_range : NDArray[np.integer | np.floating], optional 

358 Min/max color bar range, by default [np.nanmin(image), np.nanmax(image)] 

359 color_profile : str, optional 

360 Color bar profiles for logarithmic scaling. 

361 "log_cut", "sym_log", "abs", by default "log_cut" 

362 log_cut_val : int | float, optional 

363 Minimum log value to cut at, by default None 

364 data_min_abs : int | float, optional 

365 The minimum absolute value for the color bar, by default None 

366 count_missing : int, optional 

367 The number of missing values, by default None 

368 wh_missing : NDArray[np.integer], optional 

369 The location of the missing values, by default None 

370 missing_color : int, optional 

371 The index of the color bar for missing values, by default None 

372 invert_colorbar : bool, optional 

373 Invert the color bar, by default False 

374 

375 Returns 

376 ------- 

377 data_log_norm : NDArray[np.int | np.float64] 

378 The normalized data array. 

379 cb_ticks : NDArray[np.int | np.float64] 

380 The color bar ticks. 

381 cb_ticknames : NDArray[np.int | np.float64] 

382 The color bar tick names. 

383 """ 

384 # Define valid color profiles 

385 color_profile_enum = ["log_cut", "sym_log", "abs"] 

386 if color_profile not in color_profile_enum: 

387 raise ValueError( 

388 f"Color profile must be one of: {', '.join(color_profile_enum)}" 

389 ) 

390 

391 # Handle data_range 

392 if data_range is None: 

393 data_range = [np.nanmin(data), np.nanmax(data)] 

394 else: 

395 if len(data_range) != 2: 

396 raise ValueError("data_range must be a 2-element vector") 

397 

398 if data_range[1] < data_range[0]: 

399 raise ValueError("data_range[0] must be less than data_range[1]") 

400 

401 # Handle sym_log profile constraints 

402 if color_profile == "sym_log" and data_range[0] > 0: 

403 print( 

404 "sym_log profile cannot be selected with an entirely positive data range. Switching to log_cut" 

405 ) 

406 color_profile = "log_cut" 

407 

408 data_color_range, data_n_colors = color_range(count_missing=count_missing) 

409 

410 # Handle positive values 

411 wh_pos = np.where(data > 0) 

412 count_pos = len(wh_pos[0]) 

413 if count_pos > 0: 

414 min_pos = np.nanmin(data[wh_pos]) 

415 elif data_range[0] > 0: 

416 min_pos = data_range[0] 

417 elif data_range[1] > 0: 

418 min_pos = data_range[1] / 10 

419 else: 

420 min_pos = 0.01 

421 

422 # Handle negative values 

423 wh_neg = np.where(data < 0) 

424 count_neg = len(wh_neg[0]) 

425 if count_neg > 0: 

426 max_neg = np.nanmax(data[wh_neg]) 

427 elif data_range[1] < 0: 

428 max_neg = data_range[1] 

429 else: 

430 max_neg = data_range[0] / 10 

431 

432 # Handle zero values 

433 wh_zero = np.where(data == 0) 

434 count_zero = len(wh_zero[0]) 

435 

436 # Handle log_cut color profile 

437 if color_profile == "log_cut": 

438 if data_range[1] < 0: 

439 raise ValueError( 

440 "log_cut color profile will not work for entirely negative arrays." 

441 ) 

442 

443 if log_cut_val is None: 

444 if data_range[0] > 0: 

445 log_cut_val = np.log10(data_range[0]) 

446 else: 

447 log_cut_val = np.log10(min_pos) 

448 

449 log_data_range = [log_cut_val, np.log10(data_range[1])] 

450 

451 # Handle zero values 

452 if count_zero > 0: 

453 min_pos_color = 2 

454 zero_color = 1 

455 zero_val = log_data_range[0] 

456 else: 

457 min_pos_color = 1 

458 

459 neg_color = 0 

460 neg_val = log_data_range[0] 

461 

462 data_log = np.log10(data) 

463 wh_under = np.where(data < 10**log_cut_val) 

464 if len(wh_under[0]) > 0: 

465 data_log[wh_under] = log_data_range[0] 

466 

467 wh_over = np.where(data_log > log_data_range[1]) 

468 if len(wh_over[0]) > 0: 

469 data_log[wh_over] = log_data_range[1] 

470 

471 # Normalize data 

472 data_log_norm = ( 

473 (data_log - log_data_range[0]) 

474 * (data_n_colors - min_pos_color - 1) 

475 / (log_data_range[1] - log_data_range[0]) 

476 + data_color_range[0] 

477 + min_pos_color 

478 ) 

479 

480 if count_neg > 0: 

481 data_log_norm[wh_neg] = neg_color 

482 if count_zero > 0: 

483 data_log_norm[wh_zero] = zero_color 

484 

485 elif color_profile == "sym_log": 

486 if data_range[0] >= 0 or data_range[1] <= 0: 

487 raise ValueError( 

488 "sym_log color profile requires both negative and positive values in data_range." 

489 ) 

490 

491 # Calculate the minimum absolute value 

492 if min_abs is None: 

493 if count_pos > 0 and count_neg > 0: 

494 min_abs = min(min_pos, abs(max_neg)) 

495 elif count_pos > 0: 

496 min_abs = min_pos 

497 elif count_neg > 0: 

498 min_abs = abs(max_neg) 

499 else: 

500 min_abs = 1.0 

501 

502 log_data_range = [np.log10(min_abs), np.log10(data_range[1])] 

503 

504 # Normalize data 

505 data_log_norm = np.zeros_like(data, dtype=float) 

506 wh_pos = np.where(data > 0) 

507 wh_neg = np.where(data < 0) 

508 wh_zero = np.where(data == 0) 

509 

510 midpoint = (data_color_range[1] - data_color_range[0]) // 2 

511 

512 if len(wh_pos[0]) > 0: 

513 data_log_norm[wh_pos] = ( 

514 (np.log10(data[wh_pos]) - log_data_range[0]) 

515 * (midpoint) 

516 / (log_data_range[1] - log_data_range[0]) 

517 + data_color_range[0] 

518 + midpoint 

519 ) 

520 

521 if len(wh_neg[0]) > 0: 

522 # Reverse the mapping for negative values 

523 data_log_norm[wh_neg] = ( 

524 data_color_range[0] 

525 + midpoint 

526 - ( 

527 (np.log10(abs(data[wh_neg])) - log_data_range[0]) 

528 * midpoint 

529 / (log_data_range[1] - log_data_range[0]) 

530 ) 

531 ) 

532 

533 if len(wh_zero[0]) > 0: 

534 data_log_norm[wh_zero] = data_color_range[0] + midpoint 

535 

536 # Handle out-of-bounds values 

537 wh_under = np.where(data_log_norm < data_color_range[0]) 

538 if len(wh_under[0]) > 0: 

539 data_log_norm[wh_under] = data_color_range[0] 

540 

541 wh_over = np.where(data_log_norm > data_color_range[1]) 

542 if len(wh_over[0]) > 0: 

543 data_log_norm[wh_over] = data_color_range[1] 

544 

545 # Handle abs color profile 

546 elif color_profile == "abs": 

547 data_abs = np.abs(data) 

548 data_log_norm = (data_abs - data_range[0]) * (data_n_colors - 1) / ( 

549 data_range[1] - data_range[0] 

550 ) + data_color_range[0] 

551 

552 # Handle out-of-bounds values 

553 wh_under = np.where(data_log_norm < data_color_range[0]) 

554 if len(wh_under[0]) > 0: 

555 data_log_norm[wh_under] = data_color_range[0] 

556 

557 wh_over = np.where(data_log_norm > data_color_range[1]) 

558 if len(wh_over[0]) > 0: 

559 data_log_norm[wh_over] = data_color_range[1] 

560 

561 # Handle missing values 

562 if count_missing > 0: 

563 data_log_norm[wh_missing] = missing_color 

564 

565 # Handle invert_colorbar option 

566 if invert_colorbar: 

567 data_log_norm = data_color_range[1] - (data_log_norm - data_color_range[0]) 

568 

569 # Generate colorbar ticks and tick names 

570 if color_profile == "log_cut": 

571 cb_ticks = np.linspace(data_color_range[0], data_color_range[1], num=5) 

572 cb_ticknames = [ 

573 f"{10**(tick * (log_data_range[1] - log_data_range[0]) / (data_n_colors - 1) + log_data_range[0]):.2g}" 

574 for tick in cb_ticks 

575 ] 

576 elif color_profile == "sym_log": 

577 pos_ticks = np.linspace(midpoint, data_color_range[1], num=5) 

578 neg_ticks = np.linspace(data_color_range[0], midpoint, num=5) 

579 cb_ticks = np.concatenate([neg_ticks, [midpoint], pos_ticks]) 

580 cb_ticknames = ( 

581 [ 

582 f"-{10**(log_data_range[1] - (tick - data_color_range[0]) * (log_data_range[1] - log_data_range[0]) / midpoint):.2g}" 

583 for tick in neg_ticks 

584 ] 

585 + ["0"] 

586 + [ 

587 f"{10**((tick - midpoint) * (log_data_range[1] - log_data_range[0]) / midpoint + log_data_range[0]):.2g}" 

588 for tick in pos_ticks 

589 ] 

590 ) 

591 elif color_profile == "abs": 

592 cb_ticks = np.linspace(data_color_range[0], data_color_range[1], num=5) 

593 cb_ticknames = [ 

594 f"{tick * (data_range[1] - data_range[0]) / (data_n_colors - 1) + data_range[0]:.2g}" 

595 for tick in cb_ticks 

596 ] 

597 

598 return data_log_norm, cb_ticks, cb_ticknames 

599 

600 

601def color_range(count_missing: int = None) -> tuple: 

602 """ 

603 Define the color range for the image data. 

604 

605 Parameters 

606 ---------- 

607 count_missing : int, optional 

608 Count of missing values, by default None 

609 

610 Returns 

611 ------- 

612 tuple 

613 A tuple containing the color range and the number of colors. 

614 """ 

615 

616 # Initialize color range 

617 color_range = [0, 255] 

618 if count_missing > 0: 

619 data_color_range = [1, 255] 

620 else: 

621 data_color_range = color_range 

622 

623 data_n_colors = data_color_range[1] - data_color_range[0] + 1 

624 

625 return data_color_range, data_n_colors 

626 

627 

628def plot_fits_image( 

629 fits_file: str, 

630 output_path: str, 

631 logger: Logger, 

632 title: str = "FITS Image", 

633) -> None: 

634 """ 

635 Plot a FITS image using Astropy and save it to the specified output directory. 

636 

637 Parameters 

638 ---------- 

639 fits_file : str 

640 Path to the FITS file. 

641 output_path : str 

642 Path to output image file. 

643 title : str, optional 

644 Title of the plot, by default "FITS Image". 

645 logger : Logger 

646 PyFHD's logger for displaying errors and info to the log files 

647 

648 Returns 

649 ------- 

650 None 

651 The function saves the plot to the specified output path. 

652 """ 

653 # Open the FITS file 

654 with fits.open(fits_file) as hdul: 

655 # Get the data from the first extension 

656 data = hdul[0].data 

657 

658 # Check that the data is 2D and non-zero 

659 if data is None or data.ndim != 2: 

660 logger.warning( 

661 f"FITS data must be a 2D array, no image made for {fits_file}." 

662 ) 

663 return 

664 if not np.any(data): 

665 logger.warning( 

666 f"FITS data array contains only zeros, no image made for {fits_file}." 

667 ) 

668 return 

669 

670 # Get the data from the first extension 

671 header = hdul[0].header 

672 

673 header["CTYPE1"] = "RA---TAN" 

674 header["CTYPE2"] = "DEC--TAN" 

675 

676 # Get units from header 

677 if "BUNIT" not in header: 

678 unit = "Jy/str" 

679 else: 

680 unit = header["BUNIT"] 

681 

682 # Create a WCS object for the image 

683 wcs = WCS(header, relax=True) 

684 

685 # Calculate the extent of the image in degrees 

686 ny, nx = data.shape 

687 x_min, x_max = wcs.wcs_pix2world([0, nx - 1], [0, 0], 0)[0] 

688 y_min, y_max = wcs.wcs_pix2world([0, 0], [0, ny - 1], 0)[1] 

689 

690 x_extent = abs(x_max - x_min) # Extent in degrees along the x-axis 

691 y_extent = abs(y_max - y_min) # Extent in degrees along the y-axis 

692 

693 # Set grid spacing to the extent divided by 4 

694 min_spacing = 2 * u.deg 

695 spacing_x = max(x_extent / 4, min_spacing.value) * u.deg 

696 spacing_y = max(y_extent / 4, min_spacing.value) * u.deg 

697 

698 # Calculate the percentile-based color bar range 

699 percentile_range = (1, 99) 

700 vmin, vmax = np.percentile(data[np.isfinite(data)], percentile_range) 

701 

702 # Create a figure and axis with WCS projection 

703 fig, ax = plt.subplots(subplot_kw={"projection": wcs}) 

704 

705 # Plot the image data 

706 im = ax.imshow( 

707 data, origin="lower", cmap="gray", aspect="auto", vmin=vmin, vmax=vmax 

708 ) 

709 

710 # Add a WCS-based grid 

711 ax.grid(color="white", ls="--", alpha=0.5) 

712 ax.coords.grid(True, color="white", linestyle="--", alpha=0.5) 

713 ax.coords[0].set_axislabel("Right Ascension (J2000)") 

714 ax.coords[1].set_axislabel("Declination (J2000)") 

715 

716 # Customize tick labels for grid lines with dynamic spacing 

717 ax.coords[0].set_ticks(spacing=spacing_x, color="white", size=8, width=1) 

718 ax.coords[0].set_ticklabel(size=10, exclude_overlapping=True) 

719 ax.coords[1].set_ticks(spacing=spacing_y, color="white", size=8, width=1) 

720 ax.coords[1].set_ticklabel(size=10, exclude_overlapping=True) 

721 

722 # Add colorbar 

723 cbar = plt.colorbar(im, ax=ax, orientation="vertical") 

724 cbar.set_label("Flux density (" + unit + ")") 

725 

726 # Set title 

727 if title: 

728 ax.set_title(title) 

729 elif title is None: 

730 ax.set_title("FITS Image") 

731 

732 # Save the plot to the output path 

733 plt.savefig(output_path, dpi=300) 

734 plt.close(fig)