Commit 4d15f4b1 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

new log method

parent d79bb994
Loading
Loading
Loading
Loading
+25 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
"""Model definition for Transformer-based models."""
import torch
import numpy as np
import torch.nn as nn
from tgreft.nn.transformer import PositionalEncoding

@@ -44,9 +45,31 @@ class REFL_GPT(nn.Module):
        self.input_dim = input_dim
        self.output_dim = output_dim

        # if use log, then we need to generate the weight
        # NOTE:
        #  1. we are using fixed q range for now
        #  2. we are using fixed number of q points for now
        #  3. this is to match the q range in utils.data.data_synthesis
        if self.to_log:
            q_range = np.logspace(
                np.log10(0.009),
                np.log10(0.18),
                num=150,
                )
            weights = q_range**2
            weights /= weights[0]
            self.weights = torch.tensor(weights).float()

    def forward(self, src):
        """Forward pass."""
        if self.to_log:
            # find the batch size
            batch_size = src.shape[0] // self.weights.shape[0]
            # repeat the weights
            weights = self.weights.repeat(batch_size)
            # multiply the weights
            src = src * weights
            # take the log
            src = torch.log(src)

        src = self.embedding(src)
@@ -58,4 +81,6 @@ class REFL_GPT(nn.Module):

    def to(self, *args, **kwargs):
        super(REFL_GPT, self).to(*args, **kwargs)
        # move the weights to the same device as the model
        self.weights = self.weights.to(*args, **kwargs)
        return self