Commit 3202f237 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/zerograd-fix' into 'main'

fixed zero_grad for fp32_from_float16_groups

See merge request ADLR/megatron-lm!283
parents c1075275 4e64903d
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -282,9 +282,14 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):

    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
                float16_groups & fp32_from_fp32_groups."""
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for group in self.float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)