Commit 386923b5 authored by Ryan Prenger's avatar Ryan Prenger
Browse files

Merge branch 'fix_temp' into 'main'

Fixing bug where temperature was never actually broadcast

See merge request ADLR/megatron-lm!330
parents 5ac5571b 02c00ce6
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -61,9 +61,10 @@ class MegatronGenerate(Resource):
        temperature = args.temperature
        if "temperature" in request.get_json():
            temperature = request.get_json()["temperature"]
            if not isinstance(temperature, float) or not \
               0.0 < temperature <= 100.0:
                return "temperature must be a positive float less than or equal to 100.0"
            if not (type(temperature) == int or type(temperature) == float):
                return "temperature must be a positive number less than or equal to 100.0"
            if not (0.0 < temperature <= 100.0):
                return "temperature must be a positive number less than or equal to 100.0"
        
        add_BOS = False
        if "add_BOS" in request.get_json():
+12 −11
Original line number Diff line number Diff line
@@ -108,13 +108,13 @@ def tokenize_batch(sentences, max_len, add_BOS):
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
    return context_tokens_tensor, context_length_tensor 

def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
    """
    Needs to be synced up with receive_generate_info
    """
    # Send the sizes of the tensors
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs]
    input_info_tensor = torch.cuda.LongTensor(input_info)
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs, temperature]
    input_info_tensor = torch.cuda.FloatTensor(input_info)
    torch.distributed.broadcast(input_info_tensor, 0)

    # Send variables to all ranks 
@@ -125,12 +125,13 @@ def receive_generate_info():
    """
    Needs to be synced up with send_generate_info
    """
    input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.cuda.current_device())
    input_info_tensor = torch.empty(5, dtype=torch.float32, device=torch.cuda.current_device())
    torch.distributed.broadcast(input_info_tensor, 0)
    batch_size = input_info_tensor[0].item()
    seq_len = input_info_tensor[1].item()
    tokens_to_generate = input_info_tensor[2].item()
    all_probs = input_info_tensor[3].item()
    batch_size = int(input_info_tensor[0].item())
    seq_len = int(input_info_tensor[1].item())
    tokens_to_generate = int(input_info_tensor[2].item())
    all_probs = int(input_info_tensor[3].item())
    temperature = float(input_info_tensor[4].item())
    
    context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
    context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
@@ -139,7 +140,7 @@ def receive_generate_info():
    torch.distributed.broadcast(context_length_tensor, 0)
    torch.distributed.broadcast(context_tokens_tensor, 0)
    
    return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs
    return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature

def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
    context_length = context_length_tensor.min().item()
@@ -182,9 +183,9 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
    model.eval()
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
        send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
        send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
    else:
        context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()
        context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature = receive_generate_info()

    output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
    if output is not None: