Commit 410ec167 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

fix errors in broadcasting

parent 0d40c194
Loading
Loading
Loading
Loading
+2 −6
Original line number Diff line number Diff line
@@ -63,12 +63,8 @@ class REFL_GPT(nn.Module):
    def forward(self, src):
        """Forward pass."""
        if self.to_log:
            # find the batch size
            batch_size = src.shape[0] // self.weights.shape[0]
            # repeat the weights
            weights = self.weights.repeat(batch_size)
            # multiply the weights
            src = src * weights
            # multiply the weights, use tensor broadcasting to deal with batch
            src = src * self.weights
            # take the log
            src = torch.log(src)