Commit 17d897e0 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

WIP: main_retriver_merge

parent b69bc7ef
Loading
Loading
Loading
Loading
+17 −6
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
    parser = _add_realm_args(parser)
    parser = _add_biencoder_args(parser)

    # Custom arguments.
    if extra_args_provider is not None:
@@ -310,6 +310,8 @@ def _add_training_args(parser):
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
    group.add_argument('--override-checkpoint-version', type=float, default=None,
                       help='Override checkpoint version')
    group.add_argument('--distribute-checkpointed-activations',
                       action='store_true',
                       help='If set, distribute checkpointed activations '
@@ -567,12 +569,19 @@ def _add_autoresume_args(parser):
    return parser


def _add_realm_args(parser):
    group = parser.add_argument_group(title='realm')
def _add_biencoder_args(parser):
    group = parser.add_argument_group(title='biencoder')

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
                       help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
    group.add_argument('--projection-dim', type=int, default=0,
                       help='Size of projection head used in biencoder (paper default: 128)')
    group.add_argument('--shared-query-context-model', action='store_true',
                        help='Whether to share the parameters of the query and context models or not')
    group.add_argument('--pool-type', type=str, default='cls-token',
                       choices=['avg', 'cls-token', 'max'],
                       help='different options are: avg | cls-token | max, default=cls-token')

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
@@ -589,14 +598,16 @@ def _add_realm_args(parser):
                       help='Whether to use one sentence documents in ICT')

    # training
    group.add_argument('--report-topk-accuracies', nargs='+', default=[],
    group.add_argument('--report-topk-accuracies', nargs='+', type=int, default=[],
                       help="Which top-k accuracies to report (e.g. '1 5 20')")
    group.add_argument('--retriever-score-scaling', action='store_true',
                       help="Whether to scale retriever scores by inverse square root of hidden size")

    # faiss index
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
    group.add_argument('--block-data-path', type=str, default=None,
                       help='Where to save/load BlockData to/from')
    #group.add_argument('--block-data-path', type=str, default=None,
    #                   help='Where to save/load BlockData to/from')

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
+19 −3
Original line number Diff line number Diff line
@@ -9,6 +9,16 @@ from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping

def make_attention_mask(source_block, target_block):
    """
    Returns a 2-dimensional (2-D) attention mask
    :param source_block: 1-D array
    :param target_block: 1-D array
    """
    mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
    mask = mask.astype(np.int64)
    # (source_length, target_length)
    return mask

def get_ict_dataset(use_titles=True, query_in_block_prob=1):
    """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
        block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]

        query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
        context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)

        query_mask = make_attention_mask(query_tokens, query_tokens)
        context_mask = make_attention_mask(context_tokens, context_tokens)

        block_data = sample_data.as_array()

        sample = {
            'query_tokens': query_tokens,
            'query_mask': query_mask,
            'query_pad_mask': query_pad_mask,
            'block_tokens': block_tokens,
            'block_pad_mask': block_pad_mask,
            'context_tokens': context_tokens,
            'context_mask': context_mask,
            'context_pad_mask': context_pad_mask,
            'block_data': block_data,
        }

+7 −1
Original line number Diff line number Diff line
@@ -59,6 +59,12 @@ class AnnealingLR(object):
        """Learning rate decay functions from:
              https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""

        #print_rank_0(self.warmup_steps)
        #print_rank_0(self.num_steps)
        #print_rank_0(self.warmup_steps)
        #print_rank_0(self.max_lr)
        #print_rank_0(self.max_lr * float(self.num_steps) / float(self.warmup_steps))

        # Use linear warmup for the initial part.
        if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
            return self.max_lr * float(self.num_steps) / \
@@ -97,7 +103,7 @@ class AnnealingLR(object):
        new_lr = self.get_lr()
        for group in self.optimizer.param_groups:
            group['lr'] = new_lr

        #print_rank_0(new_lr)

    def state_dict(self):
        state_dict = {
+13 −0
Original line number Diff line number Diff line
@@ -374,6 +374,19 @@ class TransformerLanguageModelBase(MegatronModule):
        # Transformer.
        if self._transformer_key in state_dict:
            state_dict_ = state_dict[self._transformer_key]
        # for compatiability with t5 architecture
        # this is temporary unless t5_main is merged
        elif 'encoder' in state_dict:
            state_dict_ = state_dict['encoder']
            # for forward compatibility for t5 architecture
            state_dict_attention = {}
            for key in state_dict_.keys():
                if '.self_attention.' in key:
                    state_dict_attention[key.replace(".self_attention.",
                        ".attention.")] = state_dict_[key]
                else:
                    state_dict_attention[key] = state_dict_[key]
            state_dict_ = state_dict_attention
        else:
            # for backward compatibility.
            state_dict_ = {}
+4 −1
Original line number Diff line number Diff line
@@ -214,6 +214,9 @@ class ParallelSelfAttention(MegatronModule):
        mixed_x_layer, _ = self.query_key_value(hidden_states)

        checkpoint_version = get_checkpoint_version()
        if get_args().override_checkpoint_version is not None:
            checkpoint_version = get_args().override_checkpoint_version

        if checkpoint_version is not None:
           if checkpoint_version == 0:
               # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
Loading