Commit b46482e8 authored by rprenger's avatar rprenger
Browse files

Fixes a bug in broadcasting that was causing hanging

parent 593b47b4
Loading
Loading
Loading
Loading
+1 −3
Original line number Diff line number Diff line
@@ -30,9 +30,7 @@ class MegatronGenerate(Resource):
    @staticmethod
    def send_do_generate():
        choice = torch.cuda.LongTensor([GENERATE_NUM])
        torch.distributed.broadcast(choice,
                                    mpu.get_tensor_model_parallel_src_rank(),
                                    group=mpu.get_tensor_model_parallel_group())
        torch.distributed.broadcast(choice, 0)
     
    def put(self):
        args = get_args()
+0 −4
Original line number Diff line number Diff line
@@ -141,7 +141,6 @@ def receive_generate_info():
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)

    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
@@ -172,7 +171,6 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
                group = mpu.get_embedding_group()
                full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
                torch.distributed.broadcast(full_logits, src, group)
     
    if tokens is not None:
        return tokens[:, :context_length], output_logits, full_logits 

@@ -310,7 +308,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                if type_ids is not None:
                    types2use = type_ids[:, context_length - 1].view(
                        batch_size, -1)
            
            output, layer_past = forward_step(model, tokens2use,
                                              positions2use,
                                              attention_mask,
@@ -332,7 +329,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
+1 −3
Original line number Diff line number Diff line
@@ -78,8 +78,6 @@ if __name__ == "__main__":

    while True:
        choice = torch.cuda.LongTensor(1)
        torch.distributed.broadcast(choice,
                                    mpu.get_tensor_model_parallel_src_rank(),
                                    group=mpu.get_tensor_model_parallel_group())
        torch.distributed.broadcast(choice, 0)
        if choice[0].item() == 0:
            generate(model)