Commit f96d0dd2 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

add helper function to faciliate model eval

parent 53adc8d2
Loading
Loading
Loading
Loading
+170 −0

File added.

Preview size limit exceeded, changes collapsed.

+0 −0

Empty file added.

+275 −0
Original line number Diff line number Diff line
"""Functions used to evaluate the performance of the model."""
#!/usr/bin/env python
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional
from refl1d.names import Experiment, QProbe, Parameter
from refl1d.names import FitProblem
from bumps.fitters import fit
from tgreft.utils.data.data_synthesis import RCurveGenerator
from tgreft.utils.data.data_loader import param_to_rcurve
from tgreft.analysis.utils import interpolate_data, load_csv


def parameters_refine(
    param: np.ndarray,
    r_curve: np.ndarray,
    q_range: Optional[np.ndarray] = None,
    dq: Optional[np.ndarray] = None,
    error: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Refine the parameters with fitting.

    Parameters
    ----------
    param : np.ndarray
        The parameters to refine.
        0: electrolyte_sld
        1: electrolyte_roughness
        2: sei_sld
        3: sei_thickness
        4: sei_roughness
        5: material_sld
        6: material_thickness
        7: material_roughness
        8: cu_sld
        9: cu_thickness
        10: cu_roughness
        11: ti_sld
        12: ti_thickness
        13: ti_roughness
        14: oxide_sld
        15: oxide_thickness
        16: oxide_roughness
    r_curve : np.ndarray
        The r_curve to fit.
    error : Optional[np.ndarray], optional
        The error of the r_curve, by default None.

    Returns
    -------
    np.ndarray
        The refined parameters.
    np.ndarray
        The z data from guess sample.
    np.ndarray
        The sld data from guess sample.
    np.ndarray
        The z data from fitted sample.
    np.ndarray
        The sld data from fitted sample.
    """
    # create a fresh generator
    r_generator = RCurveGenerator()

    # build the sample from the parameters
    config = [
        {"name": "electrolyte", "sld": param[0], "isld": 0, "thickness": 0, "roughness": param[1]},
        {"name": "SEI", "sld": param[2], "isld": 0, "thickness": param[3], "roughness": param[4]},
        {"name": "material", "sld": param[5], "isld": 0, "thickness": param[6], "roughness": param[7]},
        {"name": "Cu", "sld": param[8], "isld": 0, "thickness": param[9], "roughness": param[10]},
        {"name": "Ti", "sld": param[11], "isld": 0, "thickness": param[12], "roughness": param[13]},
        {"name": "oxide", "sld": param[14], "isld": 0, "thickness": param[15], "roughness": param[16]},
        {"name": "substrate", "sld": 2.07, "isld": 0, "thickness": 0, "roughness": 0},
    ]
    sample = r_generator.build_sample_from_config(config)

    # set parameter refine ranges
    sample["electrolyte"].material.rho.range(0, 7)
    sample["electrolyte"].interface.range(5, 300)

    sample["SEI"].material.rho.range(-3, 10)
    sample["SEI"].interface.range(5, 100)
    sample["SEI"].thickness.range(10, 1000)

    sample["material"].material.rho.range(-3, 10)
    sample["material"].interface.range(5, 50)
    sample["material"].thickness.range(10, 1000)

    sample["Cu"].material.rho.range(-3, 10)
    sample["Cu"].interface.range(0, 50)
    sample["Cu"].thickness.range(10, 1000)

    sample["Ti"].material.rho.range(-3, 3)
    sample["Ti"].interface.range(0, 50)
    sample["Ti"].thickness.range(10, 1000)

    sample["oxide"].material.rho.range(0, 5)
    sample["oxide"].interface.range(0, 25)
    sample["oxide"].thickness.range(0, 60)

    data = r_curve

    # make a probe
    if q_range is None:
        q_range = np.logspace(
            np.log10(0.009),
            np.log10(0.18),
            num=150,
        )

    if dq is None:
        q_resolution = 0.025
        dq = q_resolution * q_range / 2.355

    if error is None:
        error = r_curve * 0.07  # 7% error, matching the generated synthetic data

    probe = QProbe(q_range, dq, data=(data, error))
    probe.background = Parameter(value=0.0, name="background")

    # make an experiment
    experiment = Experiment(probe=probe, sample=sample)

    # get the sld profile from guessed sample
    z_guessed, sld_guessed, _ = experiment.smooth_profile()

    # refine
    problem = FitProblem(experiment)

    # Try this first (faster)
    try:
        results = fit(problem, method="amoeba", steps=10000, verbose=False)
    except Exception as error:
        print(error)
        print("!!Switch to dream!!")
        print(f"--> guess: {param}")
        results = fit(problem, method="dream", steps=1000, burn=1000, pop=20, verbose=False)
        # re-throw the error so that the caller can handle it
        raise Exception(error)

    # get the sld profile along z
    z_fitted, sld_fitted, _ = experiment.smooth_profile()

    # Results are in the wrong order. It's interface, rho, thickness...
    fit_pars = [results.x[1], results.x[0]]
    n_layers = int((len(results.x) - 1) / 3)
    for i in range(n_layers):
        fit_pars.extend([results.x[3 * i + 3], results.x[3 * i + 4], results.x[3 * i + 2]])

    return np.array(fit_pars), z_guessed, sld_guessed, z_fitted, sld_fitted


def extract_parameters(
    csv_file: str,
    model: torch.nn.Module,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Extract the parameters from reflectometry data (CSV) via inference and refinement.

    Parameters
    ----------
    csv_file : str
        The path to the CSV file containing the data.
    model : torch.nn.Module
        The model to use for inference.

    Returns
    -------
    param : np.ndarray
        The refined parameters.
    guess : np.ndarray
        The parameters from inference.
    z_guessed : np.ndarray
        The z data from guess sample.
    sld_guessed : np.ndarray
        The sld data from guess sample.
    z_fitted : np.ndarray
        The z data from fitted sample.
    sld_fitted : np.ndarray
        The sld data from fitted sample.
    """
    # load the data from the CSV file
    q, r, dr, dq = load_csv(csv_file)
    # interpolate the data to the native q_range
    q_range, r_interpolated, dr_interpolated = interpolate_data(q, r, dr)
    # convert to torch tensors
    r_interpolated = torch.tensor(r_interpolated, dtype=torch.float32)
    # infer the parameters
    with torch.no_grad():
        # replace non-positive values with 1 so that they become 0 after log
        r_interpolated[r_interpolated <= 0] = 1
        guess = model(r_interpolated.unsqueeze(0)).squeeze(0).numpy()
    # if nan in the guess, print the filename
    if np.isnan(guess).any():
        param = None
        z_guessed, sld_guessed, z_fitted, sld_fitted = None, None, None, None
    else:
        # refine the parameters
        # NOTE: use the data from the CSV file for refinement
        param, z_guessed, sld_guessed, z_fitted, sld_fitted = parameters_refine(
            param=guess,
            r_curve=r,
            q_range=q,
            dq=dq,
            error=dr,
        )
    return param, guess, z_guessed, sld_guessed, z_fitted, sld_fitted


def evaluate(
    csv_file: str,
    model: torch.nn.Module,
    save: bool = False,
):
    """Evaluate the model on the data from the CSV file.

    Parameters
    ----------
    csv_file : str
        The path to the CSV file containing the data.
    model : torch.nn.Module
        The model to use for inference.
    save : bool, optional
        Whether to save the plot, by default False.
    """
    # extract the parameters
    param, guess, z_guessed, sld_guessed, z_fitted, sld_fitted = extract_parameters(csv_file, model)

    # error handling
    # NOTE: we need to figure out why we are getting NaN here
    if param is None or np.isnan(param).any():
        print(f"For {csv_file}:\nparam: {param}\nguess: {guess}\n")
        return

    # load the data from the CSV file
    q, r, dr, dq = load_csv(csv_file)
    # interpolate the data to the native q_range (just need the q_range here)
    q_range, r_interpolated, dr_interpolated = interpolate_data(q, r, dr)
    # get the r_curve from the parameters
    r_guess = param_to_rcurve(guess)
    r_refined = param_to_rcurve(param)

    # plot
    fig, axs = plt.subplots(1, 2, figsize=(12, 4))
    ax1, ax2 = axs
    # plt keywords
    data_kwargs = {"label": "data", "color": "blue", "fmt": "+", "alpha": 0.5}
    guess_kwargs = {"label": "guess", "color": "orange"}
    refined_kwargs = {"label": "refined", "color": "green"}
    # ax1
    ax1.errorbar(q, r, yerr=dr, **data_kwargs)
    ax1.plot(q_range, r_guess, **guess_kwargs)
    ax1.plot(q_range, r_refined, **refined_kwargs)
    ax1.set_yscale("log")
    ax1.set_xlabel("q (1/A)")
    ax1.set_ylabel("r")
    ax1.legend()
    # ax2
    ax2.plot(z_guessed, sld_guessed, **guess_kwargs)
    ax2.plot(z_fitted, sld_fitted, **refined_kwargs)
    ax2.set_xlabel("z")
    ax2.set_ylabel("sld")
    ax2.legend()
    # title
    fig.suptitle(csv_file)

    if save:
        plt.savefig(csv_file.replace(".txt", ".pdf"), dpi=120)
        plt.close()
    else:
        plt.show()


if __name__ == "__main__":
    pass
+95 −0
Original line number Diff line number Diff line
"""Utility functions useful for evaluting the performance of a model."""
#!/usr/bin/env python
import numpy as np
from typing import Tuple


def load_csv(csv_file: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Load the data from the CSV file and return the q, r, and dr data as numpy arrays.

    Parameters
    ----------
    csv_file : str
        The path to the CSV file containing the data.

    Returns
    -------
    q : np.ndarray
        The q data.
    r : np.ndarray
        The r data.
    dr : np.ndarray
        The dr data.

    Examples
    --------
    >>> q, r, dr = load_csv("data/IPTS-29196/REFL_201083_combined_data_auto.txt")
    """
    # load the data from the CSV file
    # NOTE: default arg is sufficient for the data format
    data = np.loadtxt(csv_file)
    # separate into columns
    data_q = data[:, 0]  # 1/A
    data_r = data[:, 1]  #
    data_dr = data[:, 2]  #
    data_dq = data[:, 3]  # 1/A
    return data_q, data_r, data_dr, data_dq


def interpolate_data(q: np.ndarray, r: np.ndarray, dr: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Interpolate the data to the native q_range.

    Parameters
    ----------
    q : np.ndarray
        The q data.
    r : np.ndarray
        The r data.
    dr : np.ndarray
        The dr data.

    Returns
    -------
    q : np.ndarray
        The q data, interpolated to the native q_range.
    r : np.ndarray
        The r data, interpolated to the native q_range.
    dr : np.ndarray
        The dr data, interpolated to the native q_range.
    """
    # native q_range
    q_range = np.logspace(
        np.log10(0.009),
        np.log10(0.18),
        num=150,
    )
    # interpolate the data to the native q_range
    r = np.interp(q_range, q, r)
    dr = np.interp(q_range, q, dr)
    return q_range, r, dr


def get_rcurve_from_csv(csv_file: str) -> Tuple[np.ndarray, np.ndarray]:
    """Load the data from the CSV file and return the r_curve and dq data as numpy arrays.

    Parameters
    ----------
    csv_file : str
        The path to the CSV file containing the data.

    Returns
    -------
    r_curve : np.ndarray
        The r_curve data, interpolated to the same q range as the simulated data.
    dq : np.ndarray
        The dq data, interpolated to the same q range as the simulated data.
    """
    # load the data from the CSV file
    q, r, dr, _ = load_csv(csv_file)
    # interpolate the data to the native q_range
    _, r, dr = interpolate_data(q, r, dr)
    return r, dr


if __name__ == "__main__":
    pass