Loading src/tgreft/utils/visualization.py +3 −0 Original line number Diff line number Diff line Loading @@ -45,6 +45,9 @@ def params_cmp_heatmap( if param_index < num_params: pred = params_pred[:, param_index] ref = params_ref[:, param_index] # convert nan to 0 pred = np.nan_to_num(pred) ref = np.nan_to_num(ref) # use reference min and max to make sure the color scale is the same vmin = ref.min() vmax = ref.max() Loading Loading
src/tgreft/utils/visualization.py +3 −0 Original line number Diff line number Diff line Loading @@ -45,6 +45,9 @@ def params_cmp_heatmap( if param_index < num_params: pred = params_pred[:, param_index] ref = params_ref[:, param_index] # convert nan to 0 pred = np.nan_to_num(pred) ref = np.nan_to_num(ref) # use reference min and max to make sure the color scale is the same vmin = ref.min() vmax = ref.max() Loading