Commit 8ceed7c7 authored by Mohammad's avatar Mohammad
Browse files

changed gpt2 masking to binary and masked_fill

parent c0a59a66
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal


def gpt2_attention_mask_func(attention_scores, ltor_mask):
    attention_scores = torch.mul(attention_scores, ltor_mask) - \
        10000.0 * (1.0 - ltor_mask)
    attention_scores.masked_fill_(ltor_mask, -10000.0)
    return attention_scores


+1 −2
Original line number Diff line number Diff line
@@ -42,8 +42,7 @@ def get_batch(context_tokens):
        tokenizer.eod,
        args.reset_position_ids,
        args.reset_attention_mask,
        args.eod_mask_loss,
        args.fp16)
        args.eod_mask_loss)

    return tokens, attention_mask, position_ids

+3 −5
Original line number Diff line number Diff line
@@ -119,8 +119,7 @@ def get_ltor_masks_and_position_ids(data,
                                    eod_token,
                                    reset_position_ids,
                                    reset_attention_mask,
                                    eod_mask_loss,
                                    fp16):
                                    eod_mask_loss):
    """Build masks and position id for left to right model."""

    # Extract batch size and sequence length.
@@ -170,8 +169,7 @@ def get_ltor_masks_and_position_ids(data,
                    position_ids[b, (i + 1):] -= (i + 1 - prev_index)
                    prev_index = i + 1

    # Convert
    if fp16:
        attention_mask = attention_mask.half()
    # Convert attention mask to binary:
    attention_mask = (attention_mask < 0.5)

    return attention_mask, loss_mask, position_ids
+1 −2
Original line number Diff line number Diff line
@@ -65,8 +65,7 @@ def get_batch(data_iterator):
        tokenizer.eod,
        args.reset_position_ids,
        args.reset_attention_mask,
        args.eod_mask_loss,
        args.fp16)
        args.eod_mask_loss)

    return tokens, labels, loss_mask, attention_mask, position_ids

+1 −2
Original line number Diff line number Diff line
@@ -71,8 +71,7 @@ def process_batch(batch):
        tokenizer.eod,
        args.reset_position_ids,
        args.reset_attention_mask,
        args.eod_mask_loss,
        args.fp16)
        args.eod_mask_loss)

    return tokens, labels, attention_mask, position_ids, loss_mask