Commit 6e83649f authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Quick fix for pipeline tasks to get learning rate correct

parent 25c07e14
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -129,10 +129,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
                                          args.num_workers, not args.keep_last)
    valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)

    # Now that we've built the data loaders, set args.micro_batch_size to
    # the actual batch size the model will see for this dataset
    # Now that we've built the data loaders, set batch_size arguments
    # to the actual batch size the model will see for this dataset.
    # This is necessary so pipeline transfers know what size they are
    # and the LR schedule, which is based on samples seen, gets set
    # correctly.
    if hasattr(train_dataset, 'sample_multiplier'):
        args.micro_batch_size *= train_dataset.sample_multiplier
        args.global_batch_size *= train_dataset.sample_multiplier

    return train_dataloader, valid_dataloader