Commit 05042081 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

jan 11 alpha

parent 270d6412
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -329,6 +329,7 @@ class CheckpointFunction(torch.autograd.Function):
        # the chunk corresponding to the current rank.
        if distribute_checkpointed_activations:
            # >>>
            # raise Exception("distrib.")
            # from lutil import data_leak_ctx
            # with data_leak_ctx(args[0]):
            # <<<
+28 −0
Original line number Diff line number Diff line
@@ -115,7 +115,17 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
        unwrap_output_tensor = True

    unwrapped_model.set_input_tensor(input_tensor)
    # >>>
    # if input_tensor[0] is not None:
    #     from lutil import pax, tp
    #     pax({"input_tensor": tp(input_tensor)})
    # <<<
    output_tensor, loss_func = forward_step_func(data_iterator, model)
    # >>>
    # if input_tensor[0] is not None:
    #     from lutil import pax, tp
    #     pax({"input_tensor": tp(input_tensor)})
    # <<<
    if mpu.is_pipeline_last_stage():
        output_tensor = loss_func(output_tensor)
        loss, loss_reduced = output_tensor
@@ -626,11 +636,29 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
        input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
        # >>>
        # if input_tensor[0] is not None:
        #     from lutil import pax
        #     pax({"input_tensor": input_tensor})
        # <<<
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        # >>>
        # if True or input_tensor[0] is not None:
        #     from lutil import pax
        #     pax({"input_tensor": input_tensor})
        # <<<
        send_forward(output_tensor, send_tensor_shapes, timers=timers)

        if not forward_only:
            # >>>
            # if input_tensor[0] is not None:
            #     from lutil import pax
            #     pax({"input_tensor": input_tensor})
            # if output_tensor[0] is not None:
            #     from lutil import pax
            #     pax(0, {"output_tensor / 0": output_tensor[0]})
            # <<<
            assert_viewless_tensor(input_tensor)
            assert_viewless_tensor(output_tensor)
            input_tensors.append(input_tensor)