Commit 4df8b7a2 authored by slym's avatar slym
Browse files

reflect feedback

parent 3f652469
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -205,11 +205,11 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
    execution in backprop.
    """
    @staticmethod
    def forward(ctx, input, weight, bias, use_bias):
    def forward(ctx, input, weight, bias):
        ctx.save_for_backward(input, weight)
        ctx.use_bias = use_bias
        ctx.use_bias = bias is not None
        output = torch.matmul(input, weight.t())
        if use_bias:
        if bias is not None:
            output = output + bias
        return output

@@ -227,7 +227,7 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
        grad_weight = grad_output.t().matmul(input)
        grad_bias = grad_output.sum(dim=0) if use_bias else None
        handle.wait()
        return grad_input, grad_weight, grad_bias, None
        return grad_input, grad_weight, grad_bias


class ColumnParallelLinear(torch.nn.Module):
@@ -318,7 +318,7 @@ class ColumnParallelLinear(torch.nn.Module):
            input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
            # Maxtrix multiply with asynchronouse all-reduce execution
            output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
                    input_, self.weight, bias, bias is not None)
                    input_, self.weight, bias)
            output_parallel = output_parallel.view(
                    input_shape[0], input_shape[1], output_parallel.shape[1])
        else: