Commit 752eeae3 authored by Mohammad's avatar Mohammad
Browse files

code runs

parent a6ba254f
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -319,7 +319,7 @@ def get_token_stream(model, context_tokens):
                                group=mpu.get_model_parallel_group())

    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)

    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
@@ -469,7 +469,7 @@ def main():

    # Generate samples.
    if args.num_samples == 0:
        assert args.batch_size == 1
        args.batch_size = 1
        if args.sample_input_file != "":
            generate_samples_input_from_file(model)
        else:
+6 −8
Original line number Diff line number Diff line
@@ -69,7 +69,9 @@ def parse_args(extra_args_provider=None, defaults={}):

    # Checks.
    assert args.hidden_size % args.num_attention_heads == 0
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
    if args.lr is not None:
        assert args.min_lr <= args.lr
    if args.save is not None:
        assert args.save_interval is not None
@@ -134,7 +136,7 @@ def _add_regularization_args(parser):
def _add_training_args(parser):
    group = parser.add_argument_group(title='training')

    group.add_argument('--batch-size', type=int, required=True,
    group.add_argument('--batch-size', type=int, default=None,
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
                       'parallel size.')
@@ -301,7 +303,7 @@ def _add_data_args(parser):
                       help='Path to the vocab file.')
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
    group.add_argument('--seq-length', type=int, required=True,
    group.add_argument('--seq-length', type=int, default=None,
                       help="Maximum sequence length to process.")
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
@@ -356,10 +358,6 @@ def _add_gpt2_args(parser):







def add_data_args_(parser):
    """Train/valid/test data arguments."""