Loading generate_samples.py +2 −2 Original line number Diff line number Diff line Loading @@ -319,7 +319,7 @@ def get_token_stream(model, context_tokens): group=mpu.get_model_parallel_group()) context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, Loading Loading @@ -469,7 +469,7 @@ def main(): # Generate samples. if args.num_samples == 0: assert args.batch_size == 1 args.batch_size = 1 if args.sample_input_file != "": generate_samples_input_from_file(model) else: Loading megatron/arguments.py +6 −8 Original line number Diff line number Diff line Loading @@ -69,7 +69,9 @@ def parse_args(extra_args_provider=None, defaults={}): # Checks. assert args.hidden_size % args.num_attention_heads == 0 if args.seq_length is not None: assert args.max_position_embeddings >= args.seq_length if args.lr is not None: assert args.min_lr <= args.lr if args.save is not None: assert args.save_interval is not None Loading Loading @@ -134,7 +136,7 @@ def _add_regularization_args(parser): def _add_training_args(parser): group = parser.add_argument_group(title='training') group.add_argument('--batch-size', type=int, required=True, group.add_argument('--batch-size', type=int, default=None, help='Batch size per model instance (local batch size). ' 'Global batch size is local batch size times data ' 'parallel size.') Loading Loading @@ -301,7 +303,7 @@ def _add_data_args(parser): help='Path to the vocab file.') group.add_argument('--merge-file', type=str, default=None, help='Path to the BPE merge file.') group.add_argument('--seq-length', type=int, required=True, group.add_argument('--seq-length', type=int, default=None, help="Maximum sequence length to process.") group.add_argument('--mask-prob', type=float, default=0.15, help='Probability of replacing a token with mask.') Loading Loading @@ -356,10 +358,6 @@ def _add_gpt2_args(parser): def add_data_args_(parser): """Train/valid/test data arguments.""" Loading Loading
generate_samples.py +2 −2 Original line number Diff line number Diff line Loading @@ -319,7 +319,7 @@ def get_token_stream(model, context_tokens): group=mpu.get_model_parallel_group()) context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, Loading Loading @@ -469,7 +469,7 @@ def main(): # Generate samples. if args.num_samples == 0: assert args.batch_size == 1 args.batch_size = 1 if args.sample_input_file != "": generate_samples_input_from_file(model) else: Loading
megatron/arguments.py +6 −8 Original line number Diff line number Diff line Loading @@ -69,7 +69,9 @@ def parse_args(extra_args_provider=None, defaults={}): # Checks. assert args.hidden_size % args.num_attention_heads == 0 if args.seq_length is not None: assert args.max_position_embeddings >= args.seq_length if args.lr is not None: assert args.min_lr <= args.lr if args.save is not None: assert args.save_interval is not None Loading Loading @@ -134,7 +136,7 @@ def _add_regularization_args(parser): def _add_training_args(parser): group = parser.add_argument_group(title='training') group.add_argument('--batch-size', type=int, required=True, group.add_argument('--batch-size', type=int, default=None, help='Batch size per model instance (local batch size). ' 'Global batch size is local batch size times data ' 'parallel size.') Loading Loading @@ -301,7 +303,7 @@ def _add_data_args(parser): help='Path to the vocab file.') group.add_argument('--merge-file', type=str, default=None, help='Path to the BPE merge file.') group.add_argument('--seq-length', type=int, required=True, group.add_argument('--seq-length', type=int, default=None, help="Maximum sequence length to process.") group.add_argument('--mask-prob', type=float, default=0.15, help='Probability of replacing a token with mask.') Loading Loading @@ -356,10 +358,6 @@ def _add_gpt2_args(parser): def add_data_args_(parser): """Train/valid/test data arguments.""" Loading