Commit 53adc8d2 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

add sld cmp plot during chk pt

parent 21a49028
Loading
Loading
Loading
Loading
+18 −10
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ import numpy as np
import torch
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from tgreft.utils.visualization import params_cmp_heatmap, plot_rcurve_cmp
from tgreft.utils.visualization import params_cmp_heatmap, plot_rcurve_cmp, plot_sld_cmp
from tgreft.nn.loss import CompositeLoss


@@ -125,15 +125,15 @@ def visualize_single_epoch(
        "sei_sld",
        "sei_thickness",
        "sei_roughness",
        "material_sld",
        "material_thickness",
        "material_roughness",
        "cu_sld",
        "cu_thickness",
        "cu_roughness",
        "ti_sld",
        "ti_thickness",
        "ti_roughness",
        "bulk1_sld",
        "bulk1_thickness",
        "bulk1_roughness",
        "bulk2_sld",
        "bulk2_thickness",
        "bulk2_roughness",
        "bulk3_sld",
        "bulk3_thickness",
        "bulk3_roughness",
        "oxide_sld",
        "oxide_thickness",
        "oxide_roughness",
@@ -155,6 +155,14 @@ def visualize_single_epoch(
        epoch=epoch,
        device=str(device),
    )
    # sld
    plot_sld_cmp(
        params_pred=preds,
        params_ref=refs,
        mlflow_log=True,
        epoch=epoch,
        device=str(device),
    )


def get_optimizer(
+33 −3
Original line number Diff line number Diff line
@@ -138,15 +138,45 @@ def param_to_rcurve(param: np.ndarray) -> np.ndarray:
    config = [
        {"name": "electrolyte", "sld": param[0], "isld": 0, "thickness": 0, "roughness": param[1]},
        {"name": "SEI", "sld": param[2], "isld": 0, "thickness": param[3], "roughness": param[4]},
        {"name": "material", "sld": param[5], "isld": 0, "thickness": param[6], "roughness": param[7]},
        {"name": "Cu", "sld": param[8], "isld": 0, "thickness": param[9], "roughness": param[10]},
        {"name": "Ti", "sld": param[11], "isld": 0, "thickness": param[12], "roughness": param[13]},
        {"name": "bulk3", "sld": param[5], "isld": 0, "thickness": param[6], "roughness": param[7]},
        {"name": "bulk2(Cu)", "sld": param[8], "isld": 0, "thickness": param[9], "roughness": param[10]},
        {"name": "bulk1", "sld": param[11], "isld": 0, "thickness": param[12], "roughness": param[13]},
        {"name": "oxide", "sld": param[14], "isld": 0, "thickness": param[15], "roughness": param[16]},
        {"name": "substrate", "sld": 2.07, "isld": 0, "thickness": 0, "roughness": 0},
    ]
    return r_generator(config)


def param_to_sld(param: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Convert the parameters to the SLD profile.
    
    Parameters
    ----------
    param : np.ndarray
        The parameters.
    
    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        The SLD profile.
    """
    r_generator = RCurveGenerator()
    config = [
        {"name": "electrolyte", "sld": param[0], "isld": 0, "thickness": 0, "roughness": param[1]},
        {"name": "SEI", "sld": param[2], "isld": 0, "thickness": param[3], "roughness": param[4]},
        {"name": "bulk3", "sld": param[5], "isld": 0, "thickness": param[6], "roughness": param[7]},
        {"name": "bulk2(Cu)", "sld": param[8], "isld": 0, "thickness": param[9], "roughness": param[10]},
        {"name": "bulk1", "sld": param[11], "isld": 0, "thickness": param[12], "roughness": param[13]},
        {"name": "oxide", "sld": param[14], "isld": 0, "thickness": param[15], "roughness": param[16]},
        {"name": "substrate", "sld": 2.07, "isld": 0, "thickness": 0, "roughness": 0},
    ]
    experiment = r_generator.build_experiment_from_config(config)
    # compute the SLD profile
    z, sld, _ = experiment.smooth_profile()

    return z, sld


if __name__ == "__main__":
    # example usage
    dataset = get_dataset(n_dataset=10, error=0.05)
+20 −10
Original line number Diff line number Diff line
@@ -51,19 +51,29 @@ class RCurveGenerator:
        if not new_sample_config:
            return self.experiment.reflectivity()

        # build the sample
        sample = self.build_sample_from_config(new_sample_config)
        experiment = self.build_experiment_from_config(new_sample_config)
        _, r_curve = experiment.reflectivity()

        # get the probe
        probe = self.get_prob()
        return r_curve

        # get the experiment
        experiment = Experiment(probe=probe, sample=sample)
    @staticmethod
    def build_experiment_from_config(sample_config: list) -> Experiment:
        """Build an experiment instance from given sample configuration.
        
        # calculate the reflectivity curve
        _, r_curve = experiment.reflectivity()
        Parameters
        ----------
        sample_config : list
            The sample configuration dictionary.
        
        return r_curve
        Returns
        -------
        Experiment
            The experiment instance.
        """
        sample = RCurveGenerator.build_sample_from_config(sample_config)
        probe = RCurveGenerator.get_prob()
        experiment = Experiment(probe=probe, sample=sample)
        return experiment

    @staticmethod
    def build_sample_from_config(
+61 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@
import numpy as np
import matplotlib.pyplot as plt
import mlflow
from tgreft.utils.data.data_loader import param_to_rcurve
from tgreft.utils.data.data_loader import param_to_rcurve, param_to_sld


def params_cmp_heatmap(
@@ -123,7 +123,10 @@ def plot_rcurve_cmp(
                ax[i, j].plot(r_curve_pred, label="pred")
                ax[i, j].plot(r_curve_ref, label="ref")
                ax[i, j].legend()
                ax[i, j].set_xscale("log")
                ax[i, j].set_yscale("log")
                ax[i, j].set_xlabel("q [$\AA^{-1}$]")
                ax[i, j].set_ylabel("R [q]")
    fig.tight_layout()

    if mlflow_log:
@@ -132,3 +135,60 @@ def plot_rcurve_cmp(
        plt.close(fig)
    else:
        plt.show()


def plot_sld_cmp(
    params_pred: np.ndarray,
    params_ref: np.ndarray,
    n_select: int = 16,
    mlflow_log: bool = True,
    epoch: int = 0,
    device: str = "default",
):
    """Plot the SLD profile to evaluate the model.
    
    Parameters
    ----------
    params_pred : np.ndarray
        The predicted parameters.
    params_ref : np.ndarray
        The reference parameters.
    n_select : int, optional
        The number of curves to plot, by default 5.
    mlflow_log : bool, optional
        Whether to log the figure to mlflow, by default True.
    epoch : int, optional
        The epoch number, by default 0.
    device : str, optional
        The device used for training.
    """
    # randomly select n_select curves
    indices = np.random.choice(params_pred.shape[0], n_select, replace=False)
    # based on n_select, determine the number of rows and columns
    num_rows = int(np.sqrt(n_select))
    num_cols = int(np.ceil(n_select / num_rows))
    # create the figure
    fig, ax = plt.subplots(num_rows, num_cols, figsize=(4 * num_cols, 4 * num_rows))
    # plot
    for i in range(num_rows):
        for j in range(num_cols):
            index = i * num_cols + j
            if index < n_select:
                param_pred = params_pred[indices[index], :]
                z_pred, sld_pred = param_to_sld(param_pred)
                param_ref = params_ref[indices[index], :]
                z_ref, sld_ref = param_to_sld(param_ref)
                ax[i, j].plot(z_pred, sld_pred, label="pred")
                ax[i, j].plot(z_ref, sld_ref, label="ref")
                ax[i, j].legend()
                ax[i, j].set_xlabel("z [$\AA$]")
                ax[i, j].set_ylabel("SLD [$\AA^{-2}$]")
                pass
    fig.tight_layout()

    if mlflow_log:
        mlflow.log_figure(fig, f"sld_cmp/{device}/epoch_{epoch}.png")
        # close figure when running in the back
        plt.close(fig)
    else:
        plt.show()