Commit b6b7ba4d authored by rprenger's avatar rprenger
Browse files

Added generate_samples_eval function

parent e718810e
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -151,6 +151,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
        return tokens[:, :context_length]

def generate(model, sentences=None, max_len=0):
    model.eval()
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
        send_generate_info(context_tokens_tensor, context_length_tensor, max_len)
@@ -168,6 +169,18 @@ def generate(model, sentences=None, max_len=0):
            resp_sentences.append(tokenizer.detokenize(decode_token))
        return resp_sentences

def generate_samples_eval(model, context, max_gen_length, eos_token_id):
    """
    This function is here to provide an a matching API for a legacy task
    This implementation hasn't been tested yet to make sure it matches
    """
    assert False, "Implementation untested"
    args = get_args()
    args.eos_id = eos_token_id
    raw_text_len = len(context)
    resp_sentences = generate(model, [context], max_gen_length)
    return resp_sentences[0][raw_text_len:]

def switch(val1, val2, boolean):
    boolean = boolean.type_as(val1)
    return (1 - boolean) * val1 + boolean * val2