Loading megatron/mpu/layers.py +5 −5 Original line number Diff line number Diff line Loading @@ -205,11 +205,11 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): execution in backprop. """ @staticmethod def forward(ctx, input, weight, bias, use_bias): def forward(ctx, input, weight, bias): ctx.save_for_backward(input, weight) ctx.use_bias = use_bias ctx.use_bias = bias is not None output = torch.matmul(input, weight.t()) if use_bias: if bias is not None: output = output + bias return output Loading @@ -227,7 +227,7 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): grad_weight = grad_output.t().matmul(input) grad_bias = grad_output.sum(dim=0) if use_bias else None handle.wait() return grad_input, grad_weight, grad_bias, None return grad_input, grad_weight, grad_bias class ColumnParallelLinear(torch.nn.Module): Loading Loading @@ -318,7 +318,7 @@ class ColumnParallelLinear(torch.nn.Module): input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2]) # Maxtrix multiply with asynchronouse all-reduce execution output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply( input_, self.weight, bias, bias is not None) input_, self.weight, bias) output_parallel = output_parallel.view( input_shape[0], input_shape[1], output_parallel.shape[1]) else: Loading Loading
megatron/mpu/layers.py +5 −5 Original line number Diff line number Diff line Loading @@ -205,11 +205,11 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): execution in backprop. """ @staticmethod def forward(ctx, input, weight, bias, use_bias): def forward(ctx, input, weight, bias): ctx.save_for_backward(input, weight) ctx.use_bias = use_bias ctx.use_bias = bias is not None output = torch.matmul(input, weight.t()) if use_bias: if bias is not None: output = output + bias return output Loading @@ -227,7 +227,7 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): grad_weight = grad_output.t().matmul(input) grad_bias = grad_output.sum(dim=0) if use_bias else None handle.wait() return grad_input, grad_weight, grad_bias, None return grad_input, grad_weight, grad_bias class ColumnParallelLinear(torch.nn.Module): Loading Loading @@ -318,7 +318,7 @@ class ColumnParallelLinear(torch.nn.Module): input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2]) # Maxtrix multiply with asynchronouse all-reduce execution output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply( input_, self.weight, bias, bias is not None) input_, self.weight, bias) output_parallel = output_parallel.view( input_shape[0], input_shape[1], output_parallel.shape[1]) else: Loading