Commit 54e74d51 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

Merge branch 'add_post_processing' into 'main'

Add post processing

See merge request !5
parents 53adc8d2 2c8f39c0
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -91,3 +91,5 @@ data/IPTS-30384/REFL_206968_combined_data_auto.txt filter=lfs diff=lfs merge=lfs
data/IPTS-30384/REFL_207087_combined_data_auto.txt filter=lfs diff=lfs merge=lfs -text
data/IPTS-30384/REFL_207304_combined_data_auto.txt filter=lfs diff=lfs merge=lfs -text
data/IPTS-30384/REFL_207358_combined_data_auto.txt filter=lfs diff=lfs merge=lfs -text
models/model_gpt_d2048_h32_l6.pt filter=lfs diff=lfs merge=lfs -text
data/IPTS-29196 filter=lfs diff=lfs merge=lfs -text
+3 −0
Original line number Diff line number Diff line
version https://git-lfs.github.com/spec/v1
oid sha256:8cf1cfe569126426ea643696668d6f9fcf5379aa48413b6bd7cdf001249ee6b5
size 646837393
+170 −0

File added.

Preview size limit exceeded, changes collapsed.

+0 −0

Empty file added.

+275 −0
Original line number Diff line number Diff line
"""Functions used to evaluate the performance of the model."""
#!/usr/bin/env python
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional
from refl1d.names import Experiment, QProbe, Parameter
from refl1d.names import FitProblem
from bumps.fitters import fit
from tgreft.utils.data.data_synthesis import RCurveGenerator
from tgreft.utils.data.data_loader import param_to_rcurve
from tgreft.analysis.utils import interpolate_data, load_csv


