Commit 368ad0d3 authored by Mohammad's avatar Mohammad
Browse files

made size arguments optional so they can be set from input function

parent 7d75b3b5
Loading
Loading
Loading
Loading
+14 −31
Original line number Diff line number Diff line
@@ -35,8 +35,6 @@ def parse_args(extra_args_provider=None, defaults={}):
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
    # TODO: Refactor
    parser = _add_gpt2_args(parser)

    # Custom arguments.
    if extra_args_provider is not None:
@@ -54,6 +52,12 @@ def parse_args(extra_args_provider=None, defaults={}):
            'defaults can only be overwritten for args with None values.'
        setattr(args, key, defaults[key])

    # Check required arguments.
    _check_arg_is_not_none(args, 'num_layers')
    _check_arg_is_not_none(args, 'hidden_size')
    _check_arg_is_not_none(args, 'num_attention_heads')
    _check_arg_is_not_none(args, 'max_position_embeddings')

    # Distributed args.
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
@@ -93,16 +97,20 @@ def _print_args(args):
        print('---------------- end of arguments ----------------', flush=True)


def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


def _add_network_size_args(parser):
    group = parser.add_argument_group(title='network size')

    group.add_argument('--num-layers', type=int, required=True,
    group.add_argument('--num-layers', type=int, default=None,
                       help='Number of transformer layers.')
    group.add_argument('--hidden-size', type=int, required=True,
    group.add_argument('--hidden-size', type=int, default=None,
                       help='Tansformer hidden size.')
    group.add_argument('--num-attention-heads', type=int, required=True,
    group.add_argument('--num-attention-heads', type=int, default=None,
                       help='Number of transformer attention heads.')
    group.add_argument('--max-position-embeddings', type=int, required=True,
    group.add_argument('--max-position-embeddings', type=int, default=None,
                       help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.')
    group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
@@ -342,28 +350,3 @@ def _add_autoresume_args(parser):
                       'termination signal')

    return parser


########################################################################


def _add_gpt2_args(parser):
    group = parser.add_argument_group(title='gpt2')

    group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
                       help='The filename containing all the shards '
                       'sizes for numpy data loader')

    return parser


def add_data_args_(parser):
    """Train/valid/test data arguments."""

    group = parser.add_argument_group('data', 'data configurations')

    group.add_argument('--data-loader', type=str, default=None,
                       choices=['raw', 'lazy', 'tfrecords', 'numpy', 'binary'],
                       help='Which data loader to use. Default varies by model.')

    return parser