Loading megatron/arguments.py +17 −6 Original line number Diff line number Diff line Loading @@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={}, parser = _add_validation_args(parser) parser = _add_data_args(parser) parser = _add_autoresume_args(parser) parser = _add_realm_args(parser) parser = _add_biencoder_args(parser) # Custom arguments. if extra_args_provider is not None: Loading Loading @@ -310,6 +310,8 @@ def _add_training_args(parser): group.add_argument('--checkpoint-activations', action='store_true', help='Checkpoint activation to allow for training ' 'with larger models, sequences, and batch sizes.') group.add_argument('--override-checkpoint-version', type=float, default=None, help='Override checkpoint version') group.add_argument('--distribute-checkpointed-activations', action='store_true', help='If set, distribute checkpointed activations ' Loading Loading @@ -567,12 +569,19 @@ def _add_autoresume_args(parser): return parser def _add_realm_args(parser): group = parser.add_argument_group(title='realm') def _add_biencoder_args(parser): group = parser.add_argument_group(title='biencoder') # network size group.add_argument('--ict-head-size', type=int, default=None, help='Size of block embeddings to be used in ICT and REALM (paper default: 128)') group.add_argument('--projection-dim', type=int, default=0, help='Size of projection head used in biencoder (paper default: 128)') group.add_argument('--shared-query-context-model', action='store_true', help='Whether to share the parameters of the query and context models or not') group.add_argument('--pool-type', type=str, default='cls-token', choices=['avg', 'cls-token', 'max'], help='different options are: avg | cls-token | max, default=cls-token') # checkpointing group.add_argument('--ict-load', type=str, default=None, Loading @@ -589,14 +598,16 @@ def _add_realm_args(parser): help='Whether to use one sentence documents in ICT') # training group.add_argument('--report-topk-accuracies', nargs='+', default=[], group.add_argument('--report-topk-accuracies', nargs='+', type=int, default=[], help="Which top-k accuracies to report (e.g. '1 5 20')") group.add_argument('--retriever-score-scaling', action='store_true', help="Whether to scale retriever scores by inverse square root of hidden size") # faiss index group.add_argument('--faiss-use-gpu', action='store_true', help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--block-data-path', type=str, default=None, help='Where to save/load BlockData to/from') #group.add_argument('--block-data-path', type=str, default=None, # help='Where to save/load BlockData to/from') # indexer group.add_argument('--indexer-batch-size', type=int, default=128, Loading megatron/data/ict_dataset.py +19 −3 Original line number Diff line number Diff line Loading @@ -9,6 +9,16 @@ from megatron import get_args from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.realm_dataset_utils import get_block_samples_mapping def make_attention_mask(source_block, target_block): """ Returns a 2-dimensional (2-D) attention mask :param source_block: 1-D array :param target_block: 1-D array """ mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) mask = mask.astype(np.int64) # (source_length, target_length) return mask def get_ict_dataset(use_titles=True, query_in_block_prob=1): """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) Loading Loading @@ -93,14 +103,20 @@ class ICTDataset(Dataset): block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) query_mask = make_attention_mask(query_tokens, query_tokens) context_mask = make_attention_mask(context_tokens, context_tokens) block_data = sample_data.as_array() sample = { 'query_tokens': query_tokens, 'query_mask': query_mask, 'query_pad_mask': query_pad_mask, 'block_tokens': block_tokens, 'block_pad_mask': block_pad_mask, 'context_tokens': context_tokens, 'context_mask': context_mask, 'context_pad_mask': context_pad_mask, 'block_data': block_data, } Loading megatron/learning_rates.py +7 −1 Original line number Diff line number Diff line Loading @@ -59,6 +59,12 @@ class AnnealingLR(object): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" #print_rank_0(self.warmup_steps) #print_rank_0(self.num_steps) #print_rank_0(self.warmup_steps) #print_rank_0(self.max_lr) #print_rank_0(self.max_lr * float(self.num_steps) / float(self.warmup_steps)) # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: return self.max_lr * float(self.num_steps) / \ Loading Loading @@ -97,7 +103,7 @@ class AnnealingLR(object): new_lr = self.get_lr() for group in self.optimizer.param_groups: group['lr'] = new_lr #print_rank_0(new_lr) def state_dict(self): state_dict = { Loading megatron/model/language_model.py +13 −0 Original line number Diff line number Diff line Loading @@ -374,6 +374,19 @@ class TransformerLanguageModelBase(MegatronModule): # Transformer. if self._transformer_key in state_dict: state_dict_ = state_dict[self._transformer_key] # for compatiability with t5 architecture # this is temporary unless t5_main is merged elif 'encoder' in state_dict: state_dict_ = state_dict['encoder'] # for forward compatibility for t5 architecture state_dict_attention = {} for key in state_dict_.keys(): if '.self_attention.' in key: state_dict_attention[key.replace(".self_attention.", ".attention.")] = state_dict_[key] else: state_dict_attention[key] = state_dict_[key] state_dict_ = state_dict_attention else: # for backward compatibility. state_dict_ = {} Loading megatron/model/transformer.py +4 −1 Original line number Diff line number Diff line Loading @@ -214,6 +214,9 @@ class ParallelSelfAttention(MegatronModule): mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if get_args().override_checkpoint_version is not None: checkpoint_version = get_args().override_checkpoint_version if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] Loading Loading
megatron/arguments.py +17 −6 Original line number Diff line number Diff line Loading @@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={}, parser = _add_validation_args(parser) parser = _add_data_args(parser) parser = _add_autoresume_args(parser) parser = _add_realm_args(parser) parser = _add_biencoder_args(parser) # Custom arguments. if extra_args_provider is not None: Loading Loading @@ -310,6 +310,8 @@ def _add_training_args(parser): group.add_argument('--checkpoint-activations', action='store_true', help='Checkpoint activation to allow for training ' 'with larger models, sequences, and batch sizes.') group.add_argument('--override-checkpoint-version', type=float, default=None, help='Override checkpoint version') group.add_argument('--distribute-checkpointed-activations', action='store_true', help='If set, distribute checkpointed activations ' Loading Loading @@ -567,12 +569,19 @@ def _add_autoresume_args(parser): return parser def _add_realm_args(parser): group = parser.add_argument_group(title='realm') def _add_biencoder_args(parser): group = parser.add_argument_group(title='biencoder') # network size group.add_argument('--ict-head-size', type=int, default=None, help='Size of block embeddings to be used in ICT and REALM (paper default: 128)') group.add_argument('--projection-dim', type=int, default=0, help='Size of projection head used in biencoder (paper default: 128)') group.add_argument('--shared-query-context-model', action='store_true', help='Whether to share the parameters of the query and context models or not') group.add_argument('--pool-type', type=str, default='cls-token', choices=['avg', 'cls-token', 'max'], help='different options are: avg | cls-token | max, default=cls-token') # checkpointing group.add_argument('--ict-load', type=str, default=None, Loading @@ -589,14 +598,16 @@ def _add_realm_args(parser): help='Whether to use one sentence documents in ICT') # training group.add_argument('--report-topk-accuracies', nargs='+', default=[], group.add_argument('--report-topk-accuracies', nargs='+', type=int, default=[], help="Which top-k accuracies to report (e.g. '1 5 20')") group.add_argument('--retriever-score-scaling', action='store_true', help="Whether to scale retriever scores by inverse square root of hidden size") # faiss index group.add_argument('--faiss-use-gpu', action='store_true', help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--block-data-path', type=str, default=None, help='Where to save/load BlockData to/from') #group.add_argument('--block-data-path', type=str, default=None, # help='Where to save/load BlockData to/from') # indexer group.add_argument('--indexer-batch-size', type=int, default=128, Loading
megatron/data/ict_dataset.py +19 −3 Original line number Diff line number Diff line Loading @@ -9,6 +9,16 @@ from megatron import get_args from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.realm_dataset_utils import get_block_samples_mapping def make_attention_mask(source_block, target_block): """ Returns a 2-dimensional (2-D) attention mask :param source_block: 1-D array :param target_block: 1-D array """ mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) mask = mask.astype(np.int64) # (source_length, target_length) return mask def get_ict_dataset(use_titles=True, query_in_block_prob=1): """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) Loading Loading @@ -93,14 +103,20 @@ class ICTDataset(Dataset): block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) query_mask = make_attention_mask(query_tokens, query_tokens) context_mask = make_attention_mask(context_tokens, context_tokens) block_data = sample_data.as_array() sample = { 'query_tokens': query_tokens, 'query_mask': query_mask, 'query_pad_mask': query_pad_mask, 'block_tokens': block_tokens, 'block_pad_mask': block_pad_mask, 'context_tokens': context_tokens, 'context_mask': context_mask, 'context_pad_mask': context_pad_mask, 'block_data': block_data, } Loading
megatron/learning_rates.py +7 −1 Original line number Diff line number Diff line Loading @@ -59,6 +59,12 @@ class AnnealingLR(object): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" #print_rank_0(self.warmup_steps) #print_rank_0(self.num_steps) #print_rank_0(self.warmup_steps) #print_rank_0(self.max_lr) #print_rank_0(self.max_lr * float(self.num_steps) / float(self.warmup_steps)) # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: return self.max_lr * float(self.num_steps) / \ Loading Loading @@ -97,7 +103,7 @@ class AnnealingLR(object): new_lr = self.get_lr() for group in self.optimizer.param_groups: group['lr'] = new_lr #print_rank_0(new_lr) def state_dict(self): state_dict = { Loading
megatron/model/language_model.py +13 −0 Original line number Diff line number Diff line Loading @@ -374,6 +374,19 @@ class TransformerLanguageModelBase(MegatronModule): # Transformer. if self._transformer_key in state_dict: state_dict_ = state_dict[self._transformer_key] # for compatiability with t5 architecture # this is temporary unless t5_main is merged elif 'encoder' in state_dict: state_dict_ = state_dict['encoder'] # for forward compatibility for t5 architecture state_dict_attention = {} for key in state_dict_.keys(): if '.self_attention.' in key: state_dict_attention[key.replace(".self_attention.", ".attention.")] = state_dict_[key] else: state_dict_attention[key] = state_dict_[key] state_dict_ = state_dict_attention else: # for backward compatibility. state_dict_ = {} Loading
megatron/model/transformer.py +4 −1 Original line number Diff line number Diff line Loading @@ -214,6 +214,9 @@ class ParallelSelfAttention(MegatronModule): mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if get_args().override_checkpoint_version is not None: checkpoint_version = get_args().override_checkpoint_version if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] Loading