Commit f772fbc9 authored by Jared Casper's avatar Jared Casper
Browse files

Only create task heads on last pipeline stage.

parent 6fa36844
Loading
Loading
Loading
Loading
+19 −16
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@

import torch

from megatron import get_args, print_rank_0
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
@@ -45,6 +45,7 @@ class ClassificationBase(PipelinedMegatronModule):
                                                         args.num_layers))

        # Multi-choice head.
        if mpu.is_pipeline_last_stage():
            self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.classification_head = get_linear_layer(args.hidden_size,
                                                        self.num_classes,
@@ -85,6 +86,7 @@ class ClassificationBase(PipelinedMegatronModule):
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        if mpu.is_pipeline_last_stage():
            state_dict_[self._classification_head_key] \
                = self.classification_head.state_dict(
                    destination, prefix, keep_vars)
@@ -95,11 +97,12 @@ class ClassificationBase(PipelinedMegatronModule):

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if mpu.is_pipeline_last_stage():
            if self._classification_head_key in state_dict:
                self.classification_head.load_state_dict(
                    state_dict[self._classification_head_key], strict=strict)
            else:
            print_rank_0('***WARNING*** could not find {} in the checkpoint, '
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._classification_head_key))

+18 −15
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@

import torch

from megatron import get_args, print_rank_0
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
@@ -44,6 +44,7 @@ class MultipleChoiceBase(PipelinedMegatronModule):
                                                         args.num_layers))

        # Multi-choice head.
        if mpu.is_pipeline_last_stage():
            self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.multichoice_head = get_linear_layer(args.hidden_size, 1,
                                                     init_method)
@@ -97,6 +98,7 @@ class MultipleChoiceBase(PipelinedMegatronModule):
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        if mpu.is_pipeline_last_stage():
            state_dict_[self._multichoice_head_key] \
                = self.multichoice_head.state_dict(
                    destination, prefix, keep_vars)
@@ -107,11 +109,12 @@ class MultipleChoiceBase(PipelinedMegatronModule):

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if mpu.is_pipeline_last_stage():
            if self._multichoice_head_key in state_dict:
                self.multichoice_head.load_state_dict(
                    state_dict[self._multichoice_head_key], strict=strict)
            else:
            print_rank_0('***WARNING*** could not find {} in the checkpoint, '
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._multichoice_head_key))