Loading pretrain_ict.py +10 −63 Original line number Diff line number Diff line Loading @@ -57,7 +57,6 @@ def model_provider(): return general_model_provider(False, False) def get_group_world_size_rank(): group = mpu.get_data_parallel_group() Loading @@ -67,23 +66,10 @@ def get_group_world_size_rank(): return group, rank, world_size def get_rank_chunk_along_first_dim(tensor): group, rank, world_size = get_group_world_size_rank() assert tensor.shape[0] % world_size == 0 dim_size = tensor.shape[0] // world_size output_list = torch.split(tensor, dim_size, dim=0) output = output_list[rank].contiguous() return output class AllgatherFromDataParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): assert input_.dim() == 2 group, rank, world_size = get_group_world_size_rank() Loading @@ -98,32 +84,17 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): return get_rank_chunk_along_first_dim(grad_output) class AllReduceFromDataParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): assert input_.dim() == 2 group, rank, world_size = get_group_world_size_rank() tensor_list = [torch.zero_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ output = torch.cat(tensor_list, dim=0).contiguous() torch.distributed.all_reduce(output, group=group) assert grad_output.shape[0] % world_size == 0 dim_size = grad_output.shape[0] // world_size output_list = torch.split(grad_output, dim_size, dim=0) # get chunk from this rank output = output_list[rank].contiguous() return output @staticmethod def backward(ctx, grad_output): return get_rank_chunk_along_first_dim(grad_output) def get_batch(data_iterator): # Items and their type. keys = ['query_tokens', 'query_pad_mask', Loading Loading @@ -159,39 +130,15 @@ def forward_step(data_iterator, model): block_tokens, block_pad_mask, block_indices = get_batch(data_iterator) timers('batch generator').stop() # Forward model. query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) local_batch_size = query_logits.shape[0] global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1 IMPLEMENTATION = 'original' if IMPLEMENTATION == 'original': data_parallel_size = dist.get_world_size() / args.model_parallel_size batch_size = query_logits.shape[0] global_batch_size = int(batch_size * data_parallel_size) all_logits_shape = (int(global_batch_size), int(query_logits.shape[1])) all_query_logits = torch.cuda.FloatTensor(*all_logits_shape).type(query_logits.dtype).fill_(0.0) all_block_logits = all_query_logits.clone() # record this processes' data all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits # merge data from all processes dist.all_reduce(all_query_logits) dist.all_reduce(all_block_logits) elif IMPLEMENTATION == 'allgather': all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) elif IMPLEMENTATION == 'allreduce': all_query_logits = AllReduceFromDataParallelRegion.apply(query_logits) all_block_logits = AllReduceFromDataParallelRegion.apply(block_logits) else: raise Exception('should not be here.') # scores are inner products between query and block embeddings retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float()) softmaxed = F.softmax(retrieval_scores, dim=1) Loading Loading
pretrain_ict.py +10 −63 Original line number Diff line number Diff line Loading @@ -57,7 +57,6 @@ def model_provider(): return general_model_provider(False, False) def get_group_world_size_rank(): group = mpu.get_data_parallel_group() Loading @@ -67,23 +66,10 @@ def get_group_world_size_rank(): return group, rank, world_size def get_rank_chunk_along_first_dim(tensor): group, rank, world_size = get_group_world_size_rank() assert tensor.shape[0] % world_size == 0 dim_size = tensor.shape[0] // world_size output_list = torch.split(tensor, dim_size, dim=0) output = output_list[rank].contiguous() return output class AllgatherFromDataParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): assert input_.dim() == 2 group, rank, world_size = get_group_world_size_rank() Loading @@ -98,32 +84,17 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): return get_rank_chunk_along_first_dim(grad_output) class AllReduceFromDataParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): assert input_.dim() == 2 group, rank, world_size = get_group_world_size_rank() tensor_list = [torch.zero_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ output = torch.cat(tensor_list, dim=0).contiguous() torch.distributed.all_reduce(output, group=group) assert grad_output.shape[0] % world_size == 0 dim_size = grad_output.shape[0] // world_size output_list = torch.split(grad_output, dim_size, dim=0) # get chunk from this rank output = output_list[rank].contiguous() return output @staticmethod def backward(ctx, grad_output): return get_rank_chunk_along_first_dim(grad_output) def get_batch(data_iterator): # Items and their type. keys = ['query_tokens', 'query_pad_mask', Loading Loading @@ -159,39 +130,15 @@ def forward_step(data_iterator, model): block_tokens, block_pad_mask, block_indices = get_batch(data_iterator) timers('batch generator').stop() # Forward model. query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) local_batch_size = query_logits.shape[0] global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1 IMPLEMENTATION = 'original' if IMPLEMENTATION == 'original': data_parallel_size = dist.get_world_size() / args.model_parallel_size batch_size = query_logits.shape[0] global_batch_size = int(batch_size * data_parallel_size) all_logits_shape = (int(global_batch_size), int(query_logits.shape[1])) all_query_logits = torch.cuda.FloatTensor(*all_logits_shape).type(query_logits.dtype).fill_(0.0) all_block_logits = all_query_logits.clone() # record this processes' data all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits # merge data from all processes dist.all_reduce(all_query_logits) dist.all_reduce(all_block_logits) elif IMPLEMENTATION == 'allgather': all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) elif IMPLEMENTATION == 'allreduce': all_query_logits = AllReduceFromDataParallelRegion.apply(query_logits) all_block_logits = AllReduceFromDataParallelRegion.apply(block_logits) else: raise Exception('should not be here.') # scores are inner products between query and block embeddings retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float()) softmaxed = F.softmax(retrieval_scores, dim=1) Loading