Loading generate_samples.py +2 −2 Original line number Diff line number Diff line Loading @@ -28,7 +28,7 @@ from arguments import get_args from megatron.utils import Timers from megatron.utils import initialize_distributed from megatron.utils import set_random_seed from pretrain_gpt2 import get_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import load_checkpoint from megatron.data_utils import make_tokenizer from configure_data import configure_data Loading Loading @@ -91,7 +91,7 @@ def get_batch(context_tokens, args): tokens = tokens.to(device) # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_masks_and_position_ids( attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, args.eod_token, args.reset_position_ids, Loading megatron/utils.py +57 −0 Original line number Diff line number Diff line Loading @@ -31,6 +31,63 @@ from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import get_params_for_weight_decay_optimization def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss): """Build masks and position id for left to right model.""" # Extract batch size and sequence length. batch_size, seq_length = data.size() # Attention mask (lower triangular). if reset_attention_mask: att_mask_batch = batch_size else: att_mask_batch = 1 attention_mask = torch.tril(torch.ones( (att_mask_batch, seq_length, seq_length), device=data.device)).view( att_mask_batch, 1, seq_length, seq_length) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) if eod_mask_loss: loss_mask[data == eod_token] = 0.0 # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0).expand_as(data) # We need to clone as the ids will be modifed based on batch index. if reset_position_ids: position_ids = position_ids.clone() if reset_position_ids or reset_attention_mask: # Loop through the batches: for b in range(batch_size): # Find indecies where EOD token is. eod_index = position_ids[b, data[b] == eod_token] # Detach indecies from positions if going to modify positions. if reset_position_ids: eod_index = eod_index.clone() # Loop through EOD indecies: prev_index = 0 for j in range(eod_index.size()[0]): i = eod_index[j] # Mask attention loss. if reset_attention_mask: attention_mask[b, 0, (i+1):, :(i+1)] = 0 # Reset positions. if reset_position_ids: position_ids[b, (i+1):] -= (i + 1 - prev_index) prev_index = i + 1 return attention_mask, loss_mask, position_ids def reduce_losses(losses): reduced_losses = torch.cat( [loss.clone().detach().view(1) for loss in losses]) Loading pretrain_gpt2.py +2 −58 Original line number Diff line number Diff line Loading @@ -21,6 +21,7 @@ from configure_data import configure_data from gpt2_data_loader import make_gpt2_dataloaders from megatron import mpu from megatron.model import GPT2Model from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import print_rank_0 from megatron.utils import reduce_losses from megatron.utils import vocab_size_with_padding Loading @@ -47,63 +48,6 @@ def model_provider(args): return model def get_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss): """Build masks and position id.""" # Extract batch size and sequence length. batch_size, seq_length = data.size() # Attention mask (lower triangular). if reset_attention_mask: att_mask_batch = batch_size else: att_mask_batch = 1 attention_mask = torch.tril(torch.ones( (att_mask_batch, seq_length, seq_length), device=data.device)).view( att_mask_batch, 1, seq_length, seq_length) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) if eod_mask_loss: loss_mask[data == eod_token] = 0.0 # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0).expand_as(data) # We need to clone as the ids will be modifed based on batch index. if reset_position_ids: position_ids = position_ids.clone() if reset_position_ids or reset_attention_mask: # Loop through the batches: for b in range(batch_size): # Find indecies where EOD token is. eod_index = position_ids[b, data[b] == eod_token] # Detach indecies from positions if going to modify positions. if reset_position_ids: eod_index = eod_index.clone() # Loop through EOD indecies: prev_index = 0 for j in range(eod_index.size()[0]): i = eod_index[j] # Mask attention loss. if reset_attention_mask: attention_mask[b, 0, (i+1):, :(i+1)] = 0 # Reset positions. if reset_position_ids: position_ids[b, (i+1):] -= (i + 1 - prev_index) prev_index = i + 1 return attention_mask, loss_mask, position_ids def get_batch(data_iterator, args, timers): """Generate a batch""" Loading @@ -126,7 +70,7 @@ def get_batch(data_iterator, args, timers): tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_masks_and_position_ids( attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, args.eod_token, args.reset_position_ids, Loading Loading
generate_samples.py +2 −2 Original line number Diff line number Diff line Loading @@ -28,7 +28,7 @@ from arguments import get_args from megatron.utils import Timers from megatron.utils import initialize_distributed from megatron.utils import set_random_seed from pretrain_gpt2 import get_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import load_checkpoint from megatron.data_utils import make_tokenizer from configure_data import configure_data Loading Loading @@ -91,7 +91,7 @@ def get_batch(context_tokens, args): tokens = tokens.to(device) # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_masks_and_position_ids( attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, args.eod_token, args.reset_position_ids, Loading
megatron/utils.py +57 −0 Original line number Diff line number Diff line Loading @@ -31,6 +31,63 @@ from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import get_params_for_weight_decay_optimization def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss): """Build masks and position id for left to right model.""" # Extract batch size and sequence length. batch_size, seq_length = data.size() # Attention mask (lower triangular). if reset_attention_mask: att_mask_batch = batch_size else: att_mask_batch = 1 attention_mask = torch.tril(torch.ones( (att_mask_batch, seq_length, seq_length), device=data.device)).view( att_mask_batch, 1, seq_length, seq_length) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) if eod_mask_loss: loss_mask[data == eod_token] = 0.0 # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0).expand_as(data) # We need to clone as the ids will be modifed based on batch index. if reset_position_ids: position_ids = position_ids.clone() if reset_position_ids or reset_attention_mask: # Loop through the batches: for b in range(batch_size): # Find indecies where EOD token is. eod_index = position_ids[b, data[b] == eod_token] # Detach indecies from positions if going to modify positions. if reset_position_ids: eod_index = eod_index.clone() # Loop through EOD indecies: prev_index = 0 for j in range(eod_index.size()[0]): i = eod_index[j] # Mask attention loss. if reset_attention_mask: attention_mask[b, 0, (i+1):, :(i+1)] = 0 # Reset positions. if reset_position_ids: position_ids[b, (i+1):] -= (i + 1 - prev_index) prev_index = i + 1 return attention_mask, loss_mask, position_ids def reduce_losses(losses): reduced_losses = torch.cat( [loss.clone().detach().view(1) for loss in losses]) Loading
pretrain_gpt2.py +2 −58 Original line number Diff line number Diff line Loading @@ -21,6 +21,7 @@ from configure_data import configure_data from gpt2_data_loader import make_gpt2_dataloaders from megatron import mpu from megatron.model import GPT2Model from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import print_rank_0 from megatron.utils import reduce_losses from megatron.utils import vocab_size_with_padding Loading @@ -47,63 +48,6 @@ def model_provider(args): return model def get_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss): """Build masks and position id.""" # Extract batch size and sequence length. batch_size, seq_length = data.size() # Attention mask (lower triangular). if reset_attention_mask: att_mask_batch = batch_size else: att_mask_batch = 1 attention_mask = torch.tril(torch.ones( (att_mask_batch, seq_length, seq_length), device=data.device)).view( att_mask_batch, 1, seq_length, seq_length) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) if eod_mask_loss: loss_mask[data == eod_token] = 0.0 # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0).expand_as(data) # We need to clone as the ids will be modifed based on batch index. if reset_position_ids: position_ids = position_ids.clone() if reset_position_ids or reset_attention_mask: # Loop through the batches: for b in range(batch_size): # Find indecies where EOD token is. eod_index = position_ids[b, data[b] == eod_token] # Detach indecies from positions if going to modify positions. if reset_position_ids: eod_index = eod_index.clone() # Loop through EOD indecies: prev_index = 0 for j in range(eod_index.size()[0]): i = eod_index[j] # Mask attention loss. if reset_attention_mask: attention_mask[b, 0, (i+1):, :(i+1)] = 0 # Reset positions. if reset_position_ids: position_ids[b, (i+1):] -= (i + 1 - prev_index) prev_index = i + 1 return attention_mask, loss_mask, position_ids def get_batch(data_iterator, args, timers): """Generate a batch""" Loading @@ -126,7 +70,7 @@ def get_batch(data_iterator, args, timers): tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_masks_and_position_ids( attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, args.eod_token, args.reset_position_ids, Loading