Commit b1ac9fd3 authored by mohammad's avatar mohammad
Browse files

tested and seems to be working

parent f257d0ee
Loading
Loading
Loading
Loading
+9 −1
Original line number Diff line number Diff line
@@ -89,6 +89,14 @@ def parse_args(extra_args_provider=None, defaults={},
        assert args.min_lr <= args.lr
    if args.save is not None:
        assert args.save_interval is not None
    # Parameters sharing does not work with torch DDP.
    if (args.num_unique_layers is not None) and (args.num_layers is not None):
        assert args.num_unique_layers <= args.num_layers
        assert args.num_layers % args.num_unique_layers == 0, \
            'num-layers should be divisible by num-unique-layers.'
        if args.num_unique_layers < args.num_layers:
            assert args.DDP_impl == 'local', \
                'torch-DDP does not work with parameters sharing.'

    _print_args(args)
    return args
@@ -120,7 +128,7 @@ def _add_network_size_args(parser):
                       help='Number of unique transformer layers. '
                       '`num-layers` should be divisible by this value.')
    group.add_argument('--param-sharing-style', default='grouped',
                       choices=['grouped', 'space'],
                       choices=['grouped', 'spaced'],
                       help='Ordering of the shared parameters. For example, '
                       'for a `num-layers`=4 and `--num-unique-layers`=2, '
                       'we will have the following ordering for two unique '
+3 −2
Original line number Diff line number Diff line
@@ -383,8 +383,9 @@ class ParallelTransformer(MegatronModule):
            if torch.distributed.get_rank() == 0:
                print('> will be using the following layer ordering:')
                for i in range(self.num_layers):
                    print('   layer: {:3d} --> unique layer: {:3d}'.format(
                        i, self._get_layer_index(i)), flush=True)
                    print('   layer id: {:3d} --> unique layer id: '
                          '{:3d}'.format(i, self._get_layer_index(i)),
                          flush=True)

        # Final layer norm before output.
        self.final_layernorm = LayerNorm(