Loading megatron/fused_kernels/layer_norm_cuda_kernel.cu +1 −1 Original line number Diff line number Diff line Loading @@ -21,7 +21,7 @@ #include "ATen/ATen.h" #include "ATen/AccumulateType.h" #include "ATen/cuda/CUDAContext.h" #include <THC/THCDeviceUtils.cuh> #include "ATen/cuda/DeviceUtils.cuh" #include <cuda.h> #include <cuda_runtime.h> Loading megatron/model/transformer.py +8 −4 Original line number Diff line number Diff line Loading @@ -392,14 +392,18 @@ def get_bias_dropout_add(training): @torch.jit.script def bias_dropout_add_fused_train(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) Loading megatron/text_generation_server.py +4 −3 Original line number Diff line number Diff line Loading @@ -70,9 +70,10 @@ class MegatronGenerate(Resource): temperature = args.temperature if "temperature" in request.get_json(): temperature = request.get_json()["temperature"] if not isinstance(temperature, float) or not \ 0.0 < temperature <= 100.0: return "temperature must be a positive float less than or equal to 100.0" if not (type(temperature) == int or type(temperature) == float): return "temperature must be a positive number less than or equal to 100.0" if not (0.0 < temperature <= 100.0): return "temperature must be a positive number less than or equal to 100.0" top_k = args.top_k if "top_k" in request.get_json(): Loading Loading
megatron/fused_kernels/layer_norm_cuda_kernel.cu +1 −1 Original line number Diff line number Diff line Loading @@ -21,7 +21,7 @@ #include "ATen/ATen.h" #include "ATen/AccumulateType.h" #include "ATen/cuda/CUDAContext.h" #include <THC/THCDeviceUtils.cuh> #include "ATen/cuda/DeviceUtils.cuh" #include <cuda.h> #include <cuda_runtime.h> Loading
megatron/model/transformer.py +8 −4 Original line number Diff line number Diff line Loading @@ -392,14 +392,18 @@ def get_bias_dropout_add(training): @torch.jit.script def bias_dropout_add_fused_train(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) Loading
megatron/text_generation_server.py +4 −3 Original line number Diff line number Diff line Loading @@ -70,9 +70,10 @@ class MegatronGenerate(Resource): temperature = args.temperature if "temperature" in request.get_json(): temperature = request.get_json()["temperature"] if not isinstance(temperature, float) or not \ 0.0 < temperature <= 100.0: return "temperature must be a positive float less than or equal to 100.0" if not (type(temperature) == int or type(temperature) == float): return "temperature must be a positive number less than or equal to 100.0" if not (0.0 < temperature <= 100.0): return "temperature must be a positive number less than or equal to 100.0" top_k = args.top_k if "top_k" in request.get_json(): Loading