Commit a4bb87c1 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

Merge branch 'add-new-r-curve-generator' into 'main'

add dataloader and r curve generator

See merge request !9
parents fc726995 d3f5a157
Loading
Loading
Loading
Loading
+58 −0
Original line number Diff line number Diff line
{
    "parameters": {
        "q_min": 0.008,
        "q_max": 0.2,
        "step_size": 0.015,
        "FINAL_SEI_SLD": 6.13,
        "SIMULATE_ERRORS": true,
        "ERR": 0.15,
        "total_time": 50,
        "transition_time": 3,
        "transition_midpoint": 15
    },
    "materials": [
        {
            "name": "Si",
            "rho": 2.07,
            "irho": 0,
            "thickness": 0,
            "roughness": 0
        },
        {
            "name": "THF",
            "rho": 6.13,
            "irho": 0,
            "thickness": 0,
            "roughness": 43.77
        },
        {
            "name": "Ti",
            "rho": -1.238,
            "irho": 0,
            "thickness": 52.91,
            "roughness": 12.7
        },
        {
            "name": "Cu",
            "rho": 6.446,
            "irho": 0,
            "thickness": 566.1,
            "roughness": 9.736
        },
        {
            "name": "material",
            "rho": -1.648,
            "irho": 0,
            "thickness": 21.73,
            "roughness": 18.22
        },
        {
            "name": "SEI",
            "rho": 4.581,
            "irho": 0,
            "thickness": 177.7,
            "roughness": 23.04
        }
    ]

}
+97 −0
Original line number Diff line number Diff line
"""Config parser module for synthetic data generation."""
from refl1d.names import (
    Slab,
    SLD,
)


class DataLoader:
    """Load config and make sample from config dict."""

    default_dict = {
        "parameters": {
            "q_min": 0.008,
            "q_max": 0.2,
            "step_size": 0.015,
            "FINAL_SEI_SLD": 6.13,
            "SIMULATE_ERRORS": True,
            "ERR": 0.15,
            "total_time": 50,
            "transition_time": 3,
            "transition_midpoint": 15,
        }
    }

    def __init__(self, input_dict: dict):
        """Initialize the class.

        Parameters
        ----------
        input_dict : dictionary
            Dictionary of all parameters and materials to stack in a sample.
        """

        self.param = input_dict.get("parameters", self.default_dict["parameters"])

        if self.param:
            self.q_min = self.param.get("q_min", self.default_dict["parameters"]["q_min"])
            self.q_max = self.param.get("q_max", self.default_dict["parameters"]["q_max"])
            self.step_size = self.param.get("step_size", self.default_dict["parameters"]["step_size"])
            self.final_sei_sld = self.param.get("FINAL_SEI_SLD", self.default_dict["parameters"]["FINAL_SEI_SLD"])
            self.simulate_errors = self.param.get("SIMULATE_ERRORS", self.default_dict["parameters"]["SIMULATE_ERRORS"])
            self.err = self.param.get("ERR", self.default_dict["parameters"]["ERR"])
            self.total_time = self.param.get("total_time", self.default_dict["parameters"]["total_time"])
            self.transition_time = self.param.get("transition_time", self.default_dict["parameters"]["transition_time"])
            self.transition_midpoint = self.param.get(
                "transition_midpoint", self.default_dict["parameters"]["transition_midpoint"]
            )

        self.sample = create_sample_from_dict(input_dict.get("materials"))


def create_slab_chunk(material: dict) -> "Slab":
    """Return a slab chunk built from given configuration.

    Parameters
    ----------
    material : dictionary
        Single material dictionary to turn into slab chunk.

    Returns
    -------
    Slab
        The slab chunk instance.
    """

    slab = Slab(
        material=SLD(
            name=material["name"],
            rho=material["rho"],
            irho=material["irho"],
        ),
        thickness=material["thickness"],
        interface=material["roughness"],
    )

    return slab


def create_sample_from_dict(input_dict: list) -> Slab:
    """Return a slab instance built from given configuration.

    Parameters
    ----------
    json_object : list
        The sample configuration dictionary list.

    Returns
    -------
    Slab
        The slab instance.
    """

    sample = create_slab_chunk(input_dict[0])
    for i in range(1, len(input_dict)):
        sample = sample | create_slab_chunk(input_dict[i])

    return sample
+222 −0
Original line number Diff line number Diff line
import json
import math
import numpy as np
from refl1d.names import QProbe, Experiment
from tgreft.utils.data.dataloader import DataLoader


