Commit 14c85e64 authored by Rewon Child's avatar Rewon Child
Browse files

Merge branch 'main' into rc-debug-underflow

parents 4e77e7c6 c1faa9fe
Loading
Loading
Loading
Loading
+19 −3
Original line number Diff line number Diff line
@@ -203,6 +203,22 @@ def parse_args(extra_args_provider=None, defaults={},
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'

    # custom kernel constraints check
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size

    # constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
    custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
        seq_len % 4 == 0 and attn_batch_size % 4 == 0

    if args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion:
        print('WARNING: constraints for invoking optimized'
            ' fused softmax kernel are not met. We default back to unfused'
            ' kernel invocations.')

    # Load scaled_masked_softmax_fusion_kernels
    if args.masked_softmax_fusion:
        fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
@@ -480,9 +496,9 @@ def _add_checkpointing_args(parser):
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
    group.add_argument('--no-save-optim', action='store_true',
    group.add_argument('--no-save-optim', action='store_true', default=None,
                       help='Do not save current optimizer.')
    group.add_argument('--no-save-rng', action='store_true',
    group.add_argument('--no-save-rng', action='store_true', default=None,
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
+5 −2
Original line number Diff line number Diff line
@@ -343,12 +343,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
            np.random.set_state(state_dict['np_rng_state'])
            torch.set_rng_state(state_dict['torch_rng_state'])
            torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
            # Check for empty states array
            if not state_dict['rng_tracker_states']:
                raise KeyError
            mpu.get_cuda_rng_tracker().set_states(
                state_dict['rng_tracker_states'])
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
            print_rank_0('Unable to load rng state from checkpoint {}. '
                         'Specify --no-load-rng or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'attempting to load the rng state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

+10 −5
Original line number Diff line number Diff line
@@ -116,15 +116,20 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
 
    def forward(self, input, mask):
        # [b, np, sq, sk]
        assert input.dim() == 4
        data_size = input.size()
        query_seq_len = data_size[-2]
        key_seq_len = data_size[-1]
        assert input.dim() == 4
        attn_batch_size = data_size[0] * data_size[1]

        # invoke custom kernel
        if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
           query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
        # constraints on various tensor dimensions to enable warp based
        # optimization and upper triangular optimization (for causal mask)
        custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
            query_seq_len % 4 == 0 and attn_batch_size % 4 == 0

        # invoke custom kernel
        if self.input_in_fp16 and mask is not None and \
           custom_kernel_constraint and self.scaled_masked_softmax_fusion:
            scale = self.scale if self.scale is not None else 1.0

            if self.attn_mask_type == AttnMaskType.causal:
+2 −0
Original line number Diff line number Diff line
@@ -351,6 +351,8 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
    reqs = torch.distributed.batch_isend_irecv(ops)
    for req in reqs:
        req.wait()
    # Temporary workaround for batch_isend_irecv() race condition.
    torch.cuda.synchronize()

    return tensor_recv_prev, tensor_recv_next

+3 −1
Original line number Diff line number Diff line
@@ -92,7 +92,9 @@ def main():
    """Main program."""

    initialize_megatron(extra_args_provider=add_text_generate_args,
                        args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
                        args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
                                       'no_load_rng': True,
                                       'no_load_optim': True})

    # Set up model and load checkpoint.
    model = get_model(model_provider)
Loading