Commit 4ad5717e authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'model_parallel_generation' into 'master'

force output gathering

See merge request ADLR/megatron-lm!17
parents 2d76d065 7a6d630e
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -362,6 +362,12 @@ def switch(val1, val2, boolean):
    return (1-boolean)*val1 + boolean*val2

def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
    if isinstance(model, DDP):
        model = model.module
    if isinstance(model, FP16_Module):
        model = model.module
    original_output_parallel = model.parallel_output
    model.parallel_output = False
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
@@ -404,6 +410,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
            if args.greedy:
                prev = torch.argmax(logits, dim=-1).view(-1)
            else:
                logits = logits.float()
                logits /= args.temperature
                logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)            
                log_probs = F.softmax(logits, dim=-1)
@@ -427,6 +434,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
            yield tokens, lengths
            if done:
                break
    model.parallel_output = original_output_parallel

def prepare_tokenizer(args):