Commit 3d718bfc authored by rprenger's avatar rprenger
Browse files

Fixing merge conflict

parent 9939fb58
Loading
Loading
Loading
Loading
+1 −17
Original line number Diff line number Diff line
@@ -175,21 +175,11 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
    if tokens is not None:
        return tokens[:, :context_length], output_logits, full_logits 

<<<<<<< HEAD
def generate(model, sentences=None, max_len=0, all_probs=False):
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
        c = context_length_tensor[0]
        b = context_tokens_tensor.size(0)
        start = time.time()
        send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs)
=======
def generate(model, sentences=None, max_len=0):
    model.eval()
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
        send_generate_info(context_tokens_tensor, context_length_tensor, max_len)
>>>>>>> server
        send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs)
    else:
        context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info()
    
@@ -206,7 +196,6 @@ def generate(model, sentences=None, max_len=0):
        decode_tokens = decode_tokens.cpu().numpy().tolist()
        for decode_token in decode_tokens:
            resp_sentences.append(tokenizer.detokenize(decode_token))
<<<<<<< HEAD
            words = []
            for token in decode_token:
                word = tokenizer.tokenizer.decoder[token]
@@ -218,12 +207,7 @@ def generate(model, sentences=None, max_len=0):
        if all_probs:
            full_logits = full_logits.cpu().numpy().tolist()

        end = time.time()
        print(str(b)+","+str(c)+","+str(len(decode_tokens[0]))+","+str(end-start), flush=True)
        return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens 
=======
        return resp_sentences
>>>>>>> server

def generate_samples_eval(model, context, max_gen_length, eos_token_id):
    """