Commit 628bf0dd authored by Neel Kant's avatar Neel Kant
Browse files

Use the new allgather implementation

parent 98feae4e
Loading
Loading
Loading
Loading
+10 −63
Original line number Diff line number Diff line
@@ -57,7 +57,6 @@ def model_provider():
    return general_model_provider(False, False)



def get_group_world_size_rank():

    group = mpu.get_data_parallel_group()
@@ -67,23 +66,10 @@ def get_group_world_size_rank():
    return group, rank, world_size


def get_rank_chunk_along_first_dim(tensor):

    group, rank, world_size = get_group_world_size_rank()

    assert tensor.shape[0] % world_size == 0
    dim_size = tensor.shape[0] // world_size
    output_list = torch.split(tensor, dim_size, dim=0)
    
    output = output_list[rank].contiguous()
    return output


class AllgatherFromDataParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_):

        assert input_.dim() == 2
        group, rank, world_size = get_group_world_size_rank()

@@ -98,32 +84,17 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):

    @staticmethod
    def backward(ctx, grad_output):

        return get_rank_chunk_along_first_dim(grad_output)


class AllReduceFromDataParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_):

        assert input_.dim() == 2
        group, rank, world_size = get_group_world_size_rank()

        tensor_list = [torch.zero_like(input_) for _ in range(world_size)]
        tensor_list[rank] = input_
        output = torch.cat(tensor_list, dim=0).contiguous() 
        torch.distributed.all_reduce(output, group=group)
        assert grad_output.shape[0] % world_size == 0
        dim_size = grad_output.shape[0] // world_size
        output_list = torch.split(grad_output, dim_size, dim=0)

        # get chunk from this rank
        output = output_list[rank].contiguous()
        return output


    @staticmethod
    def backward(ctx, grad_output):

        return get_rank_chunk_along_first_dim(grad_output)


def get_batch(data_iterator):
    # Items and their type.
    keys = ['query_tokens', 'query_pad_mask',
@@ -159,39 +130,15 @@ def forward_step(data_iterator, model):
    block_tokens, block_pad_mask, block_indices = get_batch(data_iterator)
    timers('batch generator').stop()


    # Forward model.
    query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
    local_batch_size = query_logits.shape[0]
    global_batch_size = dist.get_world_size() * local_batch_size  # recall we assert that model_parallel_size == 1

    IMPLEMENTATION = 'original'

    if IMPLEMENTATION == 'original':
        data_parallel_size = dist.get_world_size() / args.model_parallel_size
        batch_size = query_logits.shape[0]
        global_batch_size = int(batch_size * data_parallel_size)
        
        all_logits_shape = (int(global_batch_size), int(query_logits.shape[1]))
        all_query_logits = torch.cuda.FloatTensor(*all_logits_shape).type(query_logits.dtype).fill_(0.0)
        all_block_logits = all_query_logits.clone()

        # record this processes' data
        all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
        all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits

        # merge data from all processes
        dist.all_reduce(all_query_logits)
        dist.all_reduce(all_block_logits)

    elif IMPLEMENTATION == 'allgather':
    all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
    all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)

    elif IMPLEMENTATION == 'allreduce':
        all_query_logits = AllReduceFromDataParallelRegion.apply(query_logits)
        all_block_logits = AllReduceFromDataParallelRegion.apply(block_logits)

    else:
        raise Exception('should not be here.')

    # scores are inner products between query and block embeddings
    retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
    softmaxed = F.softmax(retrieval_scores, dim=1)