Loading src/tgreft/models/refl_gpt.py +25 −0 Original line number Diff line number Diff line #!/usr/bin/env python3 """Model definition for Transformer-based models.""" import torch import numpy as np import torch.nn as nn from tgreft.nn.transformer import PositionalEncoding Loading Loading @@ -44,9 +45,31 @@ class REFL_GPT(nn.Module): self.input_dim = input_dim self.output_dim = output_dim # if use log, then we need to generate the weight # NOTE: # 1. we are using fixed q range for now # 2. we are using fixed number of q points for now # 3. this is to match the q range in utils.data.data_synthesis if self.to_log: q_range = np.logspace( np.log10(0.009), np.log10(0.18), num=150, ) weights = q_range**2 weights /= weights[0] self.weights = torch.tensor(weights).float() def forward(self, src): """Forward pass.""" if self.to_log: # find the batch size batch_size = src.shape[0] // self.weights.shape[0] # repeat the weights weights = self.weights.repeat(batch_size) # multiply the weights src = src * weights # take the log src = torch.log(src) src = self.embedding(src) Loading @@ -58,4 +81,6 @@ class REFL_GPT(nn.Module): def to(self, *args, **kwargs): super(REFL_GPT, self).to(*args, **kwargs) # move the weights to the same device as the model self.weights = self.weights.to(*args, **kwargs) return self Loading
src/tgreft/models/refl_gpt.py +25 −0 Original line number Diff line number Diff line #!/usr/bin/env python3 """Model definition for Transformer-based models.""" import torch import numpy as np import torch.nn as nn from tgreft.nn.transformer import PositionalEncoding Loading Loading @@ -44,9 +45,31 @@ class REFL_GPT(nn.Module): self.input_dim = input_dim self.output_dim = output_dim # if use log, then we need to generate the weight # NOTE: # 1. we are using fixed q range for now # 2. we are using fixed number of q points for now # 3. this is to match the q range in utils.data.data_synthesis if self.to_log: q_range = np.logspace( np.log10(0.009), np.log10(0.18), num=150, ) weights = q_range**2 weights /= weights[0] self.weights = torch.tensor(weights).float() def forward(self, src): """Forward pass.""" if self.to_log: # find the batch size batch_size = src.shape[0] // self.weights.shape[0] # repeat the weights weights = self.weights.repeat(batch_size) # multiply the weights src = src * weights # take the log src = torch.log(src) src = self.embedding(src) Loading @@ -58,4 +81,6 @@ class REFL_GPT(nn.Module): def to(self, *args, **kwargs): super(REFL_GPT, self).to(*args, **kwargs) # move the weights to the same device as the model self.weights = self.weights.to(*args, **kwargs) return self