Commit a31833ce authored by mshoeybi's avatar mshoeybi Committed by Deepak Narayanan
Browse files

Fix some bugs, add exit-duration capability

parent 51315905
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ from .package_info import (
)

from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
+2 −0
Original line number Diff line number Diff line
@@ -326,6 +326,8 @@ def _add_training_args(parser):
    group.add_argument('--exit-interval', type=int, default=None,
                       help='Exit the program after the iteration is divisible '
                       'by this value.')
    group.add_argument('--exit-duration-in-mins', type=int, default=None,
                       help='Exit the program after this many minutes.')
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
    group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
+17 −5
Original line number Diff line number Diff line
@@ -418,10 +418,22 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
            datasets_train_valid_test_num_samples[i],
            max_seq_length, masked_lm_prob, short_seq_prob,
            seed, skip_warmup, dataset_type=dataset_type)
        if train_ds:
            train_datasets.append(train_ds)
        if valid_ds:
            valid_datasets.append(valid_ds)
        if test_ds:
            test_datasets.append(test_ds)

        # Blend.
    blending_train_dataset = None
    if train_datasets:
        blending_train_dataset = BlendableDataset(train_datasets, weights)
    blending_valid_dataset = None
    if valid_datasets:
        blending_valid_dataset = BlendableDataset(valid_datasets, weights)
    blending_test_dataset = None
    if test_datasets:
        blending_test_dataset = BlendableDataset(test_datasets, weights)

    return (blending_train_dataset, blending_valid_dataset,
+15 −6
Original line number Diff line number Diff line
@@ -55,13 +55,22 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
            prefixes[i], data_impl, splits_string,
            datasets_train_valid_test_num_samples[i],
            seq_length, seed, skip_warmup)
        if train_ds:
            train_datasets.append(train_ds)
        if valid_ds:
            valid_datasets.append(valid_ds)
        if test_ds:
            test_datasets.append(test_ds)

    # Blend.
    blending_train_dataset = None
    if train_datasets:
        blending_train_dataset = BlendableDataset(train_datasets, weights)
    blending_valid_dataset = None
    if valid_datasets:
        blending_valid_dataset = BlendableDataset(valid_datasets, weights)
    blending_test_dataset = None
    if test_datasets:
        blending_test_dataset = BlendableDataset(test_datasets, weights)

    return (blending_train_dataset, blending_valid_dataset,
+7 −2
Original line number Diff line number Diff line
@@ -43,8 +43,13 @@ def get_num_microbatches():
    return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()


def update_num_microbatches(consumed_samples):
    _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples)
def get_current_global_batch_size():
    return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()


def update_num_microbatches(consumed_samples, consistency_check=True):
    _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
                                               consistency_check)


def get_tokenizer():
Loading