Loading finetune_hoc.lsf +13 −11 Original line number Diff line number Diff line #!/bin/bash #BSUB -nnodes 2 #BSUB -W 2:00 #BSUB -P med106 #BSUB -W 0:45 #BSUB -P med107 #BSUB -alloc_flags "smt4 nvme" #BSUB -J hoc_FULL #BSUB -o hoc_FULL.%J Loading @@ -13,10 +13,11 @@ set +x #module load open-ce/1.4.0-py38-0 module load open-ce conda deactivate conda activate /gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch module list conda activate /gpfs/alpine/med106/world-shared/irl1/rhel8/myt_py1.11 export OMP_NUM_THREADS=1 ulimit -n 65536 rm -f `find -name *lock` #export PYTHONPATH=$PYTHONPATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/megatron/fused_kernels #export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch/lib/python3.8/site-packages/torch/lib #export PATH=$PATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch/lib/python3.8/site-packages/torch/include Loading @@ -35,16 +36,16 @@ VALID_DATA="/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/hocdata" #export VALID_DATA=picodata/dev.tsv export VOCAB_FILE=/gpfs/alpine/world-shared/med106/g8o/pubmed_bert-vocab.txt export CHECKPOINT_PATH=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/finetune-HOC_BIG export PRETRAINED_CHECKPOINT=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/chkptt export CHECKPOINT_PATH=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/finetune-pubmed_bert_1x1_b8 export PRETRAINED_CHECKPOINT=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/pubmed_bert_1x1_b8_chckpttt jsrun --smpiargs="-disable_gpu_hooks" -n $nnodes -r 1 -g 6 -a 6 -c 42 python tasks/main.py \ --task HOC \ --tensor-model-parallel-size 2 \ --pipeline-model-parallel-size 2 \ --num-layers 24 \ --hidden-size 1024 \ --num-attention-heads 16 \ --tensor-model-parallel-size 1 \ --pipeline-model-parallel-size 1 \ --num-layers 12 \ --hidden-size 768 \ --num-attention-heads 12 \ --seq-length 512 \ --max-position-embeddings 512 \ --fp16 \ Loading @@ -64,6 +65,7 @@ jsrun --smpiargs="-disable_gpu_hooks" -n $nnodes -r 1 -g 6 -a 6 -c 42 python tas --micro-batch-size 4 \ --lr 0.0001 \ --lr-warmup-fraction 0.06 \ --num-workers 0 \ --distributed-backend nccl #--DDP-impl torch \ megatron/arguments.py +1 −1 Original line number Diff line number Diff line Loading @@ -660,7 +660,7 @@ def _add_distributed_args(parser): group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], choices=['nccl', 'gloo', 'mpi'], help='Which backend to use for distributed training.') group.add_argument('--DDP-impl', default='local', choices=['local', 'torch'], Loading megatron/fused_kernels/layer_norm_cuda_kernel.cu +6 −3 Original line number Diff line number Diff line Loading @@ -250,12 +250,15 @@ void cuWelfordMuSigma2( template<typename U> U rsqrt(U v) { return U(1) / sqrt(v); } #if defined __HIP_PLATFORM_HCC__ template<> #else #if defined __HIP_PLATFORM_AMD__ __device__ float rsqrt(float v) { return rsqrtf(v); } #else template<> float rsqrt(float v) { return rsqrtf(v); } #endif template<> double rsqrt(double v) { Loading megatron/fused_kernels/scaled_masked_softmax.cpp +4 −0 Original line number Diff line number Diff line Loading @@ -14,7 +14,11 @@ * limitations under the License. */ #if defined __HIP_PLATFORM_AMD__ #include <hip/hip_fp16.h> #else #include <cuda_fp16.h> #endif #include <torch/extension.h> #include <vector> Loading megatron/fused_kernels/scaled_masked_softmax.h +8 −0 Original line number Diff line number Diff line Loading @@ -17,11 +17,19 @@ #pragma once #include <assert.h> #if defined __HIP_PLATFORM_AMD__ #include <hip/hip_fp16.h> #else #include <cuda_fp16.h> #endif #include <cfloat> #include <limits> #include <stdint.h> #if defined __HIP_PLATFORM_AMD__ #include <hip/hip_fp16.h> #else #include <cuda_fp16.h> #endif #include <c10/macros/Macros.h> namespace { Loading Loading
finetune_hoc.lsf +13 −11 Original line number Diff line number Diff line #!/bin/bash #BSUB -nnodes 2 #BSUB -W 2:00 #BSUB -P med106 #BSUB -W 0:45 #BSUB -P med107 #BSUB -alloc_flags "smt4 nvme" #BSUB -J hoc_FULL #BSUB -o hoc_FULL.%J Loading @@ -13,10 +13,11 @@ set +x #module load open-ce/1.4.0-py38-0 module load open-ce conda deactivate conda activate /gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch module list conda activate /gpfs/alpine/med106/world-shared/irl1/rhel8/myt_py1.11 export OMP_NUM_THREADS=1 ulimit -n 65536 rm -f `find -name *lock` #export PYTHONPATH=$PYTHONPATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/megatron/fused_kernels #export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch/lib/python3.8/site-packages/torch/lib #export PATH=$PATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch/lib/python3.8/site-packages/torch/include Loading @@ -35,16 +36,16 @@ VALID_DATA="/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/hocdata" #export VALID_DATA=picodata/dev.tsv export VOCAB_FILE=/gpfs/alpine/world-shared/med106/g8o/pubmed_bert-vocab.txt export CHECKPOINT_PATH=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/finetune-HOC_BIG export PRETRAINED_CHECKPOINT=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/chkptt export CHECKPOINT_PATH=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/finetune-pubmed_bert_1x1_b8 export PRETRAINED_CHECKPOINT=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/pubmed_bert_1x1_b8_chckpttt jsrun --smpiargs="-disable_gpu_hooks" -n $nnodes -r 1 -g 6 -a 6 -c 42 python tasks/main.py \ --task HOC \ --tensor-model-parallel-size 2 \ --pipeline-model-parallel-size 2 \ --num-layers 24 \ --hidden-size 1024 \ --num-attention-heads 16 \ --tensor-model-parallel-size 1 \ --pipeline-model-parallel-size 1 \ --num-layers 12 \ --hidden-size 768 \ --num-attention-heads 12 \ --seq-length 512 \ --max-position-embeddings 512 \ --fp16 \ Loading @@ -64,6 +65,7 @@ jsrun --smpiargs="-disable_gpu_hooks" -n $nnodes -r 1 -g 6 -a 6 -c 42 python tas --micro-batch-size 4 \ --lr 0.0001 \ --lr-warmup-fraction 0.06 \ --num-workers 0 \ --distributed-backend nccl #--DDP-impl torch \
megatron/arguments.py +1 −1 Original line number Diff line number Diff line Loading @@ -660,7 +660,7 @@ def _add_distributed_args(parser): group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], choices=['nccl', 'gloo', 'mpi'], help='Which backend to use for distributed training.') group.add_argument('--DDP-impl', default='local', choices=['local', 'torch'], Loading
megatron/fused_kernels/layer_norm_cuda_kernel.cu +6 −3 Original line number Diff line number Diff line Loading @@ -250,12 +250,15 @@ void cuWelfordMuSigma2( template<typename U> U rsqrt(U v) { return U(1) / sqrt(v); } #if defined __HIP_PLATFORM_HCC__ template<> #else #if defined __HIP_PLATFORM_AMD__ __device__ float rsqrt(float v) { return rsqrtf(v); } #else template<> float rsqrt(float v) { return rsqrtf(v); } #endif template<> double rsqrt(double v) { Loading
megatron/fused_kernels/scaled_masked_softmax.cpp +4 −0 Original line number Diff line number Diff line Loading @@ -14,7 +14,11 @@ * limitations under the License. */ #if defined __HIP_PLATFORM_AMD__ #include <hip/hip_fp16.h> #else #include <cuda_fp16.h> #endif #include <torch/extension.h> #include <vector> Loading
megatron/fused_kernels/scaled_masked_softmax.h +8 −0 Original line number Diff line number Diff line Loading @@ -17,11 +17,19 @@ #pragma once #include <assert.h> #if defined __HIP_PLATFORM_AMD__ #include <hip/hip_fp16.h> #else #include <cuda_fp16.h> #endif #include <cfloat> #include <limits> #include <stdint.h> #if defined __HIP_PLATFORM_AMD__ #include <hip/hip_fp16.h> #else #include <cuda_fp16.h> #endif #include <c10/macros/Macros.h> namespace { Loading