Loading megatron/mpu/cross_entropy.py +13 −11 Original line number Diff line number Diff line Loading @@ -27,21 +27,13 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): @staticmethod def forward(ctx, vocab_parallel_logits, target): # Copy so the input remains unchanged. logits = vocab_parallel_logits.clone() # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(logits, dim=-1)[0] logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=get_model_parallel_group()) # Subtract the maximum value. logits.sub_(logits_max.unsqueeze(dim=-1)) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = logits.exp() sum_exp_logits = exp_logits.sum(dim=-1) torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group()) vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) # Get the partition's vocab indecies get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size Loading @@ -59,11 +51,12 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): # Get predicted-logits = logits[target]. # For Simplicity, we convert logits to a 2-D tensor with size # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = logits.view(-1, partition_vocab_size) logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. Loading @@ -71,6 +64,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group()) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group()) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits Loading Loading
megatron/mpu/cross_entropy.py +13 −11 Original line number Diff line number Diff line Loading @@ -27,21 +27,13 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): @staticmethod def forward(ctx, vocab_parallel_logits, target): # Copy so the input remains unchanged. logits = vocab_parallel_logits.clone() # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(logits, dim=-1)[0] logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=get_model_parallel_group()) # Subtract the maximum value. logits.sub_(logits_max.unsqueeze(dim=-1)) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = logits.exp() sum_exp_logits = exp_logits.sum(dim=-1) torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group()) vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) # Get the partition's vocab indecies get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size Loading @@ -59,11 +51,12 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): # Get predicted-logits = logits[target]. # For Simplicity, we convert logits to a 2-D tensor with size # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = logits.view(-1, partition_vocab_size) logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. Loading @@ -71,6 +64,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group()) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group()) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits Loading