Commit 59414b33 authored by rprenger's avatar rprenger
Browse files

Made some simplifications that keep it from hanging

parent 7a9c4a03
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -41,9 +41,11 @@ class MegatronGenerate(Resource):

        max_len = 64  # Choosing hopefully sane default.  Full sequence is slow
        if "max_len" in request.get_json():
            input_max_len = request.get_json()["max_len"]
            if input_max_len < args.seq_length:
                max_len = input_max_len
            max_len = request.get_json()["max_len"]
            if not isinstance(max_len, int):
                return "max_len must be an integer greater than 0"
            if max_len < 1:
                return "max_len must be an integer greater than 0"

        MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
        resp_sentences = generate(self.model, sentences, max_len) 
+20 −27
Original line number Diff line number Diff line
@@ -104,21 +104,6 @@ def tokenize_batch(sentences):
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
    return context_tokens_tensor, context_length_tensor 

def get_token_stream(model, context_tokens_tensor, context_length_tensor):
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)

    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
    for tokens, lengths in batch_token_iterator:
        context_length += 1
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None


def send_generate_info(context_tokens_tensor, context_length_tensor, max_len):
    """
    Needs to be synced up with receive_generate_info
@@ -151,13 +136,19 @@ def receive_generate_info():
    
    return context_length_tensor, context_tokens_tensor, max_len

def synced_generate(model, context_length_tensor, context_tokens_tensor, max_len):
    token_stream = get_token_stream(model, context_tokens_tensor, context_length_tensor)
    for i, decode_tokens in enumerate(token_stream):
        if i == max_len-1:
            break
        pass
    return decode_tokens
def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len):
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)

    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
                                                 max_len)
    for tokens, lengths in batch_token_iterator:
        context_length += 1
    
    if tokens is not None:
        return tokens[:, :context_length]

def generate(model, sentences=None, max_len=0):
    if torch.distributed.get_rank() == 0:
@@ -169,12 +160,11 @@ def generate(model, sentences=None, max_len=0):
    else:
        context_length_tensor, context_tokens_tensor, max_len = receive_generate_info()
    
    decode_tokens = synced_generate(model, context_length_tensor, context_tokens_tensor, max_len)
    decode_tokens = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len)
    
    if torch.distributed.get_rank() == 0:
        args = get_args()
        tokenizer = get_tokenizer()
        decode_tokens, _ = decode_tokens
        resp_sentences = []
        for i in range(decode_tokens.size(0)):
            decode_token = decode_tokens[i,:].cpu().numpy().tolist()
@@ -248,6 +238,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
        
        maxlen = maxlen + org_context_length
        
        if maxlen > (org_context_length + args.out_seq_length):
            maxlen = org_context_length + args.out_seq_length