Commit a96ad63f authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixed alignment

parent 651cc107
Loading
Loading
Loading
Loading
+13 −7
Original line number Diff line number Diff line
@@ -644,14 +644,18 @@ def _add_biencoder_args(parser):

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
                       help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
                       help='Size of block embeddings to be used in ICT and '
                        'REALM (paper default: 128)')
    group.add_argument('--projection-dim', type=int, default=0,
                       help='Size of projection head used in biencoder (paper default: 128)')
                       help='Size of projection head used in biencoder (paper'
                        ' default: 128)')
    group.add_argument('--shared-query-context-model', action='store_true',
                        help='Whether to share the parameters of the query and context models or not')
                        help='Whether to share the parameters of the query '
                        'and context models or not')
    group.add_argument('--pool-type', type=str, default='cls-token',
                       choices=['avg', 'cls-token', 'max'],
                       help='different options are: avg | cls-token | max, default=cls-token')
                       help='different options are: avg | cls-token | max, '
                        'default=cls-token')

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
@@ -670,10 +674,12 @@ def _add_biencoder_args(parser):
                       help='Whether to use one sentence documents in ICT')

    # training
    group.add_argument('--report-topk-accuracies', nargs='+', type=int, default=[],
                       help="Which top-k accuracies to report (e.g. '1 5 20')")
    group.add_argument('--report-topk-accuracies', nargs='+', type=int, 
                        default=[], help="Which top-k accuracies to report '
                        '(e.g. '1 5 20')")
    group.add_argument('--retriever-score-scaling', action='store_true',
                       help="Whether to scale retriever scores by inverse square root of hidden size")
                       help="Whether to scale retriever scores by inverse '
                        'square root of hidden size")

    # faiss index
    group.add_argument('--faiss-use-gpu', action='store_true',
+5 −3
Original line number Diff line number Diff line
@@ -34,9 +34,11 @@ from megatron.data.biencoder_dataset_utils import get_ict_batch

def pretrain_ict_model_provider():
    args = get_args()
    model = biencoder_model_provider(only_context_model=False,
    model = biencoder_model_provider(
                only_context_model=False,
                only_query_model=False,
                                     shared_query_context_model=args.shared_query_context_model)
                shared_query_context_model=args.shared_query_context_model
                )
    return model

def get_group_world_size_rank():