Coverage for PyFHD/pyfhd_tools/test_utils.py: 57%

94 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-01 10:58 +0800

1from PyFHD.io.pyfhd_io import recarray_to_dict 

2import numpy as np 

3from scipy.io import readsav 

4from pathlib import Path 

5import numpy.testing as npt 

6from colorama import Fore 

7from colorama import Style 

8from PyFHD.io.pyfhd_io import save 

9from numpy.typing import NDArray 

10 

11 

12def get_data(data_dir: Path, data_filename: str, *args: list[str]) -> list: 

13 """ 

14 This function is designed to read npy or sav files in a 

15 data directory inside test_fhd_*. Ensure the data file 

16 has been made with the scripts inside the scripts directory. 

17 Use splitter.py to put the files and directories in the right 

18 format if you have used histogram runner and rebin runner. 

19 Paths are expected to be of data_dir/data/function_name/[data,expected]_filename.npy 

20 data_dir is given by pytest-datadir, it should be the directory where the test file is in. 

21 

22 Parameters 

23 ---------- 

24 data_dir : Path 

25 This should be the dir passed through from pytest-datadir 

26 data_filename : atr 

27 The name of the file for the input 

28 *args : list[str] 

29 If given, is expected to be more filenames 

30 

31 Returns 

32 ------- 

33 return_list: list 

34 Contains just the input if only one file given, otherwise, it also gives the output if other files given 

35 """ 

36 # Put as Paths and read the files 

37 input_path = Path(data_dir, data_filename) 

38 if input_path.suffix == ".sav": 

39 input = readsav(input_path, python_dict=True) 

40 else: 

41 input = np.load(input_path, allow_pickle=True) 

42 if len(args) > 0: 

43 return_list = [input] 

44 for file in args: 

45 path = Path(data_dir, file) 

46 if path.suffix == ".sav": 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true

47 output = readsav(path, python_dict=True) 

48 else: 

49 output = np.load(path, allow_pickle=True) 

50 return_list.append(output) 

51 return return_list 

52 # Return the input and expected 

53 return input 

54 

55 

56def get_data_items(data_dir: Path, data_with_item_path: Path, *args: list[str]) -> list: 

57 """ 

58 Takes all the path inputs from tests and processes them so they're ready for use. 

59 

60 Parameters 

61 ---------- 

62 data_dir : Path 

63 Path to the data directory 

64 data_with_item_path : Path 

65 Path to the data that contains only an item 

66 *args : Paths 

67 Give more paths to more data with items that need to be extracted 

68 

69 Returns 

70 ------- 

71 return_list: list 

72 Variable(s) required to do the test 

73 """ 

74 # Retrieve the files and their contents 

75 data = get_data(data_dir, data_with_item_path) 

76 # Get the key, then use the key to get the item 

77 key = list(data.item().keys())[0] 

78 item = data.item().get(key) 

79 # Process the args list if there is one 

80 if len(args) > 0: 80 ↛ 90line 80 didn't jump to line 90 because the condition on line 80 was always true

81 # Add to return_list 

82 return_list = [item] 

83 for path in args: 

84 data = get_data(data_dir, path) 

85 key = list(data.item().keys())[0] 

86 item_in_data = data.item().get(key) 

87 return_list.append(item_in_data) 

88 return return_list 

89 # Return them 

90 return item 

91 

92 

93def get_data_sav(data_dir: Path, sav_file: Path, *args: list[Path]) -> list: 

94 """ 

95 Takes all the path inputs from tests and processes them so they're ready for use. 

96 More specifically takes in sav files 

97 

98 Parameters 

99 ---------- 

100 data_dir : Path 

101 Path to the data directory 

102 sav_file : Path 

103 Path to the sav file, which will load a python dictionary 

104 args: list[Path] 

105 If given, is expected to be more filenames 

106 

107 Returns 

108 ------- 

109 return_list: list 

110 Contains just the data if only one file given, otherwise, it also gives the output if other files given 

111 """ 

112 data = get_data(data_dir, sav_file) 

113 key = list(data.keys())[0] 

114 data = data[key] 

115 if len(args) > 0: 115 ↛ 124line 115 didn't jump to line 124 because the condition on line 115 was always true

116 # Add to return_list 

117 return_list = [data] 

118 for path in args: 

119 data = get_data(data_dir, path) 

120 key = list(data.keys())[0] 

121 data = data[key] 

122 return_list.append(data) 

123 return return_list 

124 return data 

125 

126 

127def get_savs(data_dir: Path, sav_file: Path, *args: list[Path]) -> dict | list[dict]: 

128 """ 

129 Takes in the path for many sav files and reads them without 

130 reading their keys. Assumes the sav files here have more than one key. 

131 If you use one sav_path only then the function acts as a wrapper for scipy's readsav. 

132 

133 Parameters 

134 ---------- 

135 data_dir : Path 

136 Path to the data directory 

137 sav_file : Path 

138 Path to the sav file, which will load a python dictionary 

139 args: Paths 

140 If given, is expected to be more filenames 

141 

142 Returns 

143 ------- 

144 data: dict | list[dict] 

145 Either a dict of one sav file or the dicts of multiple sav files 

146 """ 

147 data = readsav(Path(data_dir, sav_file), python_dict=True) 

148 if len(args) > 0: 148 ↛ 149line 148 didn't jump to line 149 because the condition on line 148 was never true

149 data = [data] 

150 for file in args: 

151 new_data = readsav(Path(data_dir, file), python_dict=True) 

152 data.append(new_data) 

153 return data 

154 

155 

156def try_assert_all_close( 

157 actual: NDArray, target: NDArray, name: str, tolerance=1e-8 

158) -> None: 

