Commit 905c6bac authored by Eiffert, Brett's avatar Eiffert, Brett
Browse files

style and code improvements

parent f34f2592
Loading
Loading
Loading
Loading
+46 −24
Original line number Diff line number Diff line
from refl1d.names import *
from refl1d.names import (
    Slab,
    SLD,
)


class DataLoader:
    """Add description of the class here"""

    def __init__(self, json_dict: dict):
    default_dict = {
        "parameters": {
            "q_min": 0.008,
            "q_max": 0.2,
            "step_size": 0.015,
            "FINAL_SEI_SLD": 6.13,
            "SIMULATE_ERRORS": True,
            "ERR": 0.15,
            "total_time": 50,
            "transition_time": 3,
            "transition_midpoint": 15,
        }
    }

    def __init__(self, input_dict: dict):
        """Initialize the class.

        Parameters
        ----------
        json_dict : json dictionary
            json dictionary of all parameters and materials to stack in a sample.
        input_dict : dictionary
            Dictionary of all parameters and materials to stack in a sample.
        """
        self.json_dict = json_dict
        self.q_min = json_dict["parameters"]["q_min"]
        self.q_max = json_dict["parameters"]["q_max"]
        self.step_size = json_dict["parameters"]["step"]
        self.final_sei_sld = json_dict["parameters"]["FINAL_SEI_SLD"]
        self.simulate_errors = json_dict["parameters"]["SIMULATE_ERRORS"]
        self.err = json_dict["parameters"]["ERR"]
        self.total_time = json_dict["parameters"]["total_time"]
        self.transition_time = json_dict["parameters"]["transition_time"]
        self.transition_midpoint = json_dict["parameters"]["transition_midpoint"]
        self.sample = self.create_sample_from_json(json_dict["materials"])

    def create_slab_chunk(self, material : dict) -> "Slab":
        self.param = input_dict.get("parameters", self.default_dict["parameters"])
        if self.param:
            self.q_min = self.param.get("q_min", self.default_dict["parameters"]["q_min"])
            self.q_max = self.param.get("q_max", self.default_dict["parameters"]["q_max"])
            self.step_size = self.param.get("step_size", self.default_dict["parameters"]["step_size"])
            self.final_sei_sld = self.param.get("FINAL_SEI_SLD", self.default_dict["parameters"]["FINAL_SEI_SLD"])
            self.simulate_errors = self.param.get("SIMULATE_ERRORS", self.default_dict["parameters"]["SIMULATE_ERRORS"])
            self.err = self.param.get("ERR", self.default_dict["parameters"]["ERR"])
            self.total_time = self.param.get("total_time", self.default_dict["parameters"]["total_time"])
            self.transition_time = self.param.get("transition_time", self.default_dict["parameters"]["transition_time"])
            self.transition_midpoint = self.param.get(
                "transition_midpoint", self.default_dict["parameters"]["transition_midpoint"]
            )

        self.sample = self.create_sample_from_dict(input_dict.get("materials"))

    def create_slab_chunk(self, material: dict) -> "Slab":
        """Return a slab chunk built from given configuration.

        Parameters
@@ -49,8 +72,7 @@ class DataLoader:

        return slab

    def create_sample_from_json(self, json_object : list) -> "Slab | Stack":

    def create_sample_from_dict(self, input_dict: list) -> Slab:
        """Return a slab instance built from given configuration.

        Parameters
@@ -64,8 +86,8 @@ class DataLoader:
            The slab instance.
        """

        sample = self.create_slab_chunk(json_object[0])
        for i in range(1, len(json_object)):
            sample = sample | self.create_slab_chunk(json_object[i])  
        sample = self.create_slab_chunk(input_dict[0])
        for i in range(1, len(input_dict)):
            sample = sample | self.create_slab_chunk(input_dict[i])

        return sample
+152 −142
Original line number Diff line number Diff line
@@ -3,14 +3,13 @@ import json
import numpy as np
import math

from refl1d.names import *
from refl1d.names import QProbe, Experiment

from dataloader import DataLoader

class RCurveGenerator:

class RCurveGenerator:
    def __init__(self, data: DataLoader):

        """Initialize the class.

        Parameters
@@ -20,14 +19,72 @@ class RCurveGenerator:
        """

        self.data = data
        self.probe, self.tr_probe = self.create_probes(data.q_min, data.q_max, res=data.step)
        self.probe, self.tr_probe = create_probes(data.q_min, data.q_max, data.step_size)
        self.expt = Experiment(probe=self.probe, sample=self.data.sample)
        self.r_final, self.r_init = self.get_states(self.expt, self.data.final_sei_sld, self.data.simulate_errors, self.data.err)
        self.r_final, self.r_init = get_states(
            self.expt, self.data.final_sei_sld, self.data.simulate_errors, self.data.err
        )
        self.tr_expt = Experiment(probe=self.tr_probe, sample=self.data.sample)
        self.tr_data, self.rho_data = self.generate_tNR(self.tr_expt, self.data.total_time, self.data.transition_time, self.data.transition_midpoint, self.data.final_sei_sld, self.data.simulate_errors, self.data.err)
        self.tr_data, self.rho_data = generate_tNR(
            self.tr_expt,
            self.data.total_time,
            self.data.transition_time,
            self.data.transition_midpoint,
            self.data.final_sei_sld,
            self.data.simulate_errors,
            self.data.err,
        )

    @property
    def get_r_final(self) -> list:
        """Returns a list from the RCurveGenerator object instance.

        Returns
        -------
        list
            r_final list
        """

        return self.r_final

    def create_probes(self, q_min : float, q_max : float, step_size : float, resolution : float = 0.028) -> tuple:
    @property
    def get_r_init(self) -> list:
        """Returns a list from the object instance.

        Returns
        -------
        list
            r_init list
        """

        return self.r_init

    @property
    def get_tr_data(self) -> list:
        """Returns a list from the object instance.

        Returns
        -------
        list
            tr_data list
        """

        return self.tr_data

    @property
    def get_rho_data(self) -> list:
        """Returns a list from the object instance.

        Returns
        -------
        list
            rho_data list
        """

        return self.rho_data


def create_probes(q_min: float, q_max: float, step_size: float, resolution: float = 0.028) -> tuple:
    """Returns two probes, one using the full data and one with a slice, in the form of a tuple.

    Parameters
@@ -65,8 +122,8 @@ class RCurveGenerator:

    return probe, tr_probe

    def get_states(self, expt, final_sei_sld, simulate_errors, err) -> tuple:

def get_states(expt: Experiment, final_sei_sld: float, simulate_errors: bool, err: float) -> tuple:
    """Returns a tuple containing both the initial and final r states.

    Parameters
@@ -91,9 +148,8 @@ class RCurveGenerator:
    if simulate_errors:
        r_final = np.random.normal(r_final, err * r_final)


    # Initial state ################################################################
        expt.sample['SEI'].material.rho.value = final_sei_sld
    expt.sample["SEI"].material.rho.value = final_sei_sld
    expt.update()
    _, r_init = expt.reflectivity()
    if simulate_errors:
@@ -101,7 +157,16 @@ class RCurveGenerator:

    return r_final, r_init

    def generate_tNR(self, tr_expt, total_time, transition_time, transition_midpoint, final_sei_sld, simulate_errors, err) -> tuple:

def generate_tNR(
    tr_expt: Experiment,
    total_time: int,
    transition_time: int,
    transition_midpoint: int,
    final_sei_sld: float,
    simulate_errors: bool,
    err: float,
) -> tuple:
    """Return a tuple built containing tr_data and rho_data.

    Parameters
@@ -130,9 +195,11 @@ class RCurveGenerator:
    rho_data = []
    for i in range(total_time):
        # This is a simple ERF transition, but it could be any function!
            rho_sei = final_sei_sld - (final_sei_sld-4.581) * (1 + math.erf((i-transition_midpoint)/transition_time)) / 2
        rho_sei = (
            final_sei_sld - (final_sei_sld - 4.581) * (1 + math.erf((i - transition_midpoint) / transition_time)) / 2
        )

            tr_expt.sample['SEI'].material.rho.value = rho_sei
        tr_expt.sample["SEI"].material.rho.value = rho_sei
        tr_expt.update()
        _, _r = tr_expt.reflectivity()
        if simulate_errors:
@@ -144,67 +211,10 @@ class RCurveGenerator:
    tr_data = np.asarray(tr_data)
    return tr_data, rho_data

    def get_r_final(self) -> list:

        """Returns a list from the RCurveGenerator object instance.

        Returns
        -------
        list
            r_final list
        """

        return self.r_final

    def get_r_init(self) -> list:

        """Returns a list from the object instance.

        Returns
        -------
        list
            r_init list
        """

        return self.r_init

    def get_tr_data(self) -> list:
        
        """Returns a list from the object instance.

        Returns
        -------
        list
            tr_data list
        """

        return self.tr_data

    def get_rho_data(self) -> list:

        """Returns a list from the object instance.

        Returns
        -------
        list
            rho_data list
        """

        return self.rho_data
    
""" USE EXAMPLE
if __name__ == "__main__":
    f = open("data.json", "r")
    data = json.load(f)
    dataloader = DataLoader(data)
    rc = RCurveGenerator(dataloader)
    print(rc.get_r_init())
    
    #gen = RCurveGenerator(0.008, 0.2)
    #print(gen.q)
"""




    print(rc.get_r_init)