Loading src/tgreft/models/refl_gpt.py +2 −6 Original line number Diff line number Diff line Loading @@ -63,12 +63,8 @@ class REFL_GPT(nn.Module): def forward(self, src): """Forward pass.""" if self.to_log: # find the batch size batch_size = src.shape[0] // self.weights.shape[0] # repeat the weights weights = self.weights.repeat(batch_size) # multiply the weights src = src * weights # multiply the weights, use tensor broadcasting to deal with batch src = src * self.weights # take the log src = torch.log(src) Loading Loading
src/tgreft/models/refl_gpt.py +2 −6 Original line number Diff line number Diff line Loading @@ -63,12 +63,8 @@ class REFL_GPT(nn.Module): def forward(self, src): """Forward pass.""" if self.to_log: # find the batch size batch_size = src.shape[0] // self.weights.shape[0] # repeat the weights weights = self.weights.repeat(batch_size) # multiply the weights src = src * weights # multiply the weights, use tensor broadcasting to deal with batch src = src * self.weights # take the log src = torch.log(src) Loading