Loading tasks/orqa/supervised/finetune.py +56 −33 Original line number Diff line number Diff line Loading @@ -33,6 +33,44 @@ from tasks.orqa.supervised.eval_utils import accuracy_func_provider from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn from tasks.orqa.evaluate_utils import ORQAEvaluator # input_ is a 2D tensor def check_and_append_tensor_for_gather(group, rank, world_size, input_): # gather the size of the first dimension of the tensor from all ranks current_length = input_.size()[0] first_dim = torch.tensor([[current_length]], device=torch.cuda.current_device()) input_list = [torch.empty_like(first_dim) for _ in range(world_size)] input_list[rank].copy_(first_dim) torch.distributed.all_gather(input_list, first_dim, group=group) all_input_list = torch.cat(input_list, dim=0).contiguous() max_length = torch.max(all_input_list) min_length = torch.min(all_input_list) #if rank == 0: # print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True) # print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True) if max_length > current_length: #print("rank {} before pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True) #torch.set_printoptions(profile="full") #input_ = torch.nn.functional.pad(input=input_, # pad=(0, 0, 0, max_length - current_length)) padding=tuple([0] * (input_.dim() * 2 - 1)) + \ tuple([max_length - current_length]) input_ = F.pad(input=input_, pad=padding) #print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True) #print("rank {} after pad neg_context_tokens current_length {}".format(rank, input_[current_length]), flush=True) #print("rank {} after pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True) #if rank == 0: # print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True) # print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True) return input_ def orqa(Dataset): def cross_entropy_forward_step(batch, model): Loading @@ -56,7 +94,6 @@ def orqa(Dataset): timers('batch generator').stop() local_batch_size = query_tokens.shape[0] #print("rank {} query_tokens {} context_tokens {} batch {} neg_context_tokens {}".format(rank, query_tokens.size(), context_tokens.size(), local_batch_size, neg_context_tokens.size()), flush=True) # Text representation of query and context query_list, context_list = [], [] Loading @@ -64,44 +101,30 @@ def orqa(Dataset): query_list.append(tokenizer.decode(query_tokens[i].tolist())) context_list.append(tokenizer.decode(context_tokens[i].tolist())) if neg_context_tokens.size()[0] > 200: current_length = neg_context_tokens.size()[0] first_dim = torch.tensor([[neg_context_tokens.size()[0]]], device=torch.cuda.current_device()) neg_context_list = [torch.empty_like(first_dim) for _ in range(world_size)] neg_context_list[rank].copy_(first_dim) torch.distributed.all_gather(neg_context_list, first_dim, group=group) all_neg_context_list = torch.cat(neg_context_list, dim=0).contiguous() max_length = torch.max(all_neg_context_list) torch.set_printoptions(profile="full") #if rank == 5: # print("rank {} before query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), # query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True) if max_length > current_length: print("rank {} before pad neg_context_tokens {}".format(rank, neg_context_tokens[current_length-1]), flush=True) neg_context_tokens = torch.nn.functional.pad(input=neg_context_tokens, pad=(0, 0, 0, max_length - neg_context_tokens.size()[0])) if neg_context_tokens is not None: # and neg_context_tokens.size()[0] > local_batch_size: neg_context_tokens = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_tokens) neg_context_mask = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_mask) neg_context_types = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_types) #exit() input_ = torch.empty_like(neg_context_tokens).copy_(\ neg_context_tokens).detach_() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank].copy_(input_) torch.distributed.all_gather(tensor_list, input_, group=group) #if rank == 5: # print("rank {} middle query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), # query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True) if max_length > current_length: print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, neg_context_tokens[current_length-1]), flush=True) print("rank {} after pad neg_context_tokens current_length {}".format(rank, neg_context_tokens[current_length]), flush=True) print("rank {} after pad neg_context_tokens max_length-1 {}".format(rank, neg_context_tokens[max_length-1]), flush=True) if rank == 0: print("rank {} other pad neg_context_tokens current_length-1 {}".format(rank, tensor_list[5][451]), flush=True) print("rank {} other pad neg_context_tokens current_length {}".format(rank, tensor_list[5][452]), flush=True) print("rank {} other pad neg_context_tokens max_length-1 {}".format(rank, tensor_list[5][max_length-1]), flush=True) torch.set_printoptions(profile="default") exit() if neg_context_tokens is not None: context_tokens = torch.cat([context_tokens, neg_context_tokens]) context_mask = torch.cat([context_mask, neg_context_mask]) context_types = torch.cat([context_types, neg_context_types]) #if rank == 5: # print("rank {} after query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {}".format(rank, query_tokens.size(), query_mask.size(), # query_types.size(), context_tokens.size(), context_mask.size(), context_types.size()), flush=True) #print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True) # Forward model. output_tensor = model(query_tokens, query_mask, Loading Loading
tasks/orqa/supervised/finetune.py +56 −33 Original line number Diff line number Diff line Loading @@ -33,6 +33,44 @@ from tasks.orqa.supervised.eval_utils import accuracy_func_provider from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn from tasks.orqa.evaluate_utils import ORQAEvaluator # input_ is a 2D tensor def check_and_append_tensor_for_gather(group, rank, world_size, input_): # gather the size of the first dimension of the tensor from all ranks current_length = input_.size()[0] first_dim = torch.tensor([[current_length]], device=torch.cuda.current_device()) input_list = [torch.empty_like(first_dim) for _ in range(world_size)] input_list[rank].copy_(first_dim) torch.distributed.all_gather(input_list, first_dim, group=group) all_input_list = torch.cat(input_list, dim=0).contiguous() max_length = torch.max(all_input_list) min_length = torch.min(all_input_list) #if rank == 0: # print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True) # print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True) if max_length > current_length: #print("rank {} before pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True) #torch.set_printoptions(profile="full") #input_ = torch.nn.functional.pad(input=input_, # pad=(0, 0, 0, max_length - current_length)) padding=tuple([0] * (input_.dim() * 2 - 1)) + \ tuple([max_length - current_length]) input_ = F.pad(input=input_, pad=padding) #print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True) #print("rank {} after pad neg_context_tokens current_length {}".format(rank, input_[current_length]), flush=True) #print("rank {} after pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True) #if rank == 0: # print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True) # print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True) return input_ def orqa(Dataset): def cross_entropy_forward_step(batch, model): Loading @@ -56,7 +94,6 @@ def orqa(Dataset): timers('batch generator').stop() local_batch_size = query_tokens.shape[0] #print("rank {} query_tokens {} context_tokens {} batch {} neg_context_tokens {}".format(rank, query_tokens.size(), context_tokens.size(), local_batch_size, neg_context_tokens.size()), flush=True) # Text representation of query and context query_list, context_list = [], [] Loading @@ -64,44 +101,30 @@ def orqa(Dataset): query_list.append(tokenizer.decode(query_tokens[i].tolist())) context_list.append(tokenizer.decode(context_tokens[i].tolist())) if neg_context_tokens.size()[0] > 200: current_length = neg_context_tokens.size()[0] first_dim = torch.tensor([[neg_context_tokens.size()[0]]], device=torch.cuda.current_device()) neg_context_list = [torch.empty_like(first_dim) for _ in range(world_size)] neg_context_list[rank].copy_(first_dim) torch.distributed.all_gather(neg_context_list, first_dim, group=group) all_neg_context_list = torch.cat(neg_context_list, dim=0).contiguous() max_length = torch.max(all_neg_context_list) torch.set_printoptions(profile="full") #if rank == 5: # print("rank {} before query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), # query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True) if max_length > current_length: print("rank {} before pad neg_context_tokens {}".format(rank, neg_context_tokens[current_length-1]), flush=True) neg_context_tokens = torch.nn.functional.pad(input=neg_context_tokens, pad=(0, 0, 0, max_length - neg_context_tokens.size()[0])) if neg_context_tokens is not None: # and neg_context_tokens.size()[0] > local_batch_size: neg_context_tokens = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_tokens) neg_context_mask = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_mask) neg_context_types = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_types) #exit() input_ = torch.empty_like(neg_context_tokens).copy_(\ neg_context_tokens).detach_() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank].copy_(input_) torch.distributed.all_gather(tensor_list, input_, group=group) #if rank == 5: # print("rank {} middle query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), # query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True) if max_length > current_length: print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, neg_context_tokens[current_length-1]), flush=True) print("rank {} after pad neg_context_tokens current_length {}".format(rank, neg_context_tokens[current_length]), flush=True) print("rank {} after pad neg_context_tokens max_length-1 {}".format(rank, neg_context_tokens[max_length-1]), flush=True) if rank == 0: print("rank {} other pad neg_context_tokens current_length-1 {}".format(rank, tensor_list[5][451]), flush=True) print("rank {} other pad neg_context_tokens current_length {}".format(rank, tensor_list[5][452]), flush=True) print("rank {} other pad neg_context_tokens max_length-1 {}".format(rank, tensor_list[5][max_length-1]), flush=True) torch.set_printoptions(profile="default") exit() if neg_context_tokens is not None: context_tokens = torch.cat([context_tokens, neg_context_tokens]) context_mask = torch.cat([context_mask, neg_context_mask]) context_types = torch.cat([context_types, neg_context_types]) #if rank == 5: # print("rank {} after query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {}".format(rank, query_tokens.size(), query_mask.size(), # query_types.size(), context_tokens.size(), context_mask.size(), context_types.size()), flush=True) #print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True) # Forward model. output_tensor = model(query_tokens, query_mask, Loading