Loading megatron/checkpointing.py +1 −1 Original line number Diff line number Diff line Loading @@ -60,8 +60,8 @@ def check_checkpoint_args(checkpoint_args): _compare('num_layers') _compare('hidden_size') _compare('num_attention_heads') _compare('max_position_embeddings') if args.vocab_file: _compare('max_position_embeddings') _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') _compare('tokenizer_type') Loading megatron/model/vit_model.py +76 −52 Original line number Diff line number Diff line Loading @@ -50,11 +50,11 @@ class VitMlpHead(MegatronModule): def forward(self, hidden_states, sequence_index=0): # hidden_states: [b, s, h] # sequence_index: index of the token to pool. x = hidden_states[:, sequence_index, :] x = self.dense_in(x) x = torch.tanh(x) x = self.dense_out(x) return x hidden_state = hidden_states[:, sequence_index, :] dense_in_result = self.dense_in(hidden_state) tanh_result = torch.tanh(dense_in_result) dense_out_result = self.dense_out(tanh_result) return dense_out_result def twod_interpolate_position_embeddings_hook( Loading Loading @@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook( class VitModel(MegatronModule): """Vision Transformer Model.""" def __init__(self, num_classes, finetune=False): super(VitModel, self).__init__() def __init__(self, num_classes, finetune=False, pre_process=True, post_process=True): super(VitModel, self).__init__(share_word_embeddings=False) args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy Loading @@ -136,6 +140,8 @@ class VitModel(MegatronModule): args.init_method_std, args.num_layers ) self.pre_process = pre_process self.post_process = post_process self.hidden_size = args.hidden_size self.num_classes = num_classes self.patch_dim = args.patch_dim Loading @@ -148,8 +154,11 @@ class VitModel(MegatronModule): self.seq_length = self.num_patches + 1 self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels if self.pre_process: # cls_token self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size)) self.cls_token = torch.nn.Parameter( torch.randn(1, 1, self.hidden_size) ) torch.nn.init.zeros_(self.cls_token) # Linear encoder Loading @@ -174,9 +183,13 @@ class VitModel(MegatronModule): # Transformer self.transformer = ParallelTransformer( self.init_method, self.scaled_init_method self.init_method, self.scaled_init_method, pre_process=self.pre_process, post_process=self.post_process ) if self.post_process: # MLP head if not self.finetune: self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes) Loading @@ -185,26 +198,37 @@ class VitModel(MegatronModule): self.hidden_size, num_classes, torch.nn.init.zeros_ ) def forward(self, x): x = einops.rearrange( x, def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.transformer.set_input_tensor(input_tensor) def forward(self, input): if self.pre_process: rearranged_input = einops.rearrange( input, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=self.patch_dim, p2=self.patch_dim, ) assert x.dtype == torch.half x = self.linear_encoder(x) cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) assert rearranged_input.dtype == torch.half encoder_output = self.linear_encoder(rearranged_input) cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1) concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1) token_embeddings = concatenated_tokens + \ self.position_embeddings(self.position_ids) hidden_states = self.embedding_dropout(token_embeddings) else: hidden_states = input x = x + self.position_embeddings(self.position_ids) x = self.embedding_dropout(x) x = self.transformer(x, None) hidden_states = self.transformer(hidden_states, None) if self.post_process: if not self.finetune: x = self.mlp_head(x) hidden_states = self.mlp_head(hidden_states) else: x = self.class_head(x[:, 0, :]) hidden_states = self.class_head(hidden_states[:, 0, :]) return x return hidden_states pretrain_vit.py +20 −14 Original line number Diff line number Diff line Loading @@ -17,19 +17,22 @@ import torch import torch.nn.functional as F from functools import partial from megatron import get_args, get_timers, mpu, print_rank_0 from megatron.data.vit_dataset import build_train_valid_datasets from megatron.model.vit_model import VitModel from megatron.training import pretrain from megatron.utils import average_losses_across_data_parallel_group def model_provider(): def model_provider(pre_process=True, post_process=True): """Build the model.""" print_rank_0("building VIT model ...") args = get_args() model = VitModel(num_classes=args.num_classes) model = VitModel(num_classes=args.num_classes, pre_process=pre_process, post_process=post_process) return model def get_batch(data_iterator): Loading @@ -42,10 +45,21 @@ def get_batch(data_iterator): return images, labels def forward_step(data_iterator, model, input_tensor): def loss_func(labels, output_tensor): logits = output_tensor.contiguous().float() loss = F.cross_entropy(logits, labels) outputs = torch.argmax(logits, -1) correct = (outputs == labels).float() accuracy = torch.mean(correct) averaged_loss = average_losses_across_data_parallel_group([loss, accuracy]) return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]} def forward_step(data_iterator, model): """Forward step.""" timers = get_timers() assert input_tensor is None # Get the batch. timers("batch-generator").start() Loading @@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor): timers("batch-generator").stop() # Forward model. lm_labels logits = model(images).contiguous().float() loss = F.cross_entropy(logits, labels) outputs = torch.argmax(logits, -1) correct = (outputs == labels).float() accuracy = torch.mean(correct) averaged_loss = average_losses_across_data_parallel_group([loss, accuracy]) return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]} output_tensor = model(images) return output_tensor, partial(loss_func, labels) def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" Loading tasks/vision/classification.py +3 −2 Original line number Diff line number Diff line Loading @@ -34,13 +34,14 @@ def classification(): ) return train_ds, valid_ds def model_provider(): def model_provider(pre_process=True, post_process=True): """Build the model.""" args = get_args() print_rank_0("building classification model for ImageNet ...") return VitModel(num_classes=args.num_classes, finetune=True) return VitModel(num_classes=args.num_classes, finetune=True, pre_process=pre_process, post_process=post_process) """Finetune/evaluate.""" finetune( Loading tasks/vision/eval_utils.py +56 −20 Original line number Diff line number Diff line Loading @@ -16,10 +16,14 @@ """Evaluation utilities.""" import os from functools import partial import torch from megatron import get_args from megatron import print_rank_0 from megatron import print_rank_0, print_rank_last from megatron import mpu from megatron.schedules import get_forward_backward_func from tasks.vision.finetune_utils import build_data_loader from tasks.vision.finetune_utils import process_batch from torchvision import datasets, transforms Loading Loading @@ -56,7 +60,7 @@ def accuracy_func_provider(): print_rank_0("calculating metrics ...") correct, total = calculate_correct_answers(model, dataloader, epoch) percent = float(correct) * 100.0 / float(total) print_rank_0( print_rank_last( " >> |epoch: {}| overall: correct / total = {} / {} = " "{:.4f} %".format(epoch, correct, total, percent) ) Loading @@ -67,27 +71,59 @@ def accuracy_func_provider(): def calculate_correct_answers(model, dataloader, epoch): """Calculate correct over total answers""" model.eval() args = get_args() forward_backward_func = get_forward_backward_func() for m in model: m.eval() def loss_func(labels, output_tensor): logits = output_tensor loss_dict = {} # Compute the correct answers. predicted = torch.argmax(logits, dim=-1) corrects = (predicted == labels).float() # Add to the counters. loss_dict['total'] = labels.size(0) loss_dict['correct'] = corrects.sum().item() return 0, loss_dict #defined inside to capture output_predictions def correct_answers_forward_step(batch, model): try: batch_ = next(batch) except BaseException: batch_ = batch images, labels = process_batch(batch_) # Forward model. args = get_args() output_tensor = model(images) return output_tensor, partial(loss_func, labels) with torch.no_grad(): # For all the batches in the dataset. total = 0 correct = 0 for _, batch in enumerate(dataloader): # Run the model forward. images, labels = process_batch(batch) logits = model(images).contiguous().float() # Add output predictions. # Compute the correct answers. predicted = torch.argmax(logits, dim=-1) corrects = (predicted == labels).float() # Add to the counters. total += labels.size(0) correct += corrects.sum().item() model.train() loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, optimizer=None, timers=None, forward_only=True) for loss_dict in loss_dicts: total += loss_dict['total'] correct += loss_dict['correct'] for m in model: m.train() # Reduce. if mpu.is_pipeline_last_stage(): unreduced = torch.cuda.LongTensor([correct, total]) torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) # Print on screen. correct_ans = unreduced[0].item() Loading Loading
megatron/checkpointing.py +1 −1 Original line number Diff line number Diff line Loading @@ -60,8 +60,8 @@ def check_checkpoint_args(checkpoint_args): _compare('num_layers') _compare('hidden_size') _compare('num_attention_heads') _compare('max_position_embeddings') if args.vocab_file: _compare('max_position_embeddings') _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') _compare('tokenizer_type') Loading
megatron/model/vit_model.py +76 −52 Original line number Diff line number Diff line Loading @@ -50,11 +50,11 @@ class VitMlpHead(MegatronModule): def forward(self, hidden_states, sequence_index=0): # hidden_states: [b, s, h] # sequence_index: index of the token to pool. x = hidden_states[:, sequence_index, :] x = self.dense_in(x) x = torch.tanh(x) x = self.dense_out(x) return x hidden_state = hidden_states[:, sequence_index, :] dense_in_result = self.dense_in(hidden_state) tanh_result = torch.tanh(dense_in_result) dense_out_result = self.dense_out(tanh_result) return dense_out_result def twod_interpolate_position_embeddings_hook( Loading Loading @@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook( class VitModel(MegatronModule): """Vision Transformer Model.""" def __init__(self, num_classes, finetune=False): super(VitModel, self).__init__() def __init__(self, num_classes, finetune=False, pre_process=True, post_process=True): super(VitModel, self).__init__(share_word_embeddings=False) args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy Loading @@ -136,6 +140,8 @@ class VitModel(MegatronModule): args.init_method_std, args.num_layers ) self.pre_process = pre_process self.post_process = post_process self.hidden_size = args.hidden_size self.num_classes = num_classes self.patch_dim = args.patch_dim Loading @@ -148,8 +154,11 @@ class VitModel(MegatronModule): self.seq_length = self.num_patches + 1 self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels if self.pre_process: # cls_token self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size)) self.cls_token = torch.nn.Parameter( torch.randn(1, 1, self.hidden_size) ) torch.nn.init.zeros_(self.cls_token) # Linear encoder Loading @@ -174,9 +183,13 @@ class VitModel(MegatronModule): # Transformer self.transformer = ParallelTransformer( self.init_method, self.scaled_init_method self.init_method, self.scaled_init_method, pre_process=self.pre_process, post_process=self.post_process ) if self.post_process: # MLP head if not self.finetune: self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes) Loading @@ -185,26 +198,37 @@ class VitModel(MegatronModule): self.hidden_size, num_classes, torch.nn.init.zeros_ ) def forward(self, x): x = einops.rearrange( x, def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.transformer.set_input_tensor(input_tensor) def forward(self, input): if self.pre_process: rearranged_input = einops.rearrange( input, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=self.patch_dim, p2=self.patch_dim, ) assert x.dtype == torch.half x = self.linear_encoder(x) cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) assert rearranged_input.dtype == torch.half encoder_output = self.linear_encoder(rearranged_input) cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1) concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1) token_embeddings = concatenated_tokens + \ self.position_embeddings(self.position_ids) hidden_states = self.embedding_dropout(token_embeddings) else: hidden_states = input x = x + self.position_embeddings(self.position_ids) x = self.embedding_dropout(x) x = self.transformer(x, None) hidden_states = self.transformer(hidden_states, None) if self.post_process: if not self.finetune: x = self.mlp_head(x) hidden_states = self.mlp_head(hidden_states) else: x = self.class_head(x[:, 0, :]) hidden_states = self.class_head(hidden_states[:, 0, :]) return x return hidden_states
pretrain_vit.py +20 −14 Original line number Diff line number Diff line Loading @@ -17,19 +17,22 @@ import torch import torch.nn.functional as F from functools import partial from megatron import get_args, get_timers, mpu, print_rank_0 from megatron.data.vit_dataset import build_train_valid_datasets from megatron.model.vit_model import VitModel from megatron.training import pretrain from megatron.utils import average_losses_across_data_parallel_group def model_provider(): def model_provider(pre_process=True, post_process=True): """Build the model.""" print_rank_0("building VIT model ...") args = get_args() model = VitModel(num_classes=args.num_classes) model = VitModel(num_classes=args.num_classes, pre_process=pre_process, post_process=post_process) return model def get_batch(data_iterator): Loading @@ -42,10 +45,21 @@ def get_batch(data_iterator): return images, labels def forward_step(data_iterator, model, input_tensor): def loss_func(labels, output_tensor): logits = output_tensor.contiguous().float() loss = F.cross_entropy(logits, labels) outputs = torch.argmax(logits, -1) correct = (outputs == labels).float() accuracy = torch.mean(correct) averaged_loss = average_losses_across_data_parallel_group([loss, accuracy]) return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]} def forward_step(data_iterator, model): """Forward step.""" timers = get_timers() assert input_tensor is None # Get the batch. timers("batch-generator").start() Loading @@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor): timers("batch-generator").stop() # Forward model. lm_labels logits = model(images).contiguous().float() loss = F.cross_entropy(logits, labels) outputs = torch.argmax(logits, -1) correct = (outputs == labels).float() accuracy = torch.mean(correct) averaged_loss = average_losses_across_data_parallel_group([loss, accuracy]) return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]} output_tensor = model(images) return output_tensor, partial(loss_func, labels) def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" Loading
tasks/vision/classification.py +3 −2 Original line number Diff line number Diff line Loading @@ -34,13 +34,14 @@ def classification(): ) return train_ds, valid_ds def model_provider(): def model_provider(pre_process=True, post_process=True): """Build the model.""" args = get_args() print_rank_0("building classification model for ImageNet ...") return VitModel(num_classes=args.num_classes, finetune=True) return VitModel(num_classes=args.num_classes, finetune=True, pre_process=pre_process, post_process=post_process) """Finetune/evaluate.""" finetune( Loading
tasks/vision/eval_utils.py +56 −20 Original line number Diff line number Diff line Loading @@ -16,10 +16,14 @@ """Evaluation utilities.""" import os from functools import partial import torch from megatron import get_args from megatron import print_rank_0 from megatron import print_rank_0, print_rank_last from megatron import mpu from megatron.schedules import get_forward_backward_func from tasks.vision.finetune_utils import build_data_loader from tasks.vision.finetune_utils import process_batch from torchvision import datasets, transforms Loading Loading @@ -56,7 +60,7 @@ def accuracy_func_provider(): print_rank_0("calculating metrics ...") correct, total = calculate_correct_answers(model, dataloader, epoch) percent = float(correct) * 100.0 / float(total) print_rank_0( print_rank_last( " >> |epoch: {}| overall: correct / total = {} / {} = " "{:.4f} %".format(epoch, correct, total, percent) ) Loading @@ -67,27 +71,59 @@ def accuracy_func_provider(): def calculate_correct_answers(model, dataloader, epoch): """Calculate correct over total answers""" model.eval() args = get_args() forward_backward_func = get_forward_backward_func() for m in model: m.eval() def loss_func(labels, output_tensor): logits = output_tensor loss_dict = {} # Compute the correct answers. predicted = torch.argmax(logits, dim=-1) corrects = (predicted == labels).float() # Add to the counters. loss_dict['total'] = labels.size(0) loss_dict['correct'] = corrects.sum().item() return 0, loss_dict #defined inside to capture output_predictions def correct_answers_forward_step(batch, model): try: batch_ = next(batch) except BaseException: batch_ = batch images, labels = process_batch(batch_) # Forward model. args = get_args() output_tensor = model(images) return output_tensor, partial(loss_func, labels) with torch.no_grad(): # For all the batches in the dataset. total = 0 correct = 0 for _, batch in enumerate(dataloader): # Run the model forward. images, labels = process_batch(batch) logits = model(images).contiguous().float() # Add output predictions. # Compute the correct answers. predicted = torch.argmax(logits, dim=-1) corrects = (predicted == labels).float() # Add to the counters. total += labels.size(0) correct += corrects.sum().item() model.train() loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, optimizer=None, timers=None, forward_only=True) for loss_dict in loss_dicts: total += loss_dict['total'] correct += loss_dict['correct'] for m in model: m.train() # Reduce. if mpu.is_pipeline_last_stage(): unreduced = torch.cuda.LongTensor([correct, total]) torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) # Print on screen. correct_ans = unreduced[0].item() Loading