class RCurveGenerator:
    def __init__(self, data: DataLoader):
        """Initialize the class.

        Parameters
        ----------
        data : DataLoader object
            Object of parsed data to use with RCurveGenerator and related functions
        """

        self.data = data
        self.probe, self.tr_probe = create_probes(data.q_min, data.q_max, data.step_size)
        self.expt = Experiment(probe=self.probe, sample=self.data.sample)
        self._r_final, self._r_init = get_states(
            self.expt, self.data.final_sei_sld, self.data.simulate_errors, self.data.err
        )
        self.tr_expt = Experiment(probe=self.tr_probe, sample=self.data.sample)
        self._tr_data, self._rho_data = generate_tNR(
            self.tr_expt,
            self.data.total_time,
            self.data.transition_time,
            self.data.transition_midpoint,
            self.data.final_sei_sld,
            self.data.simulate_errors,
            self.data.err,
        )

    @property
    def r_final(self) -> list:
        """Returns a list from the RCurveGenerator object instance.

        Returns
        -------
        list
            r_final list
        """

        return self._r_final

    @property
    def r_init(self) -> list:
        """Returns a list from the object instance.

        Returns
        -------
        list
            r_init list
        """

        return self._r_init

    @property
    def tr_data(self) -> list:
        """Returns a list from the object instance.

        Returns
        -------
        list
            tr_data list
        """

        return self._tr_data

    @property
    def rho_data(self) -> list:
        """Returns a list from the object instance.

        Returns
        -------
        list
            rho_data list
        """

        return self._rho_data


def create_probes(q_min: float, q_max: float, step_size: float, resolution: float = 0.028) -> tuple:
    """Returns two probes, one using the full data and one with a slice, in the form of a tuple.

    Parameters
    ----------
    q_min : float
        minimum of the q range

    q_max : float
        maximum of the q range

    res : float
        step size between min and max

    Returns
    -------
    tuple
        tuple containing both relevant probes
    """

    q = np.arange(np.log(q_min), np.log(q_max), step_size)
    q = np.exp(q)

    dq = resolution * q

    # Time-resolved Q range
    q_tr = q[50:130]
    dq_tr = dq[50:130]

    probe = QProbe(q, dq, data=None)

    # If using data for fits, replace data by the data to fit:
    # probe = QProbe(q, dq, data=(data, errors))

    tr_probe = QProbe(q_tr, dq_tr, data=None)

    return probe, tr_probe


def get_states(expt: Experiment, final_sei_sld: float, simulate_errors: bool, err: float) -> tuple:
    """Returns a tuple containing both the initial and final r states.

    Parameters
    ----------
    expt : Experiment
        Experiment object - full range
    final_sei_sld : float
        final sei sld contant
    simulate_errors : boolean
        boolean indicating whether you want to include error or not in simulation
    err : float
        error constant

    Returns
    -------
    tuple
        Tuple of both r_final and r_init lists
    """

    # Final state ##################################################################
    _, r_final = expt.reflectivity()
    if simulate_errors:
        r_final = np.random.normal(r_final, err * r_final)

    # Initial state ################################################################
    expt.sample["SEI"].material.rho.value = final_sei_sld
    expt.update()
    _, r_init = expt.reflectivity()
    if simulate_errors:
        r_init = np.random.normal(r_init, err * r_init)

    return r_final, r_init


def generate_tNR(
    tr_expt: Experiment,
    total_time: int,
    transition_time: int,
    transition_midpoint: int,
    final_sei_sld: float,
    simulate_errors: bool,
    err: float,
) -> tuple:
    """Return a tuple built containing tr_data and rho_data.

    Parameters
    ----------
    tr_expt : Experiment
        Subset of experiment
    total_time : int
        total time to run
    transition_time : int
        time to transition
    transition_midpoint : int
        midpoint of transition
    final_sei_sld : float
        final sei sld constant
    simulate_errors : bool
        boolean indicating whether you want to include error or not in simulation
    err : float
        error constant

    Returns
    -------
    tuple
        tuple of both tr_data and rho_data lists
    """

    tr_data = []
    rho_data = []
    for i in range(total_time):
        # This is a simple ERF transition, but it could be any function!
        rho_sei = (
            final_sei_sld - (final_sei_sld - 4.581) * (1 + math.erf((i - transition_midpoint) / transition_time)) / 2
        )

        tr_expt.sample["SEI"].material.rho.value = rho_sei
        tr_expt.update()
        _, _r = tr_expt.reflectivity()
        if simulate_errors:
            _r = np.random.normal(_r, err * _r)

        rho_data.append(rho_sei)
        tr_data.append(_r)

    tr_data = np.asarray(tr_data)
    return tr_data, rho_data


if __name__ == "__main__":
    import sys

    json_file_name = sys.argv[1]

    f = open(json_file_name, "r")
    data = json.load(f)
    dataloader = DataLoader(data)
    rc = RCurveGenerator(dataloader)
    print(rc.r_init, rc.r_final, rc.tr_data, rc.rho_data)