Commit 2a3b445d authored by Neel Kant's avatar Neel Kant
Browse files

Cosmetic changes

parent ac967fa0
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -65,7 +65,6 @@ class ICTDataset(Dataset):

        query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
        print(self.tokenizer.decode_token_ids(block_tokens), '\n')
        block_data = np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)

        sample = {
+7 −6
Original line number Diff line number Diff line
@@ -33,8 +33,11 @@ num_batches = 0
def general_model_provider(only_query_model=False, only_block_model=False):
    """Build the model."""
    args = get_args()
    if args.ict_head_size is None:
        raise ValueError("Need to specify --ict-head-size to provide an ICTBertModel")
    assert args.ict_head_size is not None, \
        "Need to specify --ict-head-size to provide an ICTBertModel"

    assert args.model_parallel_size == 1, \
        "Model parallel size > 1 not supported for ICT"

    print_rank_0('building ICTBertModel...')

@@ -89,7 +92,6 @@ def forward_step(data_iterator, model):
    timers('batch generator').stop()

    # Forward model.
    # retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
    query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)

    data_parallel_size = dist.get_world_size() / args.model_parallel_size
@@ -100,11 +102,11 @@ def forward_step(data_iterator, model):
    all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda()
    all_block_logits = all_query_logits.clone().cuda()

    # record this processes' data and then merge with other processes below
    # record this processes' data
    all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
    all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits

    # currently this assumes model parallel size == 1.
    # merge data from all processes
    dist.all_reduce(all_query_logits)
    dist.all_reduce(all_block_logits)

@@ -153,6 +155,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):


if __name__ == "__main__":

    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})