Commit 78c8a6af authored by Zhang, Chen's avatar Zhang, Chen
Browse files

fix heatmap issue

parent 525ba2dd
Loading
Loading
Loading
Loading
+11 −5
Original line number Diff line number Diff line
@@ -43,10 +43,16 @@ def params_cmp_heatmap(
        for j in range(num_cols):
            param_index = i * num_cols + j
            if param_index < num_params:
                pred = params_pred[:, param_index]
                ref = params_ref[:, param_index]
                # use reference min and max to make sure the color scale is the same
                vmin = ref.min()
                vmax = ref.max()
                hist, xedges, yedges = np.histogram2d(
                    params_pred[:, param_index],
                    params_ref[:, param_index],
                    pred,
                    ref,
                    bins=bins,
                    range=[[vmin, vmax], [vmin, vmax]],
                )
                ax[i, j].imshow(
                    hist.T,
@@ -55,9 +61,9 @@ def params_cmp_heatmap(
                    aspect="auto",
                    origin="lower",
                )
                # make sure the x and y axis have the same scale
                ax[i, j].set_xlim(yedges[0], yedges[-1])
                ax[i, j].set_ylim(yedges[0], yedges[-1])
                # set the ticks to []
                ax[i, j].set_xticks([])
                ax[i, j].set_yticks([])
                if len(labels) > 0:
                    ax[i, j].set_title(labels[param_index])
    fig.tight_layout()