Commit da77a836 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into api_change

parents 397714f5 b31e1296
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh>
#include "ATen/cuda/DeviceUtils.cuh"

#include <cuda.h>
#include <cuda_runtime.h>
+8 −4
Original line number Diff line number Diff line
@@ -392,14 +392,18 @@ def get_bias_dropout_add(training):


@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob):
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob):
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
    return bias_dropout_add(x, bias, residual, prob, False)


+4 −3
Original line number Diff line number Diff line
@@ -70,9 +70,10 @@ class MegatronGenerate(Resource):
        temperature = args.temperature
        if "temperature" in request.get_json():
            temperature = request.get_json()["temperature"]
            if not isinstance(temperature, float) or not \
               0.0 < temperature <= 100.0:
                return "temperature must be a positive float less than or equal to 100.0"
            if not (type(temperature) == int or type(temperature) == float):
                return "temperature must be a positive number less than or equal to 100.0"
            if not (0.0 < temperature <= 100.0):
                return "temperature must be a positive number less than or equal to 100.0"
        
        top_k = args.top_k
        if "top_k" in request.get_json():