Commit 97ba5c0e authored by mohammad's avatar mohammad
Browse files

load and save state dicts added

parent 0888a3e1
Loading
Loading
Loading
Loading
+1 −3
Original line number Diff line number Diff line
@@ -25,7 +25,6 @@ def _get_params_for_weight_decay_optimization(module):
    """Divide params into with-weight-decay and without-weight-decay groups.
    Layernorms and baises will have no weight decay but the rest will.
    """

    args = get_args()
    LayerNorm = import_layernorm(args.fp32_residual_connection)

@@ -48,7 +47,6 @@ def _get_params_for_weight_decay_optimization(module):


def get_megatron_optimizer(model):

    args = get_args()

    # Base optimizer.
@@ -77,4 +75,4 @@ def get_megatron_optimizer(model):
                                           args.clip_grad)

    # FP32.
    return FP32Optimizer(optimizer, model, args.clip_grad)
    return FP32Optimizer(optimizer, args.clip_grad)
+22 −2
Original line number Diff line number Diff line
@@ -40,7 +40,6 @@ class MegatronGradScaler(ABC):
    def update(self, found_inf):
        pass

    '''
    @abstractmethod
    def state_dict(self):
        pass
@@ -48,7 +47,7 @@ class MegatronGradScaler(ABC):
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass
    '''



class ConstantGradScaler(MegatronGradScaler):
@@ -56,6 +55,13 @@ class ConstantGradScaler(MegatronGradScaler):
    def update(self, found_inf):
        pass

    def state_dict(self):
        return dict()

    def load_state_dict(self, state_dict):
        pass



class DynamicGradScaler(MegatronGradScaler):

@@ -111,3 +117,17 @@ class DynamicGradScaler(MegatronGradScaler):
                self._hysteresis_tracker = self.hysteresis
                # and scale up the loss scale.
                self._scale = self._scale * self.growth_factor


    def state_dict(self):
        state_dict = {}
        state_dict['scale'] = self._scale
        state_dict['growth_tracker'] = self._growth_tracker
        state_dict['hysteresis_tracker'] = self._hysteresis_tracker
        return state_dict


    def load_state_dict(self, state_dict):
        self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
        self._growth_tracker = state_dict['growth_tracker']
        self._hysteresis_tracker = state_dict['hysteresis_tracker']
+22 −5
Original line number Diff line number Diff line
@@ -145,7 +145,6 @@ class MegatronOptimizer(ABC):
    def step(self):
        pass

    '''
    @abstractmethod
    def state_dict(self):
        pass
@@ -153,7 +152,6 @@ class MegatronOptimizer(ABC):
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass
    '''

    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
@@ -180,7 +178,6 @@ class MegatronOptimizer(ABC):

class FP16OptimizerWithFP16Params(MegatronOptimizer):


    def __init__(self, optimizer, grad_scaler, clip_grad):
        super(FP16OptimizerWithFP16Params, self).__init__(optimizer)

@@ -369,12 +366,32 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        return True


    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_fp16_groups
        return state_dict


    def load_state_dict(self, state_dict):
        # Defer to the class to load.
        self.optimizer.load_state_dict(state_dict['optimizer'])
        self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
        # Copy data for the master params.
        for current_group, saved_group in zip(
                self.fp32_from_fp16_groups,
                state_dict['fp32_from_fp16_params']):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)



class FP32Optimizer(MegatronOptimizer):

    def __init__(self, optimizer, model, clip_grad):
    def __init__(self, optimizer, clip_grad):

        super(FP32Optimizer, self).__init__(optimizer)
        self.model = model
        self.clip_grad = clip_grad
        self._scale = torch.cuda.FloatTensor([1.0])