Commit 3fc035d7 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'pipeline_refactor' into 'main'

Pipeline refactor

See merge request ADLR/megatron-lm!254
parents f2d64c00 e270f68a
Loading
Loading
Loading
Loading
+3 −11
Original line number Diff line number Diff line
@@ -15,16 +15,8 @@

from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm

from .distributed import *
from .bert_model import (BertModel,
                         BertModelFirstStage,
                         BertModelIntermediateStage,
                         BertModelLastStage)
from .gpt_model import (GPTModel,
                        GPTModelFirstStage,
                        GPTModelIntermediateStage,
                        GPTModelLastStage)
from .distributed import DistributedDataParallel
from .bert_model import BertModel
from .gpt_model import GPTModel
from .language_model import get_language_model
from .module import Float16Module

+35 −84
Original line number Diff line number Diff line
@@ -121,17 +121,23 @@ def post_language_model_processing(lm_output, pooled_output,
        return lm_loss, binary_logits


class BertModelBase(MegatronModule):
class BertModel(MegatronModule):
    """Bert Language model."""

    def __init__(self, num_tokentypes=2, add_binary_head=True,
                 parallel_output=True):
        super(BertModelBase, self).__init__()
    def __init__(self,
                 num_tokentypes=2,
                 add_binary_head=True,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super(BertModel, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process

        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
@@ -142,10 +148,12 @@ class BertModelBase(MegatronModule):
            add_pooler=self.add_binary_head,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method)
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process)

        self.initialize_word_embeddings(init_method_normal)
        if mpu.is_pipeline_last_stage():
        if self.post_process:
            self.lm_head = BertLMHead(
                self.word_embeddings_weight().size(0),
                args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
@@ -156,26 +164,30 @@ class BertModelBase(MegatronModule):
                                                    init_method)
                self._binary_head_key = 'binary_head'

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, bert_model_input, attention_mask,
                tokentype_ids=None, lm_labels=None):

        extended_attention_mask = bert_extended_attention_mask(attention_mask)

        kwargs = {}
        if mpu.is_pipeline_first_stage():
        input_ids = bert_model_input
        position_ids = bert_position_ids(input_ids)
            args = [input_ids, position_ids, extended_attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
        else:
            args = [bert_model_input, extended_attention_mask]
        lm_output = self.language_model(*args, **kwargs)
        if mpu.is_pipeline_last_stage() and self.add_binary_head:

        lm_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids
        )

        if self.post_process and self.add_binary_head:
            lm_output, pooled_output = lm_output
        else:
            pooled_output = None

        if mpu.is_pipeline_last_stage():
        if self.post_process:
            return post_language_model_processing(lm_output, pooled_output,
                                                  self.lm_head, self.binary_head,
                                                  lm_labels,
@@ -194,15 +206,15 @@ class BertModelBase(MegatronModule):
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars)
        if mpu.is_pipeline_last_stage():
        if self.post_process:
            state_dict_[self._lm_head_key] \
                = self.lm_head.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        if mpu.is_pipeline_last_stage() and self.add_binary_head:
        if self.post_process and self.add_binary_head:
            state_dict_[self._binary_head_key] \
                = self.binary_head.state_dict(destination, prefix, keep_vars)
        # Save word_embeddings.
        if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
        if self.post_process and not self.pre_process:
            state_dict_[self._word_embeddings_for_head_key] \
                = self.word_embeddings.state_dict(destination, prefix, keep_vars)
        return state_dict_
@@ -212,74 +224,13 @@ class BertModelBase(MegatronModule):

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if mpu.is_pipeline_last_stage():
        if self.post_process:
            self.lm_head.load_state_dict(
                state_dict[self._lm_head_key], strict=strict)
        if mpu.is_pipeline_last_stage() and self.add_binary_head:
        if self.post_process and self.add_binary_head:
            self.binary_head.load_state_dict(
                state_dict[self._binary_head_key], strict=strict)
        # Load word_embeddings.
        if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
        if self.post_process and not self.pre_process:
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)


class BertModel(BertModelBase):

    def __init__(self, num_tokentypes=2, add_binary_head=True,
                 parallel_output=True):
        super(BertModel, self).__init__(
            num_tokentypes=num_tokentypes,
            add_binary_head=add_binary_head,
            parallel_output=parallel_output)

    def forward(self, input_ids, attention_mask,
                tokentype_ids=None, lm_labels=None):
        return super(BertModel, self).forward(
            input_ids,
            attention_mask,
            tokentype_ids=tokentype_ids,
            lm_labels=lm_labels)


class BertModelFirstStage(BertModelBase):

    def __init__(self, num_tokentypes=2):
        super(BertModelFirstStage, self).__init__(
            num_tokentypes=num_tokentypes)

    def forward(self, input_ids, attention_mask,
                tokentype_ids=None):
        return super(BertModelFirstStage, self).forward(
            input_ids,
            attention_mask,
            tokentype_ids=tokentype_ids)


class BertModelIntermediateStage(BertModelBase):

    def __init__(self, num_tokentypes=2):
        super(BertModelIntermediateStage, self).__init__(
            num_tokentypes=num_tokentypes)

    def forward(self, hidden_state, attention_mask):
        return super(BertModelIntermediateStage, self).forward(
            hidden_state,
            attention_mask)


class BertModelLastStage(BertModelBase):

    def __init__(self, num_tokentypes=2, add_binary_head=True,
                 parallel_output=True):
        super(BertModelLastStage, self).__init__(
            num_tokentypes=num_tokentypes,
            add_binary_head=add_binary_head,
            parallel_output=parallel_output)

    def forward(self, hidden_state, attention_mask,
                lm_labels=None):
        return super(BertModelLastStage, self).forward(
            hidden_state,
            attention_mask,
            lm_labels=lm_labels)
+30 −71
Original line number Diff line number Diff line
@@ -28,13 +28,19 @@ from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule


class ClassificationBase(MegatronModule):

    def __init__(self, num_classes, num_tokentypes=2):
        super(ClassificationBase, self).__init__(share_word_embeddings=False)
class Classification(MegatronModule):

    def __init__(self,
                 num_classes,
                 num_tokentypes=2,
                 pre_process=True,
                 post_process=True):
        super(Classification, self).__init__(share_word_embeddings=False)
        args = get_args()

        self.num_classes = num_classes
        self.pre_process = pre_process
        self.post_process = post_process
        init_method = init_method_normal(args.init_method_std)

        self.language_model, self._language_model_key = get_language_model(
@@ -43,31 +49,36 @@ class ClassificationBase(MegatronModule):
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers))
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)

        # Multi-choice head.
        if mpu.is_pipeline_last_stage():
        if self.post_process:
            self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.classification_head = get_linear_layer(args.hidden_size,
                                                        self.num_classes,
                                                        init_method)
            self._classification_head_key = 'classification_head'

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, model_input, attention_mask, tokentype_ids=None):

        extended_attention_mask = bert_extended_attention_mask(attention_mask)

        kwargs = {}
        if mpu.is_pipeline_first_stage():
        input_ids = model_input
        position_ids = bert_position_ids(input_ids)

            args = [input_ids, position_ids, extended_attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
        else:
            args = [model_input, extended_attention_mask]
        lm_output = self.language_model(*args, **kwargs)
        if mpu.is_pipeline_last_stage():
        lm_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids
        )

        if self.post_process:
            _, pooled_output = lm_output
            classification_output = self.classification_dropout(pooled_output)
            classification_logits = self.classification_head(classification_output)
@@ -87,7 +98,7 @@ class ClassificationBase(MegatronModule):
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        if mpu.is_pipeline_last_stage():
        if self.post_process:
            state_dict_[self._classification_head_key] \
                = self.classification_head.state_dict(
                    destination, prefix, keep_vars)
@@ -98,7 +109,7 @@ class ClassificationBase(MegatronModule):

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if mpu.is_pipeline_last_stage():
        if self.post_process:
            if self._classification_head_key in state_dict:
                self.classification_head.load_state_dict(
                    state_dict[self._classification_head_key], strict=strict)
@@ -106,55 +117,3 @@ class ClassificationBase(MegatronModule):
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._classification_head_key))


class Classification(ClassificationBase):

    def __init__(self, num_classes, num_tokentypes=2):
        super(Classification, self).__init__(
            num_classes, num_tokentypes=num_tokentypes)

    def forward(self, input_ids, attention_mask,
                tokentype_ids=None):
        return super(Classification, self).forward(
            input_ids,
            attention_mask,
            tokentype_ids=tokentype_ids)


class ClassificationFirstStage(ClassificationBase):

    def __init__(self, num_classes, num_tokentypes=2):
        super(ClassificationFirstStage, self).__init__(
            num_classes, num_tokentypes=num_tokentypes)

    def forward(self, input_ids, attention_mask,
                tokentype_ids=None):
        return super(ClassificationFirstStage, self).forward(
            input_ids,
            attention_mask,
            tokentype_ids=tokentype_ids)


class ClassificationIntermediateStage(ClassificationBase):

    def __init__(self, num_classes, num_tokentypes=2):
        super(ClassificationIntermediateStage, self).__init__(
            num_classes, num_tokentypes=num_tokentypes)

    def forward(self, hidden_state, attention_mask):
        return super(ClassificationIntermediateStage, self).forward(
            hidden_state,
            attention_mask)


