Commit e9eef962 authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'master_params_sharing' into 'master'

Parameters sharing

See merge request ADLR/megatron-lm!82
parents 3ee811be c9c69c1e
Loading
Loading
Loading
Loading
+18 −0
Original line number Diff line number Diff line
@@ -89,6 +89,14 @@ def parse_args(extra_args_provider=None, defaults={},
        assert args.min_lr <= args.lr
    if args.save is not None:
        assert args.save_interval is not None
    # Parameters sharing does not work with torch DDP.
    if (args.num_unique_layers is not None) and (args.num_layers is not None):
        assert args.num_unique_layers <= args.num_layers
        assert args.num_layers % args.num_unique_layers == 0, \
            'num-layers should be divisible by num-unique-layers.'
        if args.num_unique_layers < args.num_layers:
            assert args.DDP_impl == 'local', \
                'torch-DDP does not work with parameters sharing.'

    _print_args(args)
    return args
@@ -116,6 +124,16 @@ def _add_network_size_args(parser):

    group.add_argument('--num-layers', type=int, default=None,
                       help='Number of transformer layers.')
    group.add_argument('--num-unique-layers', type=int, default=None,
                       help='Number of unique transformer layers. '
                       '`num-layers` should be divisible by this value.')
    group.add_argument('--param-sharing-style', default='grouped',
                       choices=['grouped', 'spaced'],
                       help='Ordering of the shared parameters. For example, '
                       'for a `num-layers`=4 and `--num-unique-layers`=2, '
                       'we will have the following ordering for two unique '
                       'layers 1 and 2: '
                       '    grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
    group.add_argument('--hidden-size', type=int, default=None,
                       help='Tansformer hidden size.')
    group.add_argument('--num-attention-heads', type=int, default=None,
+37 −10
Original line number Diff line number Diff line
@@ -360,34 +360,60 @@ class ParallelTransformer(MegatronModule):
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers

        def get_layer(layer_number):
        # Number of layers:
        self.num_layers = args.num_layers
        self.num_unique_layers = args.num_unique_layers
        if self.num_unique_layers is None:
            self.num_unique_layers = self.num_layers
        assert self.num_layers % self.num_unique_layers == 0, \
            'number of layers should be divisible by number of unique layers'
        self.param_sharing_style = args.param_sharing_style

        # Transformer layers.
        def build_layer(layer_number):
            return ParallelTransformerLayer(
                attention_mask_func, mlp_activation_func,
                init_method, output_layer_init_method, layer_number)

        # Transformer layers.
        self.layers = torch.nn.ModuleList(
            [get_layer(i + 1) for i in range(args.num_layers)])
            [build_layer(i + 1) for i in range(self.num_unique_layers)])

        # Print layer ordering.
        if self.num_layers != self.num_unique_layers:
            if torch.distributed.get_rank() == 0:
                print('> will be using the following layer ordering:')
                for i in range(self.num_layers):
                    print('   layer id: {:3d} --> unique layer id: '
                          '{:3d}'.format(i, self._get_layer_index(i)),
                          flush=True)

        # Final layer norm before output.
        self.final_layernorm = LayerNorm(
            args.hidden_size,
            eps=args.layernorm_epsilon)

    def _get_layer_index(self, layer_number):
        if self.param_sharing_style == 'grouped':
            return layer_number % self.num_unique_layers
        if self.param_sharing_style == 'spaced':
            return layer_number // (self.num_layers // self.num_unique_layers) 
        assert False, 'should not be here'

    def _get_layer(self, layer_number):
        return self.layers[self._get_layer_index(layer_number)]

    def _checkpointed_forward(self, hidden_states, attention_mask):
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                layers_ = self.layers[start:end]
                x_ = inputs[0]
                for layer in layers_:
                for index in range(start, end):
                    layer = self._get_layer(index)
                    x_ = layer(x_, inputs[1])
                return x_
            return custom_forward

        l = 0
        num_layers = len(self.layers)
        while l < num_layers:
        while l < self.num_layers:
            hidden_states = mpu.checkpoint(
                custom(l, l + self.checkpoint_num_layers),
                hidden_states, attention_mask)
@@ -414,10 +440,11 @@ class ParallelTransformer(MegatronModule):
        else:
            if get_key_value:
                presents = []
            for i, layer in enumerate(self.layers):
            for index in range(self.num_layers):
                layer = self._get_layer(index)
                past = None
                if layer_past is not None:
                    past = layer_past[i]
                    past = layer_past[index]
                hidden_states = layer(hidden_states,
                                      attention_mask,
                                      layer_past=past,