159 """ 

160 Uses the numpy testing assert_all_close but uses a try and except wrapper around it to print 

161 the error instead of doing an AssertionError which stops the running of the program. This is helpful 

162 when doing testing with expected precision errors, but wanting to avoid stopping the program or constantly 

163 setting the tolerances on multiple assert statements. 

164 

165 Parameters 

166 ---------- 

167 actual : NDArray 

168 The array we calculated 

169 target : NDArray 

170 The array we actually want to calculate 

171 name : str 

172 The name of the variable we are testing 

173 tolerance : float, optional 

174 This is the tolerance for the error in absolute values, by default 1e-8 

175 """ 

176 try: 

177 npt.assert_allclose(actual, target, atol=tolerance) 

178 print( 

179 Fore.GREEN 

180 + Style.BRIGHT 

181 + "Test Passed for {}".format(name) 

182 + Style.RESET_ALL 

183 ) 

184 except AssertionError as error: 

185 print( 

186 Fore.RED 

187 + Style.BRIGHT 

188 + "Test Failed for {}:".format(name) 

189 + Style.RESET_ALL 

190 + "{}".format(error) 

191 + Style.RESET_ALL 

192 ) 

193 

194 

195def convert_to_h5(test_path: Path, save_path: Path, *args: list[Path]) -> None: 

196 """ 

197 For every file specified as an arg, read the file from the test_path into a python dictionary. 

198 If it's a dict or recarray that contaisn recarrays, convert all the recarrays using recarray_to_dict. 

199 The files can be .npy or .sav files. The python dict will then be written into a HDF5 file for testing 

200 purposes. 

201 

202 This function was made to convert many of the .npy and .sav files into something that can be read and written more 

203 easily by other packages other than numpyt or scipy. 

204 

205 Parameters 

206 ---------- 

207 test_path : Path 

208 The path to a directory with all the files inside it 

209 save_path : Path 

210 The path to the file for saving the HDF5 

211 *args : list[Path] 

212 A list of file names to be read in, can be .npy or .sav files 

213 """ 

214 to_save = {} 

215 # Process the file differently depending on whether its IDL or numpy files 

216 for file in args: 

217 if file.endswith(".sav"): 

218 var = readsav(Path(test_path, file), python_dict=True) 

219 # Convert to nested dictionaries 

220 var = recarray_to_dict(var) 

221 elif file.endswith(".npy"): 

222 var = np.load(Path(test_path, file), allow_pickle=True).item() 

223 for key in var: 

224 to_save[key] = var[key] 

225 save(save_path, to_save, "to_save") 

226 

227 

228def sav_file_vis_arr_swap_axes(sav_file_vis_arr: NDArray) -> NDArray: 

229 """After saving arrays from IDL like `vis_arr` and `vis_model_arr` into 

230 and IDL .sav file, and subsequently loading in via scipy.io.readsav, 

231 they come out in a shape/format unsuitable for PyFHD. Use this function 

232 to reshape into shape = (n_pol, n_freq, n_baselines) 

233 

234 Parameters 

235 ---------- 

236 sav_file_vis_arr : NDArray 

237 Array as read in by scipy.io.readsav, if `n_pol = 2` should have `shape=(2,)` 

238 

239 Returns 

240 ------- 

241 NDArray 

242 Returns the array with `shape=(n_pol, n_freq, n_baselines)` 

243 """ 

244 

245 n_pol = sav_file_vis_arr.shape[0] 

246 

247 vis_arr = np.empty( 

248 (n_pol, sav_file_vis_arr[0].shape[1], sav_file_vis_arr[0].shape[0]), 

249 dtype=sav_file_vis_arr[0].dtype, 

250 ) 

251 

252 for pol in range(n_pol): 

253 vis_arr[pol, :, :] = sav_file_vis_arr[pol].transpose() 

254 

255 return vis_arr 

256 

257 

258def print_types(dictionary: dict, dict_name: str, indent_level: int = 1) -> None: 

259 """ 

260 When generating the tests, Sometimes I'd find it useful to see the types of all the keys and value pairs inside 

261 the dictionary I'm manipulating. The Debug mode is helpful for this too, but this can be easily used 

262 inside a notebook if experimenting in there too. 

263 

264 Parameters 

265 ---------- 

266 dictionary : dict 

267 The dictionary to print the types of 

268 dict_name : str 

269 The name of the dict 

270 indent_level : int 

271 Sets the indent levels for printing as it's a recursive function, by default 1 

272 """ 

273 for key in dictionary.keys(): 

274 # Print this if it's a NumPy array 

275 if type(dictionary[key]) == np.ndarray: 

276 print( 

277 f"{dict_name}[{key}] : {dictionary[key].dtype} {dictionary[key].shape}\n{indent_level * 2 * ' '}Inside Type: {type(dictionary[key][0])}" 

278 ) 

279 if type(dictionary[key][0]) == np.ndarray: 

280 print( 

281 f"{indent_level * 2 * ' '}NumPy Array Dtype: {dictionary[key][0].dtype}" 

282 ) 

283 # Recursively call the function on another sub dict 

284 elif type(dictionary[key]) == dict: 

285 print(f"{dict_name}[{key}] : {type(dictionary[key])}") 

286 print_types( 

287 dictionary[key], dict_name=f" {key}", indent_level=indent_level + 2 

288 ) 

289 # If it's an object, might be useful to print the value 

290 elif type(dictionary[key]) == object: 

291 print(f"{dict_name}[{key}] : {type(dictionary[key])}") 

292 print(dictionary[key]) 

293 # Otherwise just print it out 

294 else: 

295 print(f"{dict_name}[{key}] : {type(dictionary[key])}")