Commit 9dec5374 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

added evaluation logic; finalized flag levels

parent 3bd2e973
Loading
Loading
Loading
Loading
+0 −4
Original line number Diff line number Diff line
@@ -601,10 +601,6 @@ def _add_distributed_args(parser):
    group.add_argument('--use-cpu-initialization', action='store_true',
                       default=None, help='If set, affine parallel weights '
                       'initialization uses CPU' )
    # group.add_argument('--empty-unused-memory-each-iter', action='store_true',
    #                    help='Call torch.cuda.empty_cache() each iteration '
    #                    '(training and eval), to reduce fragmentation',
    #                    default=False)
    group.add_argument('--empty-unused-memory-each-iter', default=0, type=int,
                       choices=[0, 1, 2],
                       help='Call torch.cuda.empty_cache() each iteration '
+6 −3
Original line number Diff line number Diff line
@@ -363,8 +363,7 @@ def train_step(forward_step_func, data_iterator,
        optimizer, timers, forward_only=False)

    # Empty unused memory
    if args.empty_unused_memory_each_iter >= 2:
        raise Exception("hi.")
    if args.empty_unused_memory_each_iter >= 1:
        torch.cuda.empty_cache()

    # All-reduce if needed.
@@ -414,7 +413,7 @@ def train_step(forward_step_func, data_iterator,
        skipped_iter = 1

    # Empty unused memory
    if args.empty_unused_memory_each_iter >= 1:
    if args.empty_unused_memory_each_iter >= 2:
        torch.cuda.empty_cache()

    if mpu.is_pipeline_last_stage(ignore_virtual=True):
@@ -725,6 +724,10 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

            # Empty unused memory
            if args.empty_unused_memory_each_iter >= 1:
                torch.cuda.empty_cache()

            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts: