Commit 41141526 authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'staging-gen-fix' into 'staging'

Update generate_samples.py

See merge request ADLR/megatron-lm!34
parents 29d10a36 323e75c4
Loading
Loading
Loading
Loading
+8 −7
Original line number Diff line number Diff line
@@ -366,12 +366,13 @@ def switch(val1, val2, boolean):
    return (1-boolean)*val1 + boolean*val2

def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
    if isinstance(model, DDP):
        model = model.module
    if isinstance(model, FP16_Module):
        model = model.module
    original_output_parallel = model.parallel_output
    model.parallel_output = False
    actual_model = model
    if isinstance(actual_model, DDP):
        actual_model = actual_model.module
    if isinstance(actual_model, FP16_Module):
        actual_model = actual_model.module
    original_output_parallel = actual_model.parallel_output
    actual_model.parallel_output = False
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
@@ -438,7 +439,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
            yield tokens, lengths
            if done:
                break
    model.parallel_output = original_output_parallel
    actual_model.parallel_output = original_output_parallel

def prepare_tokenizer(args):