Commit 61f50c45 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'onnx-erf' into 'master'

ONNX export fix, including one for gelu()

See merge request ADLR/megatron-lm!101
parents 4ac91436 35bea728
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -158,6 +158,8 @@ def _add_network_size_args(parser):
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
    group.add_argument('--onnx-safe', action='store_true',
                       help='Use workarounds for known problems with Torch ONNX exporter')

    return parser

+3 −1
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ from megatron import mpu
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
@@ -95,6 +95,8 @@ class BertLMHead(MegatronModule):
        self.gelu = torch.nn.functional.gelu
        if args.openai_gelu:
            self.gelu = openai_gelu
        elif args.onnx_safe:
            self.gelu = erf_gelu

    def forward(self, hidden_states, word_embeddings_weight):
        hidden_states = self.dense(hidden_states)
+3 −1
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer


@@ -52,6 +52,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
    gelu = F.gelu
    if args.openai_gelu:
        gelu = openai_gelu
    elif args.onnx_safe:
        gelu = erf_gelu
    
    # Language model.
    language_model = TransformerLanguageModel(
+4 −1
Original line number Diff line number Diff line
@@ -48,7 +48,6 @@ def get_linear_layer(rows, columns, init_method):
        layer.bias.zero_()
    return layer


@torch.jit.script
def gelu_impl(x):
    """OpenAI's gelu implementation."""
@@ -57,6 +56,10 @@ def gelu_impl(x):
def openai_gelu(x):
    return gelu_impl(x)

#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
    return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))

def get_params_for_weight_decay_optimization(module):
    """Divide params into with-weight-decay and without-weight-decay groups.
+14 −9
Original line number Diff line number Diff line
@@ -110,11 +110,12 @@ class VocabParallelEmbedding(torch.nn.Module):
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
        self.model_parallel_size = get_model_parallel_world_size()
        # Divide the weight matrix along the vocaburaly dimension.
        self.vocab_start_index, self.vocab_end_index = \
            VocabUtility.vocab_range_from_global_vocab_size(
                self.num_embeddings, get_model_parallel_rank(),
                get_model_parallel_world_size())
                self.model_parallel_size)
        self.num_embeddings_per_partition = self.vocab_end_index - \
            self.vocab_start_index

@@ -127,18 +128,22 @@ class VocabParallelEmbedding(torch.nn.Module):
            self.num_embeddings_per_partition, 0, init_method)

    def forward(self, input_):
        if self.model_parallel_size > 1:
            # Build the mask.
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_
            # Get the embeddings.
        output_parallel = F.embedding(masked_input, self.weight,
                                      self.padding_idx, self.max_norm,
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
        if self.model_parallel_size > 1:
            output_parallel[input_mask, :] = 0.0
        # Reduce across all the model parallel GPUs.
        output = reduce_from_model_parallel_region(output_parallel)
Loading