Loading src/tgreft/utils/visualization.py +11 −5 Original line number Diff line number Diff line Loading @@ -43,10 +43,16 @@ def params_cmp_heatmap( for j in range(num_cols): param_index = i * num_cols + j if param_index < num_params: pred = params_pred[:, param_index] ref = params_ref[:, param_index] # use reference min and max to make sure the color scale is the same vmin = ref.min() vmax = ref.max() hist, xedges, yedges = np.histogram2d( params_pred[:, param_index], params_ref[:, param_index], pred, ref, bins=bins, range=[[vmin, vmax], [vmin, vmax]], ) ax[i, j].imshow( hist.T, Loading @@ -55,9 +61,9 @@ def params_cmp_heatmap( aspect="auto", origin="lower", ) # make sure the x and y axis have the same scale ax[i, j].set_xlim(yedges[0], yedges[-1]) ax[i, j].set_ylim(yedges[0], yedges[-1]) # set the ticks to [] ax[i, j].set_xticks([]) ax[i, j].set_yticks([]) if len(labels) > 0: ax[i, j].set_title(labels[param_index]) fig.tight_layout() Loading Loading
src/tgreft/utils/visualization.py +11 −5 Original line number Diff line number Diff line Loading @@ -43,10 +43,16 @@ def params_cmp_heatmap( for j in range(num_cols): param_index = i * num_cols + j if param_index < num_params: pred = params_pred[:, param_index] ref = params_ref[:, param_index] # use reference min and max to make sure the color scale is the same vmin = ref.min() vmax = ref.max() hist, xedges, yedges = np.histogram2d( params_pred[:, param_index], params_ref[:, param_index], pred, ref, bins=bins, range=[[vmin, vmax], [vmin, vmax]], ) ax[i, j].imshow( hist.T, Loading @@ -55,9 +61,9 @@ def params_cmp_heatmap( aspect="auto", origin="lower", ) # make sure the x and y axis have the same scale ax[i, j].set_xlim(yedges[0], yedges[-1]) ax[i, j].set_ylim(yedges[0], yedges[-1]) # set the ticks to [] ax[i, j].set_xticks([]) ax[i, j].set_yticks([]) if len(labels) > 0: ax[i, j].set_title(labels[param_index]) fig.tight_layout() Loading