Commit 811183f0 authored by rprenger's avatar rprenger
Browse files

Got it working on the full big model

parent ddd36145
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -54,12 +54,13 @@ class MegatronGenerate(Resource):
                return "all_probs must be a boolean value"

        MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
        resp_sentences, resp_sentences_seg, output_logits, full_logits = generate(self.model, sentences, max_len, all_probs) 
        resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, max_len, all_probs) 
        if all_probs:
            return jsonify({"sentences": resp_sentences,
                "segments": resp_sentences_seg,
                "logits": output_logits,
                "all_logits": full_logits})
                "all_logits": full_logits,
                "tokens": tokens})
        
        return jsonify({"sentences": resp_sentences,
            "segments": resp_sentences_seg,
+8 −6
Original line number Diff line number Diff line
@@ -121,7 +121,7 @@ def receive_generate_info():
    """
    Needs to be synced up with send_generate_info
    """
    input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda"))
    input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.device("cuda"))
    torch.distributed.broadcast(input_info_tensor, 0)
    batch_size = input_info_tensor[0].item()
    seq_len = input_info_tensor[1].item()
@@ -166,9 +166,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
            torch.distributed.broadcast(output_logits, src, group)
            
            if all_probs:
                args = get_args()
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size(), dtype=torch.float32, device=torch.device("cuda"))
                full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
                torch.distributed.broadcast(full_logits, src, group)
     
    if tokens is not None:
@@ -193,8 +194,9 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
        tokenizer = get_tokenizer()
        resp_sentences = []
        resp_sentences_seg = []
        for i in range(decode_tokens.size(0)):
            decode_token = decode_tokens[i,:].cpu().numpy().tolist()
        
        decode_tokens = decode_tokens.cpu().numpy().tolist()
        for decode_token in decode_tokens:
            resp_sentences.append(tokenizer.detokenize(decode_token))
            words = []
            for token in decode_token:
@@ -208,8 +210,8 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
            full_logits = full_logits.cpu().numpy().tolist()

        end = time.time()
        print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True)
        return resp_sentences, resp_sentences_seg, output_logits, full_logits 
        print(str(b)+","+str(c)+","+str(len(decode_tokens[0]))+","+str(end-start), flush=True)
        return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens 

def switch(val1, val2, boolean):
    boolean = boolean.type_as(val1)