Commit 230633f8 authored by Ryan Prenger's avatar Ryan Prenger
Browse files

Merge branch 'debug_harness' into 'main'

Changing API to fix LM Harness Evaluation and make API less confusing/terrible

See merge request ADLR/megatron-lm!326
parents c81bf127 593b47b4
Loading
Loading
Loading
Loading
+16 −15
Original line number Diff line number Diff line
@@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import torch
import json
from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api

from megatron import get_args
from megatron import mpu
from megatron.text_generation_utils import generate
@@ -35,17 +36,20 @@ class MegatronGenerate(Resource):
     
    def put(self):
        args = get_args()
        print("request IP: " + str(request.remote_addr))
        print(json.dumps(request.get_json()),flush=True)
        print("current time: ", datetime.datetime.now())
        sentences = request.get_json()["sentences"]
        if len(sentences) > 128:
            return "Maximum number of sentences is 128", 400

        max_len = 64  # Choosing hopefully sane default.  Full sequence is slow
        if "max_len" in request.get_json():
            max_len = request.get_json()["max_len"]
            if not isinstance(max_len, int):
                return "max_len must be an integer greater than 0"
            if max_len < 1:
                return "max_len must be an integer greater than 0"
        tokens_to_generate = 64  # Choosing hopefully sane default.  Full sequence is slow
        if "tokens_to_generate" in request.get_json():
            tokens_to_generate = request.get_json()["tokens_to_generate"]
            if not isinstance(tokens_to_generate, int):
                return "tokens_to_generate must be an integer greater than 0"
            if tokens_to_generate < 1:
                return "tokens_to_generate must be an integer greater than 0"

        all_probs = False
        if "all_probs" in request.get_json():
@@ -54,7 +58,7 @@ class MegatronGenerate(Resource):
                return "all_probs must be a boolean value"

        MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
        resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, max_len, all_probs) 
        resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs) 
        if all_probs:
            return jsonify({"sentences": resp_sentences,
                "segments": resp_sentences_seg,
@@ -66,15 +70,12 @@ class MegatronGenerate(Resource):
            "segments": resp_sentences_seg,
            "logits": output_logits})

def index():
    return current_app.send_static_file('index.html')

class MegatronServer(object):
    def __init__(self, model):
        self.app = Flask(__name__)
        self.app.add_url_rule('/', 'index', index)
        self.app = Flask(__name__, static_folder='static', static_url_path='')
        self.app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0
        api = Api(self.app)
        api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])

    def run(self, url):
        self.app.run(url, threaded=False, debug=False)
        self.app.run(url, threaded=True, debug=False)
+25 −28
Original line number Diff line number Diff line
@@ -105,12 +105,12 @@ def tokenize_batch(sentences, max_len):
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
    return context_tokens_tensor, context_length_tensor 

def send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs):
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
    """
    Needs to be synced up with receive_generate_info
    """
    # Send the sizes of the tensors
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len, all_probs]
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs]
    input_info_tensor = torch.cuda.LongTensor(input_info)
    torch.distributed.broadcast(input_info_tensor, 0)

@@ -126,7 +126,7 @@ def receive_generate_info():
    torch.distributed.broadcast(input_info_tensor, 0)
    batch_size = input_info_tensor[0].item()
    seq_len = input_info_tensor[1].item()
    max_len = input_info_tensor[2].item()
    tokens_to_generate = input_info_tensor[2].item()
    all_probs = input_info_tensor[3].item()
    
    context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
@@ -136,16 +136,16 @@ def receive_generate_info():
    torch.distributed.broadcast(context_length_tensor, 0)
    torch.distributed.broadcast(context_tokens_tensor, 0)
    
    return context_length_tensor, context_tokens_tensor, max_len, all_probs
    return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs

def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs):
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)

    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
                                                 max_len,
                                                 tokens_to_generate,
                                                 all_probs)
    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
        context_length += 1
@@ -176,19 +176,19 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
    if tokens is not None:
        return tokens[:, :context_length], output_logits, full_logits 

def generate(model, sentences=None, max_len=0, all_probs=False):
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False):
    model.eval()
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, max_len)
        send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs)
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate)
        send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
    else:
        context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info()
        context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()
    
    output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
    
    output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs)
    if output is not None:
        decode_tokens, output_logits, full_logits = output
        
    if torch.distributed.get_rank() == 0:
        args = get_args()
        tokenizer = get_tokenizer()
        resp_sentences = []
@@ -215,11 +215,12 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id):
    This function is here to provide an a matching API for a legacy task
    This implementation hasn't been tested yet to make sure it matches
    """
    assert False, "Implementation untested"
    #assert False, "Implementation untested"
    args = get_args()
    args.eos_id = eos_token_id
    raw_text_len = len(context)
    resp_sentences = generate(model, [context], max_gen_length)
    if resp_sentences:
        return resp_sentences[0][raw_text_len:]

def switch(val1, val2, boolean):
@@ -263,7 +264,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,

def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, all_probs=False, type_ids=None):
                          tokens_to_generate, all_probs=False, type_ids=None):
    args = get_args()
    tokenizer = get_tokenizer()

@@ -279,7 +280,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
            eos_id = tokenizer.eod

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
@@ -287,13 +287,11 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
        tokens = context_tokens
        output_logits = None
       
        if maxlen is None:
            maxlen = args.seq_length - 1
        
        maxlen = maxlen + org_context_length
        # Generate enough tokens for the longest sequence
        maxlen = tokens_to_generate + context_lengths.max().item() 
       
        if maxlen > (org_context_length + args.out_seq_length):
            maxlen = org_context_length + args.out_seq_length
        if maxlen > args.seq_length:
            maxlen = args.seq_length
        
        lengths = torch.ones([batch_size]).long().cuda() * maxlen

@@ -357,7 +355,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                    if all_probs:
                        full_logits = torch.cat([full_logits, output_context], 1)
                
                #output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)