def parameters_refine(
    param: np.ndarray,
    r_curve: np.ndarray,
    q_range: Optional[np.ndarray] = None,
    dq: Optional[np.ndarray] = None,
    error: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Refine the parameters with fitting.

    Parameters
    ----------
    param : np.ndarray
        The parameters to refine.
        0: electrolyte_sld
        1: electrolyte_roughness
        2: sei_sld
        3: sei_thickness
        4: sei_roughness
        5: material_sld
        6: material_thickness
        7: material_roughness
        8: cu_sld
        9: cu_thickness
        10: cu_roughness
        11: ti_sld
        12: ti_thickness
        13: ti_roughness
        14: oxide_sld
        15: oxide_thickness
        16: oxide_roughness
    r_curve : np.ndarray
        The r_curve to fit.
    error : Optional[np.ndarray], optional
        The error of the r_curve, by default None.

    Returns
    -------
    np.ndarray
        The refined parameters.
    np.ndarray
        The z data from guess sample.
    np.ndarray
        The sld data from guess sample.
    np.ndarray
        The z data from fitted sample.
    np.ndarray
        The sld data from fitted sample.
    """
    # create a fresh generator
    r_generator = RCurveGenerator()

    # build the sample from the parameters
    config = [
        {"name": "electrolyte", "sld": param[0], "isld": 0, "thickness": 0, "roughness": param[1]},
        {"name": "SEI", "sld": param[2], "isld": 0, "thickness": param[3], "roughness": param[4]},
        {"name": "material", "sld": param[5], "isld": 0, "thickness": param[6], "roughness": param[7]},
        {"name": "Cu", "sld": param[8], "isld": 0, "thickness": param[9], "roughness": param[10]},
        {"name": "Ti", "sld": param[11], "isld": 0, "thickness": param[12], "roughness": param[13]},
        {"name": "oxide", "sld": param[14], "isld": 0, "thickness": param[15], "roughness": param[16]},
        {"name": "substrate", "sld": 2.07, "isld": 0, "thickness": 0, "roughness": 0},
    ]
    sample = r_generator.build_sample_from_config(config)

    # set parameter refine ranges
    sample["electrolyte"].material.rho.range(0, 7)
    sample["electrolyte"].interface.range(5, 300)

    sample["SEI"].material.rho.range(-3, 10)
    sample["SEI"].interface.range(5, 100)
    sample["SEI"].thickness.range(10, 1000)

    sample["material"].material.rho.range(-3, 10)
    sample["material"].interface.range(5, 50)
    sample["material"].thickness.range(10, 1000)

    sample["Cu"].material.rho.range(-3, 10)
    sample["Cu"].interface.range(0, 50)
    sample["Cu"].thickness.range(10, 1000)

    sample["Ti"].material.rho.range(-3, 3)
    sample["Ti"].interface.range(0, 50)
    sample["Ti"].thickness.range(10, 1000)

    sample["oxide"].material.rho.range(0, 5)
    sample["oxide"].interface.range(0, 25)
    sample["oxide"].thickness.range(0, 60)

    data = r_curve

    # make a probe
    if q_range is None:
        q_range = np.logspace(
            np.log10(0.009),
            np.log10(0.18),
            num=150,
        )

    if dq is None:
        q_resolution = 0.025
        dq = q_resolution * q_range / 2.355

    if error is None:
        error = r_curve * 0.07  # 7% error, matching the generated synthetic data

    probe = QProbe(q_range, dq, data=(data, error))
    probe.background = Parameter(value=0.0, name="background")

    # make an experiment
    experiment = Experiment(probe=probe, sample=sample)

    # get the sld profile from guessed sample
    z_guessed, sld_guessed, _ = experiment.smooth_profile()

    # refine
    problem = FitProblem(experiment)

    # Try this first (faster)
    try:
        results = fit(problem, method="amoeba", steps=10000, verbose=False)
    except Exception as error:
        print(error)
        print("!!Switch to dream!!")
        print(f"--> guess: {param}")
        results = fit(problem, method="dream", steps=1000, burn=1000, pop=20, verbose=False)
        # re-throw the error so that the caller can handle it
        raise Exception(error)

    # get the sld profile along z
    z_fitted, sld_fitted, _ = experiment.smooth_profile()

    # Results are in the wrong order. It's interface, rho, thickness...
    fit_pars = [results.x[1], results.x[0]]
    n_layers = int((len(results.x) - 1) / 3)
    for i in range(n_layers):
        fit_pars.extend([results.x[3 * i + 3], results.x[3 * i + 4], results.x[3 * i + 2]])

    return np.array(fit_pars), z_guessed, sld_guessed, z_fitted, sld_fitted


def extract_parameters(
    csv_file: str,
    model: torch.nn.Module,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Extract the parameters from reflectometry data (CSV) via inference and refinement.

    Parameters
    ----------
    csv_file : str
        The path to the CSV file containing the data.
    model : torch.nn.Module
        The model to use for inference.

    Returns
    -------
    param : np.ndarray
        The refined parameters.
    guess : np.ndarray
        The parameters from inference.
    z_guessed : np.ndarray
        The z data from guess sample.
    sld_guessed : np.ndarray
        The sld data from guess sample.
    z_fitted : np.ndarray
        The z data from fitted sample.
    sld_fitted : np.ndarray
        The sld data from fitted sample.
    """
    # load the data from the CSV file
    q, r, dr, dq = load_csv(csv_file)
    # interpolate the data to the native q_range
    q_range, r_interpolated, dr_interpolated = interpolate_data(q, r, dr)
    # convert to torch tensors
    r_interpolated = torch.tensor(r_interpolated, dtype=torch.float32)
    # infer the parameters
    with torch.no_grad():
        # replace non-positive values with 1 so that they become 0 after log
        r_interpolated[r_interpolated <= 0] = 1
        guess = model(r_interpolated.unsqueeze(0)).squeeze(0).numpy()
    # if nan in the guess, print the filename
    if np.isnan(guess).any():
        param = None
        z_guessed, sld_guessed, z_fitted, sld_fitted = None, None, None, None
    else:
        # refine the parameters
        # NOTE: use the data from the CSV file for refinement
        param, z_guessed, sld_guessed, z_fitted, sld_fitted = parameters_refine(
            param=guess,
            r_curve=r,
            q_range=q,
            dq=dq,
            error=dr,
        )
    return param, guess, z_guessed, sld_guessed, z_fitted, sld_fitted


def evaluate(
    csv_file: str,
    model: torch.nn.Module,
    save: bool = False,
):
    """Evaluate the model on the data from the CSV file.

    Parameters
    ----------
    csv_file : str
        The path to the CSV file containing the data.
    model : torch.nn.Module
        The model to use for inference.
    save : bool, optional
        Whether to save the plot, by default False.
    """
    # extract the parameters
    param, guess, z_guessed, sld_guessed, z_fitted, sld_fitted = extract_parameters(csv_file, model)

    # error handling
    # NOTE: we need to figure out why we are getting NaN here
    if param is None or np.isnan(param).any():
        print(f"For {csv_file}:\nparam: {param}\nguess: {guess}\n")
        return

    # load the data from the CSV file
    q, r, dr, dq = load_csv(csv_file)
    # interpolate the data to the native q_range (just need the q_range here)
    q_range, r_interpolated, dr_interpolated = interpolate_data(q, r, dr)
    # get the r_curve from the parameters
    r_guess = param_to_rcurve(guess)
    r_refined = param_to_rcurve(param)

    # plot
    fig, axs = plt.subplots(1, 2, figsize=(12, 4))
    ax1, ax2 = axs
    # plt keywords
    data_kwargs = {"label": "data", "color": "blue", "fmt": "+", "alpha": 0.5}
    guess_kwargs = {"label": "guess", "color": "orange"}
    refined_kwargs = {"label": "refined", "color": "green"}
    # ax1
    ax1.errorbar(q, r, yerr=dr, **data_kwargs)
    ax1.plot(q_range, r_guess, **guess_kwargs)
    ax1.plot(q_range, r_refined, **refined_kwargs)
    ax1.set_yscale("log")
    ax1.set_xlabel("q (1/A)")
    ax1.set_ylabel("r")
    ax1.legend()
    # ax2
    ax2.plot(z_guessed, sld_guessed, **guess_kwargs)
    ax2.plot(z_fitted, sld_fitted, **refined_kwargs)
    ax2.set_xlabel("z")
    ax2.set_ylabel("sld")
    ax2.legend()
    # title
    fig.suptitle(csv_file)

    if save:
        plt.savefig(csv_file.replace(".txt", ".pdf"), dpi=120)
        plt.close()
    else:
        plt.show()


if __name__ == "__main__":
    pass
Loading