Loading megatron/data/realm_dataset.py +0 −1 Original line number Diff line number Diff line Loading @@ -65,7 +65,6 @@ class ICTDataset(Dataset): query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) print(self.tokenizer.decode_token_ids(block_tokens), '\n') block_data = np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64) sample = { Loading pretrain_bert_ict.py +7 −6 Original line number Diff line number Diff line Loading @@ -33,8 +33,11 @@ num_batches = 0 def general_model_provider(only_query_model=False, only_block_model=False): """Build the model.""" args = get_args() if args.ict_head_size is None: raise ValueError("Need to specify --ict-head-size to provide an ICTBertModel") assert args.ict_head_size is not None, \ "Need to specify --ict-head-size to provide an ICTBertModel" assert args.model_parallel_size == 1, \ "Model parallel size > 1 not supported for ICT" print_rank_0('building ICTBertModel...') Loading Loading @@ -89,7 +92,6 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. # retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float() query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) data_parallel_size = dist.get_world_size() / args.model_parallel_size Loading @@ -100,11 +102,11 @@ def forward_step(data_iterator, model): all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda() all_block_logits = all_query_logits.clone().cuda() # record this processes' data and then merge with other processes below # record this processes' data all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits # currently this assumes model parallel size == 1. # merge data from all processes dist.all_reduce(all_query_logits) dist.all_reduce(all_block_logits) Loading Loading @@ -153,6 +155,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): if __name__ == "__main__": pretrain(train_valid_test_datasets_provider, model_provider, forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) Loading
megatron/data/realm_dataset.py +0 −1 Original line number Diff line number Diff line Loading @@ -65,7 +65,6 @@ class ICTDataset(Dataset): query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) print(self.tokenizer.decode_token_ids(block_tokens), '\n') block_data = np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64) sample = { Loading
pretrain_bert_ict.py +7 −6 Original line number Diff line number Diff line Loading @@ -33,8 +33,11 @@ num_batches = 0 def general_model_provider(only_query_model=False, only_block_model=False): """Build the model.""" args = get_args() if args.ict_head_size is None: raise ValueError("Need to specify --ict-head-size to provide an ICTBertModel") assert args.ict_head_size is not None, \ "Need to specify --ict-head-size to provide an ICTBertModel" assert args.model_parallel_size == 1, \ "Model parallel size > 1 not supported for ICT" print_rank_0('building ICTBertModel...') Loading Loading @@ -89,7 +92,6 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. # retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float() query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) data_parallel_size = dist.get_world_size() / args.model_parallel_size Loading @@ -100,11 +102,11 @@ def forward_step(data_iterator, model): all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda() all_block_logits = all_query_logits.clone().cuda() # record this processes' data and then merge with other processes below # record this processes' data all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits # currently this assumes model parallel size == 1. # merge data from all processes dist.all_reduce(all_query_logits) dist.all_reduce(all_block_logits) Loading Loading @@ -153,6 +155,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): if __name__ == "__main__": pretrain(train_valid_test_datasets_provider, model_provider, forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})