Commit 3b29f0e7 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

minor fixes

parent 7a77abd9
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -164,7 +164,7 @@ class MegatronPretrainingRandomSampler:
        current_epoch_samples = self.consumed_samples % active_total_samples
        assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

        if isinstance(dataset, RandomSeedDataset):
        if isinstance(self.dataset, RandomSeedDataset):
            self.dataset.set_epoch(self.epoch)

        # data sharding and random sampling
+1 −1
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ class ClassificationTransform():
                T.ColorJitter(0.4, 0.4, 0.4, 0.1),
                ImageNetPolicy(),
                T.ToTensor(),
                T.Normalize(*self.mean_std),
                T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                T.ConvertImageDtype(self.data_type)
            ])
        else:
+0 −2
Original line number Diff line number Diff line
@@ -19,8 +19,6 @@ import torch
from megatron import get_args
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.vision.mit_backbone import mit_b3_avg
from megatron.model.vision.utils import trunc_normal_
from megatron.model.module import MegatronModule

class VitClassificationModel(MegatronModule):