Commit db88a27b authored by mohammad's avatar mohammad
Browse files

addressed Jareds and Deepaks comments

parent 512337f5
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -83,6 +83,9 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
    else:
        if norm_type == 2.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0])
            # Use apex's multi-tensor applier for efficiency reasons.
            # Multi-tensor applier takes a function and a list of list
            # and performs the operation on that list all in one kernel.
            grad_norm, _ = multi_tensor_applier(
                amp_C.multi_tensor_l2norm,
                dummy_overflow_buf,
+6 −16
Original line number Diff line number Diff line
@@ -78,6 +78,7 @@ class MegatronOptimizer(ABC):

    @abstractmethod
    def get_loss_scale(self):
        """The output should be a cuda tensor of size 1."""
        pass

    def scale_loss(self, loss):
@@ -90,6 +91,11 @@ class MegatronOptimizer(ABC):

    @abstractmethod
    def reload_model_params(self):
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
        pass

    @abstractmethod
@@ -289,54 +295,38 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):

        timers = get_timers()

        # ==================================================
        # Copy gradients from model params to main params.
        # ==================================================
        timers('optimizer-copy-to-main-grad').start()
        self._copy_model_grads_to_main_grads()
        timers('optimizer-copy-to-main-grad').stop()

        # ==============================
        # Unscale and check for inf/nan.
        # ==============================
        timers('optimizer-unscale-and-check-inf').start()
        found_inf_flag = self._unscale_main_grads_and_check_for_nan()
        timers('optimizer-unscale-and-check-inf').stop()

        # ==================================
        # We are done with scaling gradients
        # so we can update the loss scale.
        # ==================================
        self.grad_scaler.update(found_inf_flag)

        # =====================================
        # If we found inf/nan, skip the update.
        # =====================================
        if found_inf_flag:
            return False

        # ==========================
        # Clip the main gradients.
        # ==========================
        timers('optimizer-clip-main-grad').start()
        self.clip_grad_norm(self.clip_grad)
        timers('optimizer-clip-main-grad').stop()

        # ===================
        # Step the optimizer.
        # ===================
        self.optimizer.step()

        # =================================
        # Update params from main params.
        # =================================
        timers('optimizer-copy-main-to-model-params').start()
        self._copy_main_params_to_model_params()
        timers('optimizer-copy-main-to-model-params').stop()

        # ==================
        # Successful update.
        # ==================
        return True


+5 −9
Original line number Diff line number Diff line
@@ -703,7 +703,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
                              args.consumed_train_samples)
        if args.fp16:
        writer.add_scalar('loss-scale', loss_scale, iteration)
        writer.add_scalar('loss-scale vs samples', loss_scale,
                          args.consumed_train_samples)
@@ -732,7 +731,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
        if args.fp16:
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
@@ -797,8 +795,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                                       get_num_microbatches()

        # Logging.
        loss_scale = None
        if args.fp16:
        loss_scale = optimizer.get_loss_scale().item()
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],