Loading megatron/model/transformer.py +1 −1 Original line number Diff line number Diff line Loading @@ -269,7 +269,7 @@ class ParallelSelfAttention(MegatronModule): # Attention probabilities. [b, np, s, s] if self.apply_query_key_layer_scaling: attention_scores = attention_scores * self.layer_number attention_probs = torch.nn.Softmax(dim=-1)(attention_probs) attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): Loading Loading
megatron/model/transformer.py +1 −1 Original line number Diff line number Diff line Loading @@ -269,7 +269,7 @@ class ParallelSelfAttention(MegatronModule): # Attention probabilities. [b, np, s, s] if self.apply_query_key_layer_scaling: attention_scores = attention_scores * self.layer_number attention_probs = torch.nn.Softmax(dim=-1)(attention_probs) attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): Loading