Commit c8c40598 authored by Mathieu Doucet's avatar Mathieu Doucet
Browse files

update fit ranges

parent d2abc0f0
Loading
Loading
Loading
Loading
+22 −42
Original line number Diff line number Diff line
%% Cell type:markdown id: tags:

# Overview

This notebook shows how to use the trained model to extract the parameters from reflectivity curves.

%% Cell type:code id: tags:

``` python
import sys
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
```

%% Cell type:markdown id: tags:

## Make example data

%% Cell type:code id: tags:

``` python
from tgreft.utils.data.data_loader import generate_data, param_to_rcurve
```

%% Cell type:code id: tags:

``` python
generate_data?
```

%% Output


%% Cell type:markdown id: tags:

params:
- electolyte_sld
- sei_sld
- sei_thickness
- sei_roughness
- si_sld
- si_thickness
- si_roughness
- cu_sld
- cu_thickness
- cu_roughness
- oxide_sld
- oxide_thickness
- oxide_roughness

%% Cell type:code id: tags:

``` python
params, rcurves = generate_data(n_dataset=5, error=0.07)
q = np.logspace(np.log10(0.009), np.log10(0.18), num=150)
```

%% Cell type:code id: tags:

``` python
rcurves.shape, rcurves.dtype
```

%% Output

    ((5, 150), dtype('float64'))

%% Cell type:code id: tags:

``` python
# plot the curves generated (with error)
plt.figure(figsize=(8, 6))
for i in range(len(rcurves)):
    plt.plot(q, rcurves[i], label=f"curve {i}")
plt.yscale("log")
plt.xscale("log")
plt.legend()
plt.show()
```

%% Output


%% Cell type:markdown id: tags:

## Load the model

%% Cell type:code id: tags:

``` python
from tgreft.models.refl_gpt import REFL_GPT
```

%% Cell type:code id: tags:

``` python
model = REFL_GPT(
    d_model= 1024,
    nhead=8,
    num_encoder_layers=4,
    input_dim=150,
    output_dim=17,
    to_log=True,
)

model
```

%% Output

    /Users/m2d/miniconda3/envs/tgreft_dev/lib/python3.11/site-packages/torch/nn/modules/transformer.py:282: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
      warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")

    REFL_GPT(
      (embedding): Linear(in_features=150, out_features=1024, bias=True)
      (positional_encoding): PositionalEncoding()
      (transformer_encoder): TransformerEncoder(
        (layers): ModuleList(
          (0-3): 4 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
            )
            (linear1): Linear(in_features=1024, out_features=2048, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear2): Linear(in_features=2048, out_features=1024, bias=True)
            (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (dropout1): Dropout(p=0.1, inplace=False)
            (dropout2): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (decoder): Linear(in_features=1024, out_features=17, bias=True)
    )

%% Cell type:code id: tags:

``` python
# load the weights
pre_trained_weights = "../models/model_gpt.pt"
model.load_state_dict(
    torch.load(
        pre_trained_weights,
        map_location=torch.device("cpu"),  # model is trained on GPU, so need to map to CPU if no GPU is available
        ),
)
```

%% Output

    <All keys matched successfully>

%% Cell type:code id: tags:

``` python
# print total number of parameters
sum(p.numel() for p in model.parameters() if p.requires_grad)
```

%% Output

    33771537

%% Cell type:markdown id: tags:

## Evaluate the model

%% Cell type:code id: tags:

``` python
# Map R(q) to training q range

d = np.loadtxt(os.path.expanduser("~/git/hygnn2/data/REFL_207282_combined_data_auto.txt")).T
#d = np.loadtxt(os.path.expanduser("~/git/hygnn2/data/REFL_207268_combined_data_auto.txt")).T
#d = np.loadtxt(os.path.expanduser("~/git/hygnn2/data/REFL_206915_combined_data_auto.txt")).T
d = np.loadtxt(os.path.expanduser("~/git/hygnn2/data/REFL_206938_combined_data_auto.txt")).T


scale = 1/1.1
meas = scale * np.interp(q, d[0], d[1])
d_meas = scale * np.interp(q, d[0], d[2])
```

%% Cell type:code id: tags:

``` python
input_data = torch.from_numpy(meas).float()
print(meas.shape)
# inference
params_pred = model(input_data).detach().numpy()[0]


print("\tparam\tinference")
print(f"\telectolyte_sld \t{params_pred[0]}")
print(f"\telectolyte_roughness \t{params_pred[1]}")
print(f"\tsei_sld \t{params_pred[2]}")
print(f"\tsei_thickness \t{params_pred[3]}")
print(f"\tsei_roughness \t{params_pred[4]}")
print(f"\tplated_sld \t{params_pred[5]}")
print(f"\tplated_thickness \t{params_pred[6]}")
print(f"\tplated_roughness \t{params_pred[7]}")
print(f"\tcu_sld \t{params_pred[8]}")
print(f"\tcu_thickness \t{params_pred[9]}")
print(f"\tcu_roughness \t{params_pred[10]}")
print(f"\tTi_sld \t{params_pred[11]}")
print(f"\tTi_thickness \t{params_pred[12]}")
print(f"\tTi_roughness \t{params_pred[13]}")
print(f"\tSiOx_sld \t{params_pred[14]}")
print(f"\tSiOx_thickness \t{params_pred[15]}")
print(f"\tSiOx_roughness \t{params_pred[16]}")



#params_pred[0]=6.13
#params_pred[1]=54
#params_pred[4]=30
#params_pred[2]=-2.88
#params_pred[3]=252

params_pred[7]=9.7
#params_pred[5]=0.94
#params_pred[16]=2.9
#params_pred[14]=2.7
#params_pred[11]=-2
rq = param_to_rcurve(params_pred)
print(rq.shape)

```

%% Output

    (150,)
    	param	inference
    	electolyte_sld 	6.065761566162109
    	electolyte_roughness 	80.83588409423828
    	sei_sld 	0.95870041847229
    	sei_thickness 	181.23532104492188
    	sei_roughness 	37.07640838623047
    	plated_sld 	2.433926582336426
    	plated_thickness 	105.890380859375
    	plated_roughness 	15.010866165161133
    	cu_sld 	6.467490196228027
    	cu_thickness 	563.5714721679688
    	cu_roughness 	12.106481552124023
    	Ti_sld 	-1.6616637706756592
    	Ti_thickness 	55.24467468261719
    	Ti_roughness 	10.041732788085938
    	SiOx_sld 	3.1505918502807617
    	SiOx_thickness 	17.92501449584961
    	SiOx_roughness 	4.21012544631958
    (150,)

%% Cell type:code id: tags:

``` python
plt.figure(figsize=(8, 6))
plt.errorbar(q, meas, yerr=d_meas, label="Data")
plt.plot(q, rq, label="Prediction")
plt.yscale("log")
plt.xscale("log")
plt.legend()
plt.show()
```

%% Output


%% Cell type:raw id: tags:

	param	           inference
	electolyte_sld   	5.927164077758789          6.13
	electolyte_roughness 	84.17691802978516      54
	sei_sld 	        4.5298752784729            3.8
	sei_thickness    	181.52833557128906         158
	sei_roughness     	16.39092445373535          13
	plated_sld 	        1.6382867097854614         2.7
	plated_thickness 	38.12439727783203          54
	plated_roughness 	14.366933822631836         9.7
	cu_sld 	            6.530182838439941          6.5
	cu_thickness 	    566.5345458984375          561
	cu_roughness 	    12.822381973266602          15.3
	Ti_sld           	-1.198122501373291         -2
	Ti_thickness    	54.66168975830078          50
	Ti_roughness 	    15.442959785461426         11.1
	SiOx_sld 	        2.677332639694214          3.2
	SiOx_thickness 	    20.91607666015625          18.8
	SiOx_roughness 	    4.87614107131958           2.9


%% Cell type:code id: tags:

``` python
from refl1d.names import Experiment, FitProblem
from bumps.fitters import fit
from tgreft.utils.data.data_synthesis import RCurveGenerator

r_generator = RCurveGenerator()

param = params_pred
config = [
        {"name": "electrolyte", "sld": param[0], "isld": 0, "thickness": 0, "roughness": param[1]},
        {"name": "SEI", "sld": param[2], "isld": 0, "thickness": param[3], "roughness": param[4]},
        {"name": "material", "sld": param[5], "isld": 0, "thickness": param[6], "roughness": param[7]},
        {"name": "Cu", "sld": param[8], "isld": 0, "thickness": param[9], "roughness": param[10]},
        {"name": "Ti", "sld": param[11], "isld": 0, "thickness": param[12], "roughness": param[13]},
        {"name": "oxide", "sld": param[14], "isld": 0, "thickness": param[15], "roughness": param[16]},
        {"name": "substrate", "sld": 2.07, "isld": 0, "thickness": 0, "roughness": 0},
    ]

sample = r_generator.build_sample_from_config(config)

# Ranges
value = sample['electrolyte'].material.rho.value
sample['electrolyte'].material.rho.range(value*0.85, value*1.5)
value = sample['electrolyte'].interface.value
sample['electrolyte'].interface.range(value*0.1, value*2)

value = sample['SEI'].material.rho.value
sample['SEI'].material.rho.range(value*0.5, value*1.5)
value = sample['SEI'].interface.value
sample['SEI'].interface.range(value*0.1, value*2)
value = sample['SEI'].thickness.value
sample['SEI'].thickness.range(value*0.1, value*2)

value = sample['material'].material.rho.value
sample['material'].material.rho.range(value*0.5, value*1.5)
value = sample['material'].interface.value
sample['material'].interface.range(value*0.1, value*2)
value = sample['material'].thickness.value
sample['material'].thickness.range(value*0.1, value*2)

value = sample['Cu'].material.rho.value
sample['Cu'].material.rho.range(value*0.5, value*1.5)
value = sample['Cu'].interface.value
sample['Cu'].interface.range(value*0.1, value*2)
value = sample['Cu'].thickness.value
sample['Cu'].thickness.range(value*0.1, value*2)

value = sample['Ti'].material.rho.value
sample['Ti'].material.rho.range(value*0.5, value*1.5)
value = sample['Ti'].interface.value
sample['Ti'].interface.range(value*0.1, value*2)
value = sample['Ti'].thickness.value
sample['Ti'].thickness.range(value*0.1, value*2)

value = sample['oxide'].material.rho.value
sample['oxide'].material.rho.range(value*0.5, value*1.5)
value = sample['oxide'].interface.value
sample['oxide'].interface.range(value*0.1, value*2)
value = sample['oxide'].thickness.value
sample['oxide'].thickness.range(value*0.1, value*2)
sample['electrolyte'].material.rho.range(param[0]*0.85, param[0]*1.5)
sample['electrolyte'].interface.range(param[1]*0.1, param[1]*2)

sample['SEI'].material.rho.range(param[2]*0.5, param[2]*1.5)
sample['SEI'].interface.range(param[4]*0.1, param[4]*2)
sample['SEI'].thickness.range(param[3]*0.1, param[3]*2)

sample['material'].material.rho.range(param[5]*0.5, param[5]*1.5)
sample['material'].interface.range(param[7]*0.1, param[7]*2)
sample['material'].thickness.range(param[6]*0.1, param[6]*2)

sample['Cu'].material.rho.range(param[8]*0.5, param[8]*1.5)
sample['Cu'].interface.range(param[10]*0.1, param[10]*2)
sample['Cu'].thickness.range(param[9]*0.1, param[9]*2)

sample['Ti'].material.rho.range(param[11]*0.5, param[11]*1.5)
sample['Ti'].interface.range(param[13]*0.1, param[13]*2)
sample['Ti'].thickness.range(param[12]*0.1, param[12]*2)

sample['oxide'].material.rho.range(param[14]*0.5, param[14]*1.5)
sample['oxide'].interface.range(param[16]*0.1, param[16]*2)
sample['oxide'].thickness.range(param[15]*0.1, param[15]*2)



# get the probe
probe = r_generator.get_prob(meas, d_meas)

# get the experiment
experiment = Experiment(probe=probe, sample=sample)
_, r_curve = experiment.reflectivity()


#].thickness.range(20.0, 950.0)
#        sample['l%s' % i].interface.range(20.0, 60.0)

problem = FitProblem(experiment)
results = fit(problem, method='amoeba', samples=2000, burn=2000, pop=20, verbose=True)

# Results are in the wrong order. It's interface, rho, thickness...
fit_pars = [results.x[1], results.x[0]]

n_layers = int((len(results.x)-1)/3)
for i in range(n_layers):
    fit_pars.extend([results.x[3*i+3], results.x[3*i+4], results.x[3*i+2]])



param = results.x
print(fit_pars)
rq = param_to_rcurve(fit_pars)
```

%% Output

    step 1 cost 11.59(14)
                       electrolyte interface ....|.....    80.8359 in (8.08359,161.672)
                             electrolyte rho ..|.......    6.06576 in (5.1559,9.09864)
                               SEI interface ....|.....    37.0764 in (3.70764,74.1528)
                                     SEI rho ....|.....     0.9587 in (0.47935,1.43805)
                               SEI thickness ....|.....    181.235 in (18.1235,362.471)
                          material interface ....|.....        9.7 in (0.97,19.4)
                                material rho ....|.....    2.43393 in (1.21696,3.65089)
                          material thickness ....|.....     105.89 in (10.589,211.781)
                                Cu interface ....|.....    12.1065 in (1.21065,24.213)
                                      Cu rho ....|.....    6.46749 in (3.23375,9.70124)
                                Cu thickness ....|.....    563.571 in (56.3571,1127.14)
                                Ti interface ....|.....    10.0417 in (1.00417,20.0835)
                                      Ti rho ....|.....   -1.66166 in (-2.4925,-0.830832)
                                Ti thickness ....|.....    55.2447 in (5.52447,110.489)
                             oxide interface ....|.....    4.21013 in (0.421013,8.42025)
                                   oxide rho ......|...    3.62318 in (1.5753,4.72589)
                             oxide thickness ....|.....     17.925 in (1.7925,35.85)
    final chisq 1.68(14)
    === Uncertainty from curvature:     name   value(unc.) ===
                       electrolyte interface   82(17)
                             electrolyte rho   5.16(64)
                               SEI interface   58(28)
                                     SEI rho   0.5(29)
                               SEI thickness   138(49)
                          material interface   10.03(62)
                                material rho   2.45(13)
                          material thickness   126(26)
                                Cu interface   9.5(11)
                                      Cu rho   6.501(49)
                                Cu thickness   563.91(81)
                                Ti interface   0.010(18)e3
                                      Ti rho   -2.27(40)
                                Ti thickness   53(29)
                             oxide interface   0.008(26)e3
                                   oxide rho   0.004(15)e3
                             oxide thickness   0.02(11)e3
    ==========================================================
    [5.158087476873803, 82.34090488378104, 0.4800620087807319, 138.1971129524162, 58.494522716723885, 2.4541655780713767, 125.51315859240731, 10.025993328729458, 6.501346774112443, 563.9077166681218, 9.519062407360048, -2.2697454815218476, 53.37201495928419, 9.721640385165621, 3.5665434323387517, 20.371161850439968, 8.4196499703433]

%% Cell type:code id: tags:

``` python
plt.figure(figsize=(8, 6))
plt.errorbar(q, meas, yerr=d_meas, label="Data")
plt.plot(q, rq, label="Prediction")
plt.yscale("log")
plt.xscale("log")
plt.legend()
plt.show()
```

%% Output


%% Cell type:code id: tags:

``` python
```