Commit c81bf127 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'memory_save' into 'main'

Reducing memory usage during inference

See merge request ADLR/megatron-lm!320
parents 8fe6f9bf faf58b77
Loading
Loading
Loading
Loading
+8 −7
Original line number Diff line number Diff line
@@ -85,21 +85,22 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):

    return logits

def pad_batch(batch, pad_id, args):
def pad_batch(batch, pad_id, max_len):
    context_lengths = []
    max_context_length = max([len(tokens) for tokens in batch])
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
            tokens.extend([pad_id] * (args.seq_length - context_length))
        if context_length < max_context_length + max_len:
            tokens.extend([pad_id] * (max_context_length + max_len - context_length))
        context_lengths.append(context_length)
    return batch, context_lengths

def tokenize_batch(sentences):
def tokenize_batch(sentences, max_len):
    args = get_args()
    tokenizer = get_tokenizer()
    context_tokens = [tokenizer.tokenize(s) for s in sentences]
    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
                                                tokenizer.eod, max_len)
    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
    return context_tokens_tensor, context_length_tensor 
@@ -178,7 +179,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
def generate(model, sentences=None, max_len=0, all_probs=False):
    model.eval()
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, max_len)
        send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs)
    else:
        context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info()