Commit d46aa964 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'api_change' into 'main'

API improvements.

See merge request ADLR/megatron-lm!337
parents b31e1296 0694205c
Loading
Loading
Loading
Loading
+45 −20
Original line number Diff line number Diff line
@@ -40,9 +40,18 @@ class MegatronGenerate(Resource):
        print(json.dumps(request.get_json()),flush=True)
        print("current time: ", datetime.datetime.now())
       
        sentences = request.get_json()["sentences"]
        if len(sentences) > 128:
            return "Maximum number of sentences is 128", 400
        if not "prompts" in request.get_json():
            return "prompts argument required", 400
        
        if "max_len" in request.get_json():
            return "max_len is no longer used.  Replace with tokens_to_generate", 400
        
        if "sentences" in request.get_json():
            return "sentences is no longer used.  Replace with prompts", 400

        prompts = request.get_json()["prompts"]
        if len(prompts) > 128:
            return "Maximum number of prompts is 128", 400

        tokens_to_generate = 64  # Choosing hopefully sane default.  Full sequence is slow
        if "tokens_to_generate" in request.get_json():
@@ -52,11 +61,11 @@ class MegatronGenerate(Resource):
            if tokens_to_generate < 1:
                return "tokens_to_generate must be an integer greater than 0"

        all_probs = False
        if "all_probs" in request.get_json():
            all_probs = request.get_json()["all_probs"]
            if not isinstance(all_probs, bool):
                return "all_probs must be a boolean value"
        logprobs = False
        if "logprobs" in request.get_json():
            logprobs = request.get_json()["logprobs"]
            if not isinstance(logprobs, bool):
                return "logprobs must be a boolean value"
        
        temperature = args.temperature
        if "temperature" in request.get_json():
@@ -66,6 +75,22 @@ class MegatronGenerate(Resource):
            if not (0.0 < temperature <= 100.0):
                return "temperature must be a positive number less than or equal to 100.0"
        
        top_k = args.top_k
        if "top_k" in request.get_json():
            top_k = request.get_json()["top_k"]
            if not (type(top_k) == int):
                return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
            if not (0 < top_k <= 1000):
                return "top_k must be equal to or greater than 0 and less than or equal to 1000"
        
        top_p = args.top_p
        if "top_p" in request.get_json():
            top_p = request.get_json()["top_p"]
            if not (type(top_p) == float):
                return "top_p must be a positive float less than or equal to 1.0"
            if not (0 < top_p <= 1.0):
                return "top_p must be less than or equal to 1.0"
        
        add_BOS = False
        if "add_BOS" in request.get_json():
            add_BOS = request.get_json()["add_BOS"]
@@ -74,24 +99,24 @@ class MegatronGenerate(Resource):

        with lock:  # Need to get lock to keep multiple threads from hitting code
            MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
            resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS) 
        
        if all_probs:
            return jsonify({"sentences": resp_sentences,
                "segments": resp_sentences_seg,
                "logits": output_logits,
                "all_logits": full_logits,
                "tokens": tokens})
            response, response_seg, response_logprobs = generate(self.model,
                                                                 prompts,
                                                                 tokens_to_generate,
                                                                 logprobs,
                                                                 temperature,
                                                                 top_k,
                                                                 top_p,
                                                                 add_BOS) 
        
        return jsonify({"sentences": resp_sentences,
            "segments": resp_sentences_seg,
            "logits": output_logits})
        return jsonify({"text": response,
            "segments": response_seg,
            "logprobs": response_logprobs})

class MegatronServer(object):
    def __init__(self, model):
        self.app = Flask(__name__, static_url_path='')
        api = Api(self.app)
        api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])
        api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
        
    def run(self, url): 
        self.app.run(url, threaded=True, debug=False)
+68 −68
Original line number Diff line number Diff line
@@ -108,12 +108,12 @@ def tokenize_batch(sentences, max_len, add_BOS):
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
    return context_tokens_tensor, context_length_tensor 

def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p):
    """
    Needs to be synced up with receive_generate_info
    """
    # Send the sizes of the tensors
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs, temperature]
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, logprobs, temperature, top_k, top_p]
    input_info_tensor = torch.cuda.FloatTensor(input_info)
    torch.distributed.broadcast(input_info_tensor, 0)

@@ -125,13 +125,15 @@ def receive_generate_info():
    """
    Needs to be synced up with send_generate_info
    """
    input_info_tensor = torch.empty(5, dtype=torch.float32, device=torch.cuda.current_device())
    input_info_tensor = torch.empty(7, dtype=torch.float32, device=torch.cuda.current_device())
    torch.distributed.broadcast(input_info_tensor, 0)
    batch_size = int(input_info_tensor[0].item())
    seq_len = int(input_info_tensor[1].item())
    tokens_to_generate = int(input_info_tensor[2].item())
    all_probs = int(input_info_tensor[3].item())
    logprobs = bool(input_info_tensor[3].item())
    temperature = float(input_info_tensor[4].item())
    top_k = int(input_info_tensor[5].item())
    top_p = float(input_info_tensor[6].item())
    
    context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
    context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
@@ -140,28 +142,31 @@ def receive_generate_info():
    torch.distributed.broadcast(context_length_tensor, 0)
    torch.distributed.broadcast(context_tokens_tensor, 0)
    
    return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature
    return context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p

def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p):
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
    batch_token_iterator = sample_sequence_batch(model,
                                                 context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
                                                 attention_mask,
                                                 position_ids,
                                                 tokens_to_generate,
                                                 all_probs,
                                                 temperature=temperature)
    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
                                                 logprobs,
                                                 temperature,
                                                 top_k,
                                                 top_p)

    for tokens, lengths, output_logits in batch_token_iterator:
        context_length += 1
   

    if logprobs:
        if mpu.is_pipeline_last_stage():
            src = mpu.get_pipeline_model_parallel_last_rank()
            group = mpu.get_embedding_group()
            torch.distributed.broadcast(output_logits, src, group)
        if all_probs:
            src = mpu.get_pipeline_model_parallel_last_rank()
            group = mpu.get_embedding_group()
            torch.distributed.broadcast(full_logits, src, group)

        else:
            if mpu.is_pipeline_first_stage():
@@ -170,26 +175,20 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
                output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
                torch.distributed.broadcast(output_logits, src, group)
            
            if all_probs:
                args = get_args()
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
                torch.distributed.broadcast(full_logits, src, group)
    if tokens is not None:
        return tokens[:, :context_length], output_logits, full_logits 
        return tokens[:, :context_length], output_logits 

def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False):
def generate(model, sentences=None, tokens_to_generate=0, logprobs=False, temperature=1.0, top_k=0, top_p=0.0, add_BOS=False):
    model.eval()
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
        send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
        send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p)
    else:
        context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature = receive_generate_info()
        context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p = receive_generate_info()

    output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
    output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p)
    if output is not None:
        decode_tokens, output_logits, full_logits = output
        decode_tokens, output_logits = output
        
        args = get_args()
        tokenizer = get_tokenizer()
@@ -197,7 +196,8 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
        resp_sentences_seg = []
        
        decode_tokens = decode_tokens.cpu().numpy().tolist()
        for decode_token in decode_tokens:
        
        for i, decode_token in enumerate(decode_tokens):
            resp_sentences.append(tokenizer.detokenize(decode_token))
            words = []
            for token in decode_token:
@@ -206,11 +206,9 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
                words.append(word)
            resp_sentences_seg.append(words)
        
        if logprobs:
            output_logits = output_logits.cpu().numpy().tolist()
        if all_probs:
            full_logits = full_logits.cpu().numpy().tolist()
       
        return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens 
        return resp_sentences, resp_sentences_seg, output_logits

def generate_samples_eval(model, context, max_gen_length, eos_token_id):
    """
@@ -260,9 +258,17 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
    return output_tensor


def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          tokens_to_generate, all_probs=False, type_ids=None, temperature=None):
def sample_sequence_batch(model,
                          context_tokens,
                          context_lengths,
                          attention_mask,
                          position_ids,
                          tokens_to_generate,
                          logprobs,
                          temperature,
                          top_k,
                          top_p,
                          type_ids=None):
    args = get_args()
    tokenizer = get_tokenizer()

@@ -330,8 +336,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                else:
                    logits = logits.float()
                    logits /= temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    logits = top_k_logits(logits, top_k=top_k,
                                          top_p=top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length
@@ -344,12 +350,11 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
               
                if logprobs:
                    if output_logits is None:
                        output_context = F.log_softmax(output[:, :context_length, :], 2)
                        indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
                        output_logits = torch.gather(output_context, 2, indices).squeeze(2)
                    if all_probs:
                        full_logits = output_context
                    else:
                        output_context = F.log_softmax(output, 2)
                        indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
@@ -357,8 +362,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                        
                        # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                        output_logits = torch.cat([output_logits, new_output_logits],1)
                    if all_probs:
                        full_logits = torch.cat([full_logits, output_context], 1)
                
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
@@ -373,10 +376,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                if all_probs:
                    yield tokens, lengths, output_logits, full_logits
                else:
                    yield tokens, lengths, output_logits, None
                yield tokens, lengths, output_logits

            else:
                if mpu.is_pipeline_first_stage():
@@ -385,9 +385,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None, None, None
                    yield tokens, None, None
                else:
                    yield None, None, None, None
                    yield None, None, None

                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
+3 −3
Original line number Diff line number Diff line
@@ -25,10 +25,10 @@ if __name__ == "__main__":
    url = sys.argv[1]
    while True:
        sentence = raw_input("Enter prompt: ")
        max_len = int(input("Enter number tokens output: "))
        data = json.dumps({"sentences": [sentence], "max_len":max_len})
        tokens_to_generate = int(input("Enter number of tokens to generate: "))
        data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
        req = PutRequest(url, data, {'Content-Type': 'application/json'})
        response = urllib2.urlopen(req)
        resp_sentences = json.load(response)
        print("Megatron Response: ")
        print(resp_sentences["sentences"][0])
        print(resp_sentences["text"][0])