Unverified Commit 8c8336fc authored by bjacob's avatar bjacob Committed by GitHub
Browse files

Add missing `linalg.batch_vecmat` named op (#70218)

Linalg currently has these named ops:
* `matmul`
* `matvec`
* `vecmat`
* `batch_matmul`
* `batch_matvec`

But it does not have:
* `batch_vecmat`

This PRs adds that for consistency, and I have a short-term need for it
( https://github.com/openxla/iree/issues/15158 ), so not having this
would cause some contortion on my end.
parent 8e00d59d
Loading
Loading
Loading
Loading
+68 −0
Original line number Diff line number Diff line
@@ -1796,6 +1796,74 @@ structured_op: !LinalgStructuredOpConfig
                - !ScalarExpression
                  scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
  name: batch_vecmat
  cpp_class_name: BatchVecmatOp
  doc: |-
    Performs a batched matrix-vector multiplication.

    Numeric casting is performed on the operands to the inner multiply, promoting
    them to the same data type as the accumulator/output.
  implements:
  - LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
  args:
  - !LinalgOperandDefConfig
    name: A
    kind: input_tensor
    type_var: T1
    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
  - !LinalgOperandDefConfig
    name: B
    kind: input_tensor
    type_var: T2
    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
  - !LinalgOperandDefConfig
    name: C
    kind: output_tensor
    type_var: U
    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
  indexing_maps: !LinalgIndexingMapsConfig
    static_indexing_maps:
    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2, d1)>
    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
  iterator_types:
  - parallel
  - parallel
  - reduction
  assignments:
  - !ScalarAssign
    arg: C
    value: !ScalarExpression
      scalar_fn:
        kind: binary
        fn_name: add
        operands:
        - !ScalarExpression
          scalar_arg: C
        - !ScalarExpression
          scalar_fn:
            kind: binary
            fn_name: mul
            operands:
            - !ScalarExpression
              scalar_fn:
                kind: type
                fn_name: cast_signed
                type_var: U
                operands:
                - !ScalarExpression
                  scalar_arg: A
            - !ScalarExpression
              scalar_fn:
                kind: type
                fn_name: cast_signed
                type_var: U
                operands:
                - !ScalarExpression
                  scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
  name: dot
  cpp_class_name: DotOp
+18 −0
Original line number Diff line number Diff line
@@ -517,6 +517,24 @@ def batch_matvec(
    )


@linalg_structured_op
def batch_vecmat(
    A=TensorDef(T1, Batch, S.K),
    B=TensorDef(T2, Batch, S.K, S.N),
    C=TensorDef(U, Batch, S.N, output=True),
):
    """Performs a batched matrix-vector multiplication.

    Numeric casting is performed on the operands to the inner multiply, promoting
    them to the same data type as the accumulator/output.
    """
    domain(D.b, D.n, D.k)
    implements(ContractionOpInterface)
    C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
        U, B[D.b, D.k, D.n]
    )


@linalg_structured_op
def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
    """Performs a dot product of two vectors to a scalar result.
+25 −0
Original line number Diff line number Diff line
@@ -251,6 +251,31 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi

// -----

func.func @generalize_batch_vecmat(%lhs : memref<?x?xi8>, %rhs: memref<?x?x?xi8>,  %out: memref<?x?xf32>) {
  linalg.batch_vecmat ins(%lhs, %rhs: memref<?x?xi8>, memref<?x?x?xi8>)
                     outs(%out: memref<?x?xf32>)
  return
}
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>

// CHECK: @generalize_batch_vecmat

// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?xi8>, memref<?x?x?xi8>)
// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
// CHECK:            %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32
// CHECK:            %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32
// CHECK:            %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
// CHECK:            %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
// CHECK:            linalg.yield %[[ADD]] : f32

// -----

func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
  linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
                             outs(%out: memref<8x8xf32>)