class ClassificationLastStage(ClassificationBase):

    def __init__(self, num_classes, num_tokentypes=2):
        super(ClassificationLastStage, self).__init__(
            num_classes, num_tokentypes=num_tokentypes)

    def forward(self, hidden_state, attention_mask):
        return super(ClassificationLastStage, self).forward(
            hidden_state,
            attention_mask)
+26 −86
Original line number Diff line number Diff line
@@ -57,14 +57,20 @@ def post_language_model_processing(lm_output, labels, logit_weights,
        return loss


class GPTModelBase(MegatronModule):
class GPTModel(MegatronModule):
    """GPT-2 Language model."""

    def __init__(self, num_tokentypes=0, parallel_output=True):
        super(GPTModelBase, self).__init__()
    def __init__(self,
                 num_tokentypes=0,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super(GPTModel, self).__init__()
        args = get_args()

        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy

        self.language_model, self._language_model_key = get_language_model(
@@ -73,24 +79,28 @@ class GPTModelBase(MegatronModule):
            encoder_attn_mask_type=AttnMaskType.causal,
            init_method=init_method_normal(args.init_method_std),
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers))
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)

        self.initialize_word_embeddings(init_method_normal)

    def forward(self, gpt_model_input, attention_mask, labels=None,
    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, input_ids, position_ids, attention_mask, labels=None,
                tokentype_ids=None, layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):

        kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value}
        if mpu.is_pipeline_first_stage():
            (input_ids, position_ids) = gpt_model_input
            args = [input_ids, position_ids, attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
        else:
            args = [gpt_model_input, attention_mask]
        lm_output = self.language_model(*args, **kwargs)
        lm_output = self.language_model(
            input_ids,
            position_ids,
            attention_mask,
            layer_past=layer_past,
            get_key_value=get_key_value)

        if mpu.is_pipeline_last_stage():
        if self.post_process:
            return post_language_model_processing(
                lm_output, labels,
                self.word_embeddings_weight(),
@@ -109,7 +119,7 @@ class GPTModelBase(MegatronModule):
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        # Save word_embeddings.
        if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
        if self.post_process and not self.pre_process:
            state_dict_[self._word_embeddings_for_head_key] \
                = self.word_embeddings.state_dict(destination, prefix, keep_vars)
        return state_dict_
@@ -118,79 +128,9 @@ class GPTModelBase(MegatronModule):
        """Customized load."""

        # Load word_embeddings.
        if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
        if self.post_process and not self.pre_process:
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)
        if self._language_model_key in state_dict:
            state_dict = state_dict[self._language_model_key]
        self.language_model.load_state_dict(state_dict, strict=strict)


class GPTModel(GPTModelBase):

    def __init__(self, num_tokentypes=0, parallel_output=True):
        super(GPTModel, self).__init__(
            num_tokentypes=num_tokentypes,
            parallel_output=parallel_output)

    def forward(self, input_ids, position_ids, attention_mask, labels=None,
                tokentype_ids=None, layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):
        return super(GPTModel, self).forward(
            (input_ids, position_ids),
            attention_mask,
            labels=labels,
            tokentype_ids=tokentype_ids,
            layer_past=layer_past,
            get_key_value=get_key_value,
            forward_method_parallel_output=forward_method_parallel_output)


class GPTModelFirstStage(GPTModelBase):

    def __init__(self, num_tokentypes=0):
        super(GPTModelFirstStage, self).__init__(
            num_tokentypes=num_tokentypes)

    def forward(self, input_ids, position_ids, attention_mask,
                tokentype_ids=None, layer_past=None, get_key_value=False):
        return super(GPTModelFirstStage, self).forward(
            (input_ids, position_ids),
            attention_mask,
            tokentype_ids=tokentype_ids,
            layer_past=layer_past,
            get_key_value=get_key_value)


class GPTModelIntermediateStage(GPTModelBase):

    def __init__(self, num_tokentypes=0):
        super(GPTModelIntermediateStage, self).__init__(
            num_tokentypes=num_tokentypes)

    def forward(self, hidden_state, attention_mask,
                layer_past=None, get_key_value=False):
        return super(GPTModelIntermediateStage, self).forward(
            hidden_state,
            attention_mask,
            layer_past=layer_past,
            get_key_value=get_key_value)


class GPTModelLastStage(GPTModelBase):

    def __init__(self, num_tokentypes=0, parallel_output=True):
        super(GPTModelLastStage, self).__init__(
            num_tokentypes=num_tokentypes,
            parallel_output=parallel_output)

    def forward(self, hidden_state, attention_mask, labels=None,
                layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):
        return super(GPTModelLastStage, self).forward(
            hidden_state,
            attention_mask,
            labels=labels,
            layer_past=layer_past,
            get_key_value=get_key_value,
            forward_method_parallel_output=forward_method_parallel_output)
+42 −162

File changed.

Preview size limit exceeded, changes collapsed.

Loading