Commit 42e83ee0 authored by rprenger's avatar rprenger
Browse files

Changing the interface to the lm eval harness and fixing bugs caused by...

Changing the interface to the lm eval harness and fixing bugs caused by misunderstanding out_seq_length
parent 8fe6f9bf
Loading
Loading
Loading
Loading
+14 −13
Original line number Diff line number Diff line
@@ -187,7 +187,6 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
    if output is not None:
        decode_tokens, output_logits, full_logits = output
        
    if torch.distributed.get_rank() == 0:
        args = get_args()
        tokenizer = get_tokenizer()
        resp_sentences = []
@@ -214,12 +213,15 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id):
    This function is here to provide an a matching API for a legacy task
    This implementation hasn't been tested yet to make sure it matches
    """
    assert False, "Implementation untested"
    #assert False, "Implementation untested"
    args = get_args()
    args.eos_id = eos_token_id
    raw_text_len = len(context)
    resp_sentences = generate(model, [context], max_gen_length)
    if resp_sentences:
        return resp_sentences[0][raw_text_len:]
    else:
        return [None]  # This is horrible

def switch(val1, val2, boolean):
    boolean = boolean.type_as(val1)
@@ -262,7 +264,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,

def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, all_probs=False, type_ids=None):
                          maxlen, all_probs=False, type_ids=None):
    args = get_args()
    tokenizer = get_tokenizer()

@@ -286,13 +288,12 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
        tokens = context_tokens
        output_logits = None
       
        if maxlen is None:
            maxlen = args.seq_length - 1
        
        # TODO(rprenger) maxlen should be named a different parameter
        maxlen = maxlen + org_context_length
       
        if maxlen > (org_context_length + args.out_seq_length):
            maxlen = org_context_length + args.out_seq_length
        # TODO(rprenger) Need a better understanding of what args.seq_length vs args.out_seq_length (shouldn't be "args")
        if maxlen > args.seq_length:
            maxlen = args.seq_length
        
        lengths = torch.ones([batch_size]).long().cuda() * maxlen