Commit 6dcb2322 authored by McDonnell, Marshall's avatar McDonnell, Marshall
Browse files

Adds multiple cif + scale on/on/factor

parent 26d72a0f
Loading
Loading
Loading
Loading
Loading

.gitignore

0 → 100644
+178 −0
Original line number Diff line number Diff line
.idea

# Created by https://www.toptal.com/developers/gitignore/api/python
# Edit at https://www.toptal.com/developers/gitignore?templates=python

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml

# ruff
.ruff_cache/

# LSP config files
pyrightconfig.json

# End of https://www.toptal.com/developers/gitignore/api/python
+1 −2
Original line number Diff line number Diff line
@@ -10,7 +10,6 @@ This uses the RSE shared docker image for GSAS2 found in [this repo][rse_gsas2_i

Build the image
```
docker login code.ornl.gov:4567
docker build -t asrp-gsas2-refinement .
```

@@ -18,7 +17,7 @@ docker build -t asrp-gsas2-refinement .

Test:
```
docker run --entrypoint="" asrp-gsas2-refinement pytest
docker run --entrypoint="" asrp-gsas2-refinement pixi run pytest
```

To run (use your parameters):
+40 −14
Original line number Diff line number Diff line
@@ -5,20 +5,35 @@ from .run_gsas2_fit import run_gsas2_fit

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--cif-filename', help='Name of CIF file to load (*.cif)',type=str)
    parser.add_argument(
        '-c',
        '--cif-filenames',
        help='Name of CIF files to load (*.cif)',
        action="extend",
        nargs="+",
        type=str,
    )
    parser.add_argument('-f', '--gsas-filename', help='Name of gsas file to load (*.gsa) ',type=str)
    parser.add_argument('-i', '--instrument-params-filename', help='Name of instrument parameters file to load (*.prm)',type=str)
    parser.add_argument('-o', '--output-stem-name', help='Output stem name', type=str, default="gsas2_refinement")
    parser.add_argument('-p', '--output-directory', help='Output directory name', type=str, default="./portal")
    parser.add_argument('-s', '--scatter-type', help='Scatter type: ["N", "X"]', choices=["N", "X"], type=str)
    parser.add_argument('-b', '--bank-id', help='Index of the bank to use', type=str)
    parser.add_argument(
        '-b',
        '--bank-ids',
        help='Indices of the banks to use',
        action="extend",
        nargs="+",
        type=int,
    )
    parser.add_argument('-l', '--xmin', help='Xmin', type=str)
    parser.add_argument('-r', '--xmax', help='Xmax', type=str)
    parser.add_argument('-n', '--num-cycles', help='Number of refinement cycles', type=int)
    parser.add_argument('-v', '--initial-values', help='Initial values for refinement', type=str)
    parser.add_argument('--refine-scale-on', action='store_true')
    parser.add_argument('--refine-scale-off', action='store_true')
    parser.add_argument('--scale-factor', type=float, help="Set initial scale factor for phase")
    args = parser.parse_args()

    bank_id = int(args.bank_id)
    left_bound = float(args.xmin)
    right_bound = float(args.xmax)

@@ -28,17 +43,28 @@ if __name__ == "__main__":
        kwargs["num_cycles"] = args.num_cycles
    if args.initial_values:
        kwargs["init_vals"] = args.initial_values
    if args.refine_scale_on and args.refine_scale_off:
        raise Exception("Cannot have refine scale ON and OFF.")
    kwargs["refine_scale"] = True
    if args.refine_scale_on:
        kwargs["refine_scale"] = True
    if args.refine_scale_off:
        kwargs["refine_scale"] = False
    if args.scale_factor:
        kwargs["scale_factor"] = args.scale_factor

    print("refine scale:", args.refine_scale_off, kwargs["refine_scale"])

    # Run refinement
    run_gsas2_fit(
        args.cif_filename,
        args.gsas_filename,
        args.instrument_params_filename,
        args.output_stem_name,
        args.scatter_type,
        bank_id,
        left_bound,
        right_bound,
        args.output_directory,
    rw, x, y, ycalc, dy, bkg, cell_r, ref_list = run_gsas2_fit(
        structure_paths = args.cif_filenames,
        gsas_input_path=args.gsas_filename,
        instrument_params_path=args.instrument_params_filename,
        output_stem_fn=args.output_stem_name,
        banks=args.bank_ids,
        xmin=left_bound,
        xmax=right_bound,
        output_path=args.output_directory,
        **kwargs,
    )
+162 −69
Original line number Diff line number Diff line
import matplotlib.pyplot as plt
import os
import sys
import time
from typing import Union, Dict, Any

import GSASIIscriptable as G2sc
import numpy as np

# Callback functions
from GSASII import GSASIIscriptable as G2sc


def print_stats(gpx: G2sc.G2Project):
    """prints profile rfactors for all histograms"""
    print("*** profile Rwp, " + os.path.split(gpx.filename)[1])
    for hist in gpx.histograms():
        print("\t{:20s}: {:.2f}".format(hist.name, hist.get_wR()))

    for phase in gpx.phases():
        phase_name = phase['General']['Name']
        for hist_name, hist_data in phase['Histograms'].items():
            # Each 'Scale' is [value, refine_flag]
            scale_val, refine_flag = hist_data['Scale']

            print(f"{phase_name} - {hist_name}: scale = {scale_val}, refine = {refine_flag}")


    for phase in gpx.phases():
        for hist in phase['Histograms']:
            scale = phase['Histograms'][hist]['Scale']
            print(f"Phase Scale: {scale}")

    print("")

def plot_refinement(gpx: G2sc.G2Project):
    for hist in gpx.histograms():
        x = np.array(hist.getdata("X"))
        y = np.array(hist.getdata("Yobs"))
        ycalc = np.array(hist.getdata("Ycalc"))
        dy = np.array(hist.getdata("Residual"))
        bkg = np.array(hist.getdata("Background"))

    filename=f"/app/portal/out_{time.time()}.png"
    fig,ax = plt.subplots()
    ax.plot(x, y, label="exp", color='black', marker='x')
    ax.plot(x, ycalc, label="refinement", color='red')
    ax.plot(x, bkg, label="background", color='green')
    ax.plot(x, dy, label='residual', color='blue')
    ax.legend()
    fig.savefig(filename, dpi=300)
    print(f"*** saved {filename}")

def callback(gpx: G2sc.G2Project):
    print_stats(gpx)
    plot_refinement(gpx)

def run_gsas2_fit(
    structure_path,
    gsas_input_path,
    instrument_params_path,
    output_stem_fn,
    stype,
    bank,
    xmin,
    xmax,
    output_path,
    num_cycles=5,
    structure_paths: Union[ str, list[str] ],
    gsas_input_path: str,
    instrument_params_path: str,
    output_stem_fn: str,
    banks: Union[ int, list[int] ],
    xmin: float,
    xmax: float,
    output_path: str,
    refine_scale: bool = True,
    scale_factor: Union[None, float] = None,
    num_cycles: int = 5,
    init_vals: Union[None, Dict[str, Any]] = None,
):
    """
    Parameters
    ----------
    structure_path: str
        input structure cif filename.
    structure_path: str OR list[str]
        input structure cif filename(s).
    gsas_input_path: str
        input gsa filename.
    instrument_params_path: str
        input instrument profile filename.
    output_stem_fn: str
        output stem filename.
    stype: str
        scattering type
    banks: str
    banks: list[int]
        bank 1-6.
    xmin: float
        minimum x value
    xmax: float
        maximum x value
    refine_scale: bool
        turn on refinement of scale
    scale_factor: float
        initial scale factor
    output_path: str
        path to put output files
    num_cycles: int
@@ -51,49 +102,83 @@ def run_gsas2_fit(
        gsas2 .gpx project file
    """

    def HistStats(gpx):
        """prints profile rfactors for all histograms"""
        print("*** profile Rwp, " + os.path.split(gpx.filename)[1])
        for hist in gpx.histograms():
            print("\t{:20s}: {:.2f}".format(hist.name, hist.get_wR()))
        print("")

    print("INFO: Build GSAS-II Project File.")
    print("******************************")

    # start GSAS-II refinement
    # create a project file
    if os.path.exists(os.path.join(output_path, output_stem_fn + "_initial.gpx")):
        os.remove(os.path.join(output_path, output_stem_fn + "_initial.gpx"))
    gpx = G2sc.G2Project(newgpx=os.path.join(output_path, output_stem_fn + "_initial.gpx"))
    project_name = os.path.join(output_path, output_stem_fn + "_initial.gpx")

    # add six bank histograms to the project
    hists = []
    if stype == "N":
        # prmFile = "pdfitc/utils/NOMAD_2019B_Si_sixbanks_Shifter_instrument_file.prm"
        hist1 = gpx.add_powder_histogram(gsas_input_path, instrument_params_path, databank=bank, instbank=bank)
        hist1.set_refinements({"Limits": [xmin, xmax]})
    if stype == "X":
        # prmFile = "pdfitc/utils/PDFNSLSII.instprm"
        hist1 = gpx.add_powder_histogram(gsas_input_path, instrument_params_path)
        hist1.set_refinements({"Limits": [xmin, xmax]})
    if os.path.exists(project_name):
        os.remove(project_name)

    hists.append(hist1)
    gpx = G2sc.G2Project(newgpx = project_name)

    # step 2: add a phase and link it to the previous histograms
    _ = gpx.add_phase(structure_path, phasename="structure", fmthint="CIF", histograms=hists)
    print_stats(gpx)

    cell_i = gpx.phase("structure").get_cell()
    # step 1: add experiment data as histograms
    if isinstance(banks, int):
        banks = [banks]

    # step 3: increase # of cycles to improve convergence
    histograms = []
    for bank in banks:
        print(f"bank: {bank}")
        hist = gpx.add_powder_histogram(
            gsas_input_path,
            instrument_params_path,
        )
        histograms.append(hist)
    print(f"length of histograms: {len(histograms)}")

    # step 2: add phases and link it to the previous exp. data / histograms
    if isinstance(structure_paths, str):
        structure_paths = [structure_paths]

    phases = []
    for struct in structure_paths:
        phase = gpx.add_phase(
            struct,
            phasename="structure",
            fmthint="CIF",
            histograms=histograms,
        )
        phases.append(phase)

    # step 3: set the refinement engine parameters
    #   step 3.1 increase # of cycles to improve convergence
    gpx.data["Controls"]["data"]["max cyc"] = num_cycles

    # step 4: start refinement
    # refinement step 1: turn on  Histogram Scale factor
    refdict1 = {
        "set": {"Sample Parameters": ["Scale"]},
        "call": HistStats,
    }
    # step 4: create parameter limits to refine
    refine_params_list = []

    #   step 4.1: add xmin and xmax refinement limits
    for hist in histograms:
        default_limits = hist.data['Limits']
        xmin_default = default_limits[0]
        xmax_default = default_limits[1]

        if not xmin:
            xmin = xmin_default
        if not xmax:
            xmax = xmax_default

        limits = [np.float64(xmin), np.float64(xmax)]
        hist.set_refinements({"Limits": limits})

    #   step 4.2: add scale factor to refinement
    refdict1 = {"call": callback}

    for phase in gpx.phases():
        for hist in phase['Histograms']:
            if scale_factor:
                phase['Histograms'][hist]['Scale'][0] = scale_factor

            if refine_scale:
                phase['Histograms'][hist]['Scale'][1] = True
            else:
                phase['Histograms'][hist]['Scale'][1] = False

    # refinement step 2: turn on background refinement (Hist)
    if init_vals and "bkg" in init_vals:
        bkg_type = init_vals["bkg"]["Type"]
@@ -101,17 +186,17 @@ def run_gsas2_fit(
        coeffs = init_vals["bkg"]["Coeffs"]
        refdict2 = {
            "set": {"Background": {"type": bkg_type, "no. coeffs": num_coeffs, "coeffs": coeffs, "refine": True}},
            "call": HistStats,
            "call": callback,
        }
    else:
        refdict2 = {
            "set": {"Background": {"type": "chebyschev", "no. coeffs": 6, "refine": True}},
            "call": HistStats,
            "call": callback,
        }
    # refinement step 3: refine lattice parameter and Uiso refinement (Phase)
    refdict3 = {
        "set": {"Atoms": {"all": "U"}, "Cell": True},  # set the Uiso and lattice parameters to be refined
        "call": HistStats,
        "call": callback,
    }

    dictList = [refdict1, refdict2, refdict3]
@@ -120,31 +205,39 @@ def run_gsas2_fit(
    gpx.save(os.path.join(output_path, output_stem_fn + "_refined.gpx"))

    gpx.do_refinements(dictList)
    print("================")

    # save results data
    rw_list = []
    x_list = []
    y_list = []
    ycalc_list = []
    dy_list = []
    bkg_list = []
    refs_list = []

    for hist in histograms:
        rw = hist.get_wR() * 0.01
        x = np.array(hist.getdata("X"))
        y = np.array(hist.getdata("Yobs"))
        ycalc = np.array(hist.getdata("Ycalc"))
        dy = np.array(hist.getdata("Residual"))
        bkg = np.array(hist.getdata("Background"))
        refs = hist.reflections()

        rw_list.append(rw)
        x_list.append(x)
        y_list.append(y)
        ycalc_list.append(ycalc)
        dy_list.append(dy)
        bkg_list.append(bkg)
        refs_list.append(refs)

    rw = gpx.histogram(0).get_wR() * 0.01
    x = np.array(gpx.histogram(0).getdata("X"))
    y = np.array(gpx.histogram(0).getdata("Yobs"))
    ycalc = np.array(gpx.histogram(0).getdata("Ycalc"))
    dy = np.array(gpx.histogram(0).getdata("Residual"))
    bkg = np.array(gpx.histogram(0).getdata("Background"))

    refs = gpx.histogram(0).reflections()
    ref_list = refs["structure"]["RefList"]

    # output_cif_fn = os.path.join(os.getcwd(), 'data/bragg_gsasii/', output_stem_fn + "_refined.cif")
    output_cif_fn = os.path.join(output_path, output_stem_fn + "_refined.cif")
    gpx.phase("structure").export_CIF(output_cif_fn)
    cell_r = gpx.phase("structure").get_cell()
    # header = "Rw = {} \nx           ycalc           y           dy           bkg".format(rw)
    # np.savetxt(f"{output_stem_fn}bank{str(bank)}.dat",
    #            np.transpose([x, ycalc, y, dy, bkg]),
    #            fmt = '%f', delimiter=' ', header = header)
    # df = pd.DataFrame(
    #     {"rw": rw, "x": x, "y": y, "ycalc": ycalc, "dy": dy, "bkg": bkg})
    # df.update(cell)

    return rw, x, y, ycalc, dy, bkg, cell_i, cell_r, ref_list


    return rw_list, x_list, y_list, ycalc_list, dy_list, bkg_list, cell_r, ref_list
+22 −0
Original line number Diff line number Diff line
CIF=177063_P2_R-3m_NaNiFeMn111O2_CollCode177063.cif
INSTPRM=Aeris_Si_Exported_03062025.instprm
XYE=GSAS_Sim_NaNiFeMnO2_111.xye

#IMAGE=savannah.ornl.gov/asrp/gsas2_refinement@sha256:b0c351a9b9624e12dba5653a84c4c84e5246ee9e807228635eb0e89da67eb11c
IMAGE=asrp-gsas2

docker run \
    -v ${PWD}/${CIF}:/refinement/input.cif \
    -v ${PWD}/${XYE}:/refinement/input.xye \
    -v ${PWD}/${INSTPRM}:/refinement/input.instprm \
    -v ${PWD}/output:/app/portal \
    ${IMAGE} \
        pixi run python -m gsas2_refinement.gsas2_refinement \
        --cif-filename /refinement/input.cif \
        --gsas-filename /refinement/input.xye \
        --instrument-params-filename /refinement/input.instprm \
        --scatter-type X \
        --bank-id 0 \
        --xmin 7.0 \
        --xmax 120.0
Loading