Commit d3f5a157 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

light refactor

parent ec1cff2a
Loading
Loading
Loading
Loading
+42 −38
Original line number Diff line number Diff line
"""Config parser module for synthetic data generation."""
from refl1d.names import (
    Slab,
    SLD,
@@ -5,7 +6,7 @@ from refl1d.names import (


class DataLoader:
    """Add description of the class here"""
    """Load config and make sample from config dict."""

    default_dict = {
        "parameters": {
@@ -31,6 +32,7 @@ class DataLoader:
        """

        self.param = input_dict.get("parameters", self.default_dict["parameters"])

        if self.param:
            self.q_min = self.param.get("q_min", self.default_dict["parameters"]["q_min"])
            self.q_max = self.param.get("q_max", self.default_dict["parameters"]["q_max"])
@@ -44,9 +46,10 @@ class DataLoader:
                "transition_midpoint", self.default_dict["parameters"]["transition_midpoint"]
            )

        self.sample = self.create_sample_from_dict(input_dict.get("materials"))
        self.sample = create_sample_from_dict(input_dict.get("materials"))


    def create_slab_chunk(self, material: dict) -> "Slab":
def create_slab_chunk(material: dict) -> "Slab":
    """Return a slab chunk built from given configuration.

    Parameters
@@ -72,7 +75,8 @@ class DataLoader:

    return slab

    def create_sample_from_dict(self, input_dict: list) -> Slab:

def create_sample_from_dict(input_dict: list) -> Slab:
    """Return a slab instance built from given configuration.

    Parameters
@@ -86,8 +90,8 @@ class DataLoader:
        The slab instance.
    """

        sample = self.create_slab_chunk(input_dict[0])
    sample = create_slab_chunk(input_dict[0])
    for i in range(1, len(input_dict)):
            sample = sample | self.create_slab_chunk(input_dict[i])
        sample = sample | create_slab_chunk(input_dict[i])

    return sample
+24 −22
Original line number Diff line number Diff line
import json

import numpy as np
import math

import numpy as np
from refl1d.names import QProbe, Experiment

from dataloader import DataLoader
from tgreft.utils.data.dataloader import DataLoader


class RCurveGenerator:
@@ -21,11 +18,11 @@ class RCurveGenerator:
        self.data = data
        self.probe, self.tr_probe = create_probes(data.q_min, data.q_max, data.step_size)
        self.expt = Experiment(probe=self.probe, sample=self.data.sample)
        self.r_final, self.r_init = get_states(
        self._r_final, self._r_init = get_states(
            self.expt, self.data.final_sei_sld, self.data.simulate_errors, self.data.err
        )
        self.tr_expt = Experiment(probe=self.tr_probe, sample=self.data.sample)
        self.tr_data, self.rho_data = generate_tNR(
        self._tr_data, self._rho_data = generate_tNR(
            self.tr_expt,
            self.data.total_time,
            self.data.transition_time,
@@ -36,7 +33,7 @@ class RCurveGenerator:
        )

    @property
    def get_r_final(self) -> list:
    def r_final(self) -> list:
        """Returns a list from the RCurveGenerator object instance.

        Returns
@@ -45,10 +42,10 @@ class RCurveGenerator:
            r_final list
        """

        return self.r_final
        return self._r_final

    @property
    def get_r_init(self) -> list:
    def r_init(self) -> list:
        """Returns a list from the object instance.

        Returns
@@ -57,10 +54,10 @@ class RCurveGenerator:
            r_init list
        """

        return self.r_init
        return self._r_init

    @property
    def get_tr_data(self) -> list:
    def tr_data(self) -> list:
        """Returns a list from the object instance.

        Returns
@@ -69,10 +66,10 @@ class RCurveGenerator:
            tr_data list
        """

        return self.tr_data
        return self._tr_data

    @property
    def get_rho_data(self) -> list:
    def rho_data(self) -> list:
        """Returns a list from the object instance.

        Returns
@@ -81,7 +78,7 @@ class RCurveGenerator:
            rho_data list
        """

        return self.rho_data
        return self._rho_data


def create_probes(q_min: float, q_max: float, step_size: float, resolution: float = 0.028) -> tuple:
@@ -174,16 +171,17 @@ def generate_tNR(
    tr_expt : Experiment
        Subset of experiment
    total_time : int

        total time to run
    transition_time : int

        time to transition
    transition_midpoint : int

        midpoint of transition
    final_sei_sld : float

        final sei sld constant
    simulate_errors : bool

        boolean indicating whether you want to include error or not in simulation
    err : float
        error constant

    Returns
    -------
@@ -213,8 +211,12 @@ def generate_tNR(


if __name__ == "__main__":
    f = open("data.json", "r")
    import sys

    json_file_name = sys.argv[1]

    f = open(json_file_name, "r")
    data = json.load(f)
    dataloader = DataLoader(data)
    rc = RCurveGenerator(dataloader)
    print(rc.get_r_init)
    print(rc.r_init, rc.r_final, rc.tr_data, rc.rho_data)