Loading megatron/arguments.py +18 −0 Original line number Diff line number Diff line Loading @@ -89,6 +89,14 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.min_lr <= args.lr if args.save is not None: assert args.save_interval is not None # Parameters sharing does not work with torch DDP. if (args.num_unique_layers is not None) and (args.num_layers is not None): assert args.num_unique_layers <= args.num_layers assert args.num_layers % args.num_unique_layers == 0, \ 'num-layers should be divisible by num-unique-layers.' if args.num_unique_layers < args.num_layers: assert args.DDP_impl == 'local', \ 'torch-DDP does not work with parameters sharing.' _print_args(args) return args Loading Loading @@ -116,6 +124,16 @@ def _add_network_size_args(parser): group.add_argument('--num-layers', type=int, default=None, help='Number of transformer layers.') group.add_argument('--num-unique-layers', type=int, default=None, help='Number of unique transformer layers. ' '`num-layers` should be divisible by this value.') group.add_argument('--param-sharing-style', default='grouped', choices=['grouped', 'spaced'], help='Ordering of the shared parameters. For example, ' 'for a `num-layers`=4 and `--num-unique-layers`=2, ' 'we will have the following ordering for two unique ' 'layers 1 and 2: ' ' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].') group.add_argument('--hidden-size', type=int, default=None, help='Tansformer hidden size.') group.add_argument('--num-attention-heads', type=int, default=None, Loading megatron/model/transformer.py +37 −10 Original line number Diff line number Diff line Loading @@ -360,34 +360,60 @@ class ParallelTransformer(MegatronModule): self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers def get_layer(layer_number): # Number of layers: self.num_layers = args.num_layers self.num_unique_layers = args.num_unique_layers if self.num_unique_layers is None: self.num_unique_layers = self.num_layers assert self.num_layers % self.num_unique_layers == 0, \ 'number of layers should be divisible by number of unique layers' self.param_sharing_style = args.param_sharing_style # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer( attention_mask_func, mlp_activation_func, init_method, output_layer_init_method, layer_number) # Transformer layers. self.layers = torch.nn.ModuleList( [get_layer(i + 1) for i in range(args.num_layers)]) [build_layer(i + 1) for i in range(self.num_unique_layers)]) # Print layer ordering. if self.num_layers != self.num_unique_layers: if torch.distributed.get_rank() == 0: print('> will be using the following layer ordering:') for i in range(self.num_layers): print(' layer id: {:3d} --> unique layer id: ' '{:3d}'.format(i, self._get_layer_index(i)), flush=True) # Final layer norm before output. self.final_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) def _get_layer_index(self, layer_number): if self.param_sharing_style == 'grouped': return layer_number % self.num_unique_layers if self.param_sharing_style == 'spaced': return layer_number // (self.num_layers // self.num_unique_layers) assert False, 'should not be here' def _get_layer(self, layer_number): return self.layers[self._get_layer_index(layer_number)] def _checkpointed_forward(self, hidden_states, attention_mask): """Forward method with activation checkpointing.""" def custom(start, end): def custom_forward(*inputs): layers_ = self.layers[start:end] x_ = inputs[0] for layer in layers_: for index in range(start, end): layer = self._get_layer(index) x_ = layer(x_, inputs[1]) return x_ return custom_forward l = 0 num_layers = len(self.layers) while l < num_layers: while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.checkpoint_num_layers), hidden_states, attention_mask) Loading @@ -414,10 +440,11 @@ class ParallelTransformer(MegatronModule): else: if get_key_value: presents = [] for i, layer in enumerate(self.layers): for index in range(self.num_layers): layer = self._get_layer(index) past = None if layer_past is not None: past = layer_past[i] past = layer_past[index] hidden_states = layer(hidden_states, attention_mask, layer_past=past, Loading Loading
megatron/arguments.py +18 −0 Original line number Diff line number Diff line Loading @@ -89,6 +89,14 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.min_lr <= args.lr if args.save is not None: assert args.save_interval is not None # Parameters sharing does not work with torch DDP. if (args.num_unique_layers is not None) and (args.num_layers is not None): assert args.num_unique_layers <= args.num_layers assert args.num_layers % args.num_unique_layers == 0, \ 'num-layers should be divisible by num-unique-layers.' if args.num_unique_layers < args.num_layers: assert args.DDP_impl == 'local', \ 'torch-DDP does not work with parameters sharing.' _print_args(args) return args Loading Loading @@ -116,6 +124,16 @@ def _add_network_size_args(parser): group.add_argument('--num-layers', type=int, default=None, help='Number of transformer layers.') group.add_argument('--num-unique-layers', type=int, default=None, help='Number of unique transformer layers. ' '`num-layers` should be divisible by this value.') group.add_argument('--param-sharing-style', default='grouped', choices=['grouped', 'spaced'], help='Ordering of the shared parameters. For example, ' 'for a `num-layers`=4 and `--num-unique-layers`=2, ' 'we will have the following ordering for two unique ' 'layers 1 and 2: ' ' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].') group.add_argument('--hidden-size', type=int, default=None, help='Tansformer hidden size.') group.add_argument('--num-attention-heads', type=int, default=None, Loading
megatron/model/transformer.py +37 −10 Original line number Diff line number Diff line Loading @@ -360,34 +360,60 @@ class ParallelTransformer(MegatronModule): self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers def get_layer(layer_number): # Number of layers: self.num_layers = args.num_layers self.num_unique_layers = args.num_unique_layers if self.num_unique_layers is None: self.num_unique_layers = self.num_layers assert self.num_layers % self.num_unique_layers == 0, \ 'number of layers should be divisible by number of unique layers' self.param_sharing_style = args.param_sharing_style # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer( attention_mask_func, mlp_activation_func, init_method, output_layer_init_method, layer_number) # Transformer layers. self.layers = torch.nn.ModuleList( [get_layer(i + 1) for i in range(args.num_layers)]) [build_layer(i + 1) for i in range(self.num_unique_layers)]) # Print layer ordering. if self.num_layers != self.num_unique_layers: if torch.distributed.get_rank() == 0: print('> will be using the following layer ordering:') for i in range(self.num_layers): print(' layer id: {:3d} --> unique layer id: ' '{:3d}'.format(i, self._get_layer_index(i)), flush=True) # Final layer norm before output. self.final_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon) def _get_layer_index(self, layer_number): if self.param_sharing_style == 'grouped': return layer_number % self.num_unique_layers if self.param_sharing_style == 'spaced': return layer_number // (self.num_layers // self.num_unique_layers) assert False, 'should not be here' def _get_layer(self, layer_number): return self.layers[self._get_layer_index(layer_number)] def _checkpointed_forward(self, hidden_states, attention_mask): """Forward method with activation checkpointing.""" def custom(start, end): def custom_forward(*inputs): layers_ = self.layers[start:end] x_ = inputs[0] for layer in layers_: for index in range(start, end): layer = self._get_layer(index) x_ = layer(x_, inputs[1]) return x_ return custom_forward l = 0 num_layers = len(self.layers) while l < num_layers: while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.checkpoint_num_layers), hidden_states, attention_mask) Loading @@ -414,10 +440,11 @@ class ParallelTransformer(MegatronModule): else: if get_key_value: presents = [] for i, layer in enumerate(self.layers): for index in range(self.num_layers): layer = self._get_layer(index) past = None if layer_past is not None: past = layer_past[i] past = layer_past[index] hidden_states = layer(hidden_states, attention_mask, layer_past=past, Loading