Commit 6c40f892 authored by mshoeybi's avatar mshoeybi
Browse files

working

parent 25f9c3f0
Loading
Loading
Loading
Loading
+56 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Inference API."""


import torch

from .communication import broadcast_float_list
from .generation import generate_tokens_probs_and_return_on_first_stage
from .tokenization import tokenize_prompts


def generate(model,
             prompts=None,
             tokens_to_generate=0,
             return_output_log_probs=False,
             return_all_log_probs=False,
             temperature=1.0):
    """TO DO ..."""

    # Make sure input params are avaialble to all ranks.
    values = [tokens_to_generate, return_output_log_probs,
              return_all_log_probs, temperature]
    values_float_tensor = broadcast_float_list(4, float_list=values)
    tokens_to_generate = int(values_float_tensor[0].item())
    return_output_log_probs = bool(values_float_tensor[1].item())
    return_all_log_probs = bool(values_float_tensor[2].item())
    temperature = values_float_tensor[2].item()

    # Tokenize prompts and get the batch.
    # Note that these tensors are broadcaseted to all ranks.
    if torch.distributed.get_rank() == 0:
        assert prompts is not None
    context_tokens_tensor, context_length_tensor = tokenize_prompts(
        prompts=prompts, tokens_to_generate=tokens_to_generate)

    # Main inference function.
    # Note that the outputs are available on the first stage.
    return generate_tokens_probs_and_return_on_first_stage(
        model, context_tokens_tensor, context_length_tensor,
        return_output_log_probs=return_output_log_probs,
        return_all_log_probs=return_all_log_probs,
        temperature=temperature)
+39 −8
Original line number Diff line number Diff line
@@ -40,6 +40,33 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    return tensor



def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Broadcast tensor values from last stage into the first stage."""

    # Only first and last stage pipeline stages need to be involved.
    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
    if is_last_stage or is_first_stage:
        if is_last_stage:
            assert tensor is not None
            assert tensor.is_cuda
            assert tensor.is_contiguous()
        else:
            tensor = torch.empty(size,
                                 dtype=dtype,
                                 device=torch.cuda.current_device())
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor, src, group)
    else:
        tensor = None

    return tensor



def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Copy tensor values from last stage into the first stage.
    Note that the input tensor is updated in place."""
@@ -48,11 +75,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
    if is_last_stage or is_first_stage:
        assert tensor is not None
        assert tensor.is_cuda
        is_contiguous = tensor.is_contiguous()
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        if is_contiguous:
            tensor_ = tensor
        else:
            if is_last_stage:
            assert tensor is not None
            assert tensor.is_cuda
                tensor_ = tensor.contiguous()
            else:
                tensor_ = torch.empty(size,
@@ -61,7 +92,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor_, src, group)
        # Update the first stage tensor
        if is_first_stage:
        if is_first_stage and not is_contiguous:
            tensor[...] = tensor_


+104 −33
Original line number Diff line number Diff line
@@ -19,19 +19,44 @@
import torch
import torch.nn.functional as F

from megatron import get_args, get_tokenizer
from megatron import mpu
from megatron import get_args, get_tokenizer, mpu
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
    copy_from_last_to_first_pipeline_stage,
    broadcast_from_last_pipeline_stage)
    broadcast_from_last_pipeline_stage,
    broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step
from .sampling import sample


def generate_tokens(model, tokens, lengths, return_all_probs=False,
def generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths,
        return_output_log_probs=False,
        return_all_log_probs=False,
        temperature=1.0):
    """Main token generation function."""
    """Main token generation function.
    Arguments:
        model: XXX
        tokens: prompt tokens extended to be of size [b, max-sequence-length]
        lengths: original prompt length, size: [b]
        return_output_log_probs: flag to calculate the log probability of
            the generated tokens. Note that the log probability is the one
            after logits are modifed for sampling.
        return_all_log_probs: flag to calculate the log probability of across
            all the tokens (vocab size). Note that the log probability is the
            one after logits are modifed for sampling.
        temperature: sampling temperature.
    Note: Outside of model, other parameters only need to be available on
          rank 0.
    Outputs: Note that is size is adjusted to a lower value than
             max-sequence-length if generation is terminated early.
        tokens: prompt and generated tokens. size: [b, :]
        generated_sequence_lengths: total length (including prompt) of
            the generated sequence. size: [b]
        output_log_probs: log probability of the selected tokens. size: [b, s]
        all_log_probs: log probability of all the tokens.
            size: [b, s, vocab-size]
    """

    args = get_args()
    tokenizer = get_tokenizer()
@@ -52,11 +77,24 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
    # Pre-allocate memory
    # ===================

    # Log probability of the sequence (prompt + generated tokens)
    output_log_probs = torch.empty(batch_size, max_sequence_length - 1,
    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
    # Log probability of all tokens for the sequence.
    all_log_probs = None
    all_log_probs_size = (batch_size, max_sequence_length -1,
                          args.padded_vocab_size)
    # Lengths of generated seuquence including including prompts.
    generated_sequence_lengths = None
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = torch.empty(output_log_probs_size,
                                           dtype=torch.float32,
                                           device=torch.cuda.current_device())
        if return_all_log_probs:
            all_log_probs = torch.empty(all_log_probs_size,
                                        dtype=torch.float32,
                                        device=torch.cuda.current_device())
    # Lengths of generated seuquence including including prompts.
        generated_sequence_lengths = torch.ones(
            batch_size, dtype=torch.int64,
            device=torch.cuda.current_device()) * max_sequence_length
@@ -64,6 +102,10 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
    is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
                                     device=torch.cuda.current_device())

    # =============
    # Run infernece
    # =============

    attention_mask, position_ids = _build_attention_mask_and_position_ids(
        tokens)

@@ -114,14 +156,24 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
                tokens[started, context_length] = new_sample[started]

                # Calculate the log probabilities.
                if return_output_log_probs or return_all_log_probs:
                    log_probs = F.log_softmax(logits, dim=2)
                # Pick the tokens that we need to get the log probabilities for.
                # Note that next input token is the token which we selected in
                # the current logits, so shift by 1.
                    if return_all_log_probs:
                        all_log_probs[:,
                                      prev_context_length:context_length,
                                      :] = log_probs
                    if return_output_log_probs:
                        # Pick the tokens that we need to get the log
                        # probabilities for. Note that next input token is
                        # the token which we selected in the current logits,
                        # so shift by 1.
                        indices = torch.unsqueeze(
                    tokens[:, (prev_context_length + 1):(context_length + 1)],
                            tokens[
                                :,
                                (prev_context_length + 1):(context_length + 1)],
                            2)
                output_log_probs[:, prev_context_length:context_length] = \
                        output_log_probs[:,
                                         prev_context_length:context_length] = \
                            torch.gather(log_probs, 2, indices).squeeze(2)

            # Update the tokens on the first stage so the next input to
@@ -147,17 +199,36 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
            if done:
                break

    # ===================================================
    # Update the length of based on max generated length.
    # ===================================================

    tokens = tokens[:, :(context_length + 1)]
    if mpu.is_pipeline_last_stage():
            if return_all_probs:
                full_logits = None
                return tokens, generated_sequence_lengths, output_log_probs, \
                    full_logits, context_length + 1
        if return_output_log_probs:
            output_log_probs = output_log_probs[:, :context_length]
        if return_all_log_probs:
            all_log_probs = all_log_probs[:, :context_length, :]

    # ======================================
    # Broadcast to the first pipeline stage.
    # ======================================

    generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
        batch_size, torch.int64, generated_sequence_lengths)
    if return_output_log_probs:
        output_log_probs_size = (batch_size, context_length)
        output_log_probs = broadcast_from_last_to_first_pipeline_stage(
            output_log_probs_size, torch.float32, output_log_probs)
    if return_all_log_probs:
        all_log_probs_size = (batch_size, context_length,
                              args.padded_vocab_size)
        all_log_probs = broadcast_from_last_to_first_pipeline_stage(
            all_log_probs_size, torch.float32, all_log_probs)

    return tokens, generated_sequence_lengths, output_log_probs, \
                None, context_length + 1
        all_log_probs

        if mpu.is_pipeline_first_stage():
            return tokens, None, None, None, context_length + 1
        return None, None, None, None, context_length + 1


def _build_attention_mask_and_position_ids(tokens):
+33 −0
Original line number Diff line number Diff line
@@ -23,6 +23,39 @@ from megatron import get_tokenizer
from .communication import broadcast_int_list, broadcast_tensor


def detokenize_generations(tokens_gpu_tensor,
                           lengths_gpu_tensor,
                           return_segments):
    """Detokenize the generated tokens."""

    tokenizer = get_tokenizer()

    prompts_plus_generations = []
    if return_segments:
        prompts_plus_generations_segments = []

    tokens = tokens_gpu_tensor.cpu().numpy().tolist()
    lengths = lengths_gpu_tensor.cpu().numpy().tolist()
    for sequence_tokens, length in zip(tokens, lengths):
        sequence_tokens = sequence_tokens[:length]
        prompts_plus_generations.append(
            tokenizer.detokenize(sequence_tokens))
        if return_segments:
            words = []
            for token in sequence_tokens:
                word = tokenizer.tokenizer.decoder[token]
                word = bytearray(
                    [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
                        'utf-8', errors='replace')
                words.append(word)
            prompts_plus_generations_segments.append(words)

    if return_segments:
        return tokens, prompts_plus_generations, \
            prompts_plus_generations_segments

    return tokens, prompts_plus_generations


def tokenize_prompts(prompts=None, tokens_to_generate=None, rank=0):
    """Tokenize prompts and make them avaiable on all ranks."""
+11 −2
Original line number Diff line number Diff line
@@ -153,8 +153,12 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
    if mpu.is_pipeline_last_stage():
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        print('last rank output size {} {} | \n'.format(output_logits.size(0), output_logits.size(1)))
        torch.distributed.broadcast(output_logits, src, group)
        if all_probs:
            print('last rank full size {} {} | \n'.format(full_logits.size(0),
                                                        full_logits.size(1),
                                                        full_logits.size(2)))
            src = mpu.get_pipeline_model_parallel_last_rank()
            group = mpu.get_embedding_group()
            torch.distributed.broadcast(full_logits, src, group)
@@ -164,13 +168,18 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
            src = mpu.get_pipeline_model_parallel_last_rank()
            group = mpu.get_embedding_group()
            output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
            print('first rank output size {} {} | \n'.format(output_logits.size(0), output_logits.size(1)))
            torch.distributed.broadcast(output_logits, src, group)
            
            if all_probs:
                args = get_args()
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
                full_logits = torch.empty(tokens.size(0), context_length-1, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
                print('first rank full size {} {} | \n'.format(full_logits.size(0),
                                                            full_logits.size(1),
                                                            full_logits.size(2)))
                
                torch.distributed.broadcast(full_logits, src, group)
    if tokens is not None:
        return tokens[:, :context_length], output_logits, full_logits 
@@ -204,7 +213,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe

        output_logits = output_logits.cpu().numpy().tolist()
        if all_probs:
            full_logits = full_logits.cpu().numpy().tolist()
            full_logits = full_logits.cpu().numpy() #.tolist()
       
        return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens