Unverified Commit c5771226 authored by Sang Ik Lee's avatar Sang Ik Lee Committed by GitHub
Browse files

[MLIR][XeVM] XeVM to LLVM: Update xevm.truncf handling (#194491)

Add support for more src/dst type combinations.
xevm.truncf no longer expects destination vector size to match source
vector size.
XeVM target: Turn on extension SPV_KHR_bfloat16
parent 1bbfade9
Loading
Loading
Loading
Loading
+155 −41
Original line number Diff line number Diff line
@@ -1129,51 +1129,165 @@ class TruncfToOCLPattern : public OpConversionPattern<TruncfOp> {
    // Supported source and result types are resticted for now.
    auto srcEtype = op.getSrcEtype().getEtype();
    auto dstEtype = op.getDstEtype().getEtype();
    if (auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType())) {
      if (vecSrcTy.getNumElements() != 16)
        return rewriter.notifyMatchFailure(
            op, "Only vector src of 16 elements is supported");
    } else {
    // Currently only 16 input elements are supported as
    //  - Any vector beyond 16 elements not a valid OpenCL vector.
    //  - 2D block load can only load up to 16 16bit elements per lane.
    //      Widest load is 8x16xi32 with 16 lanes, which is 16 16bit
    //      elements per lane.
    //  - mma_mx A and B operands need more than 16 elements per lane
    //
    // Conversion is done in batches depending on the dst type.
    // batch_size =
    //   16 if dst type == fp8
    //   8  if dst type == fp4
    // For num_elem > batch_size
    //   convert batch of batch_size
    //   cast batch to i32 elem type vector
    //   concat batches by shufflevector
    // For num_elem = batch_size
    //   use API for conversion
    // Scalar case is not supported until usage case become clear.
    auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType());
    if (!vecSrcTy) {
      return rewriter.notifyMatchFailure(op, "Scalar src is not supported.");
    }
    if (auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType())) {
      if (vecDstTy.getNumElements() != 16)
    if (vecSrcTy.getNumElements() != 16)
      return rewriter.notifyMatchFailure(
            op, "Only vector dst of 16 elements is supported");
    } else {
          op, "Only vector src of 16 elements is supported");
    auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType());
    if (!vecDstTy)
      return rewriter.notifyMatchFailure(op, "Scalar dst is not supported.");
    Value src = op.getSrc();
    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
        /*other=*/LLVM::ModRefInfo::NoModRef,
        /*argMem=*/LLVM::ModRefInfo::NoModRef,
        /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
        /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
        /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
        /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
    auto funcAttrs = convergentNoUnwindWillReturnAttrs;
    funcAttrs.memEffectsAttr = memAttr;

    // Handle the case where dst type is fp4 first.
    if (dstEtype == TruncfDstElemTypes::E2M1) {
      // Convert 8 elements at a time.
      // To convert 8 elements, vector<8xf16>:
      // Use:
      // uint __builtin_IB_dnscl_hf16(uint, uint, 1, 0)
      // uint __builtin_IB_dnscl_hf16(uint, uint, 1, 3)
      // llvm.or
      Value cast = LLVM::BitcastOp::create(
          rewriter, op.getLoc(), VectorType::get(8, rewriter.getI32Type()),
          src);

      std::string fnName = "__builtin_IB_dnscl_";
      fnName += (srcEtype == TruncfSrcElemTypes::F16) ? "hf16" : "bf16";
      auto genDnscl = [&](Value input, Value idx0, Value idx1, Value dstTy,
                          Value mode) -> Value {
        Value arg1 =
            LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx0)
                ->getResult(0);
        Value arg2 =
            LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx1)
                ->getResult(0);
        SmallVector<Type> argTypes{arg1.getType(), arg2.getType(),
                                   dstTy.getType(), mode.getType()};
        SmallVector<Value> args{arg1, arg2, dstTy, mode};
        Value dnscl = createDeviceFunctionCall(
                          rewriter, fnName, rewriter.getI32Type(), argTypes,
                          args, {}, funcAttrs, op.getOperation())
                          ->getResult(0);
        return dnscl;
      };

      Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),
                                            rewriter.getI32Type(), 0);
      Value one = LLVM::ConstantOp::create(rewriter, op.getLoc(),
                                           rewriter.getI32Type(), 1);
      Value two = LLVM::ConstantOp::create(rewriter, op.getLoc(),
                                           rewriter.getI32Type(), 2);
      Value three = LLVM::ConstantOp::create(rewriter, op.getLoc(),
                                             rewriter.getI32Type(), 3);
      Value even = genDnscl(cast, zero, two, one, zero);
      Value odd = genDnscl(cast, one, three, one, two);
      Value firstHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
      Value four = LLVM::ConstantOp::create(rewriter, op.getLoc(),
                                            rewriter.getI32Type(), 4);
      Value five = LLVM::ConstantOp::create(rewriter, op.getLoc(),
                                            rewriter.getI32Type(), 5);
      Value six = LLVM::ConstantOp::create(rewriter, op.getLoc(),
                                           rewriter.getI32Type(), 6);
      Value seven = LLVM::ConstantOp::create(rewriter, op.getLoc(),
                                             rewriter.getI32Type(), 7);
      even = genDnscl(cast, four, six, one, zero);
      odd = genDnscl(cast, five, seven, one, two);
      Value secondHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
      // Create vector<2xi32> from two i32 values and then bitcast to
      // vector<8xi8> to match the dst type.
      Value combined = LLVM::UndefOp::create(
          rewriter, op.getLoc(), VectorType::get(2, rewriter.getI32Type()));
      combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
                                               firstHalf, zero)
                     ->getResult(0);
      combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
                                               secondHalf, one)
                     ->getResult(0);
      Value result =
          LLVM::BitcastOp::create(rewriter, op.getLoc(), vecDstTy, combined);
      rewriter.replaceOp(op, result);
      return success();
    }

    // Handle the case where dst type is fp8.
    // BF16 type needs some preprocessing before conversion,
    // First extended to F32 and then truncated to F16.
    if (srcEtype == TruncfSrcElemTypes::BF16) {
      // Step 1: Extend to F32
      // Use float16 __builtin_IB_bftof_16(short16)
      src = LLVM::BitcastOp::create(
          rewriter, op.getLoc(),
          VectorType::get(vecSrcTy.getShape(), rewriter.getI16Type()), src);
      std::string fnName = "__builtin_IB_bftof_16";
      SmallVector<Type> argTypes{src.getType()};
      SmallVector<Value> args{src};
      Type resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF32Type());
      src = createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args,
                                     {}, funcAttrs, op.getOperation())
                ->getResult(0);
      // Step 2: Truncf to F16
      // Use half16 convert_half16(float16)
      std::string truncFnName = "convert_half16";
      SmallVector<Type> truncArgTypes{src.getType()};
      SmallVector<Value> truncArgs{src};
      truncFnName = mangle(truncFnName, truncArgTypes);
      resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF16Type());
      src =
          createDeviceFunctionCall(rewriter, truncFnName, resTy, truncArgTypes,
                                   truncArgs, {}, funcAttrs, op.getOperation())
              ->getResult(0);
    }
    if (srcEtype == TruncfSrcElemTypes::F16 &&
        dstEtype == TruncfDstElemTypes::BF8) {
      // BF8 is just F16 with lower 8 bits of mantessa discard.
      //     Signbit Exponent Mantessa
      // BF8 1       5        2
      // F16 1       5        10
      // Xe arch is Little Endian so BF8 is just the second byte of the two
      // byte representation used for F16
      auto firstHalf =
          LLVM::ShuffleVectorOp::create(rewriter, op.getLoc(), op.getSrc(),
                                        op.getSrc(), {0, 1, 2, 3, 4, 5, 6, 7});
      auto secondHalf = LLVM::ShuffleVectorOp::create(
          rewriter, op.getLoc(), op.getSrc(), op.getSrc(),
          {8, 9, 10, 11, 12, 13, 14, 15});
      auto firstHalfCasted = LLVM::BitcastOp::create(
          rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
          firstHalf);
      auto secondHalfCasted = LLVM::BitcastOp::create(
          rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
          secondHalf);
      // Gather just the second bytes from every two byte F16 values
      auto resFirstHalf = LLVM::ShuffleVectorOp::create(
          rewriter, op.getLoc(), firstHalfCasted, firstHalfCasted,
          {1, 3, 5, 7, 9, 11, 13, 15});
      auto resSecondHalf = LLVM::ShuffleVectorOp::create(
          rewriter, op.getLoc(), secondHalfCasted, secondHalfCasted,
          {1, 3, 5, 7, 9, 11, 13, 15});
      auto res = LLVM::ShuffleVectorOp::create(
          rewriter, op.getLoc(), resFirstHalf, resSecondHalf,
          {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
      rewriter.replaceOp(op, res);
    if (dstEtype == TruncfDstElemTypes::BF8) { // Float8E5M2Type
      // Use char16 __builtin_IB_hftobf8_16(half16)
      std::string fnName = "__builtin_IB_hftobf8_16";
      SmallVector<Type> argTypes{src.getType()};
      SmallVector<Value> args{src};
      Value result =
          createDeviceFunctionCall(rewriter, fnName, vecDstTy, argTypes, args,
                                   {}, funcAttrs, op.getOperation())
              ->getResult(0);

      rewriter.replaceOp(op, result);
    } else if (dstEtype == TruncfDstElemTypes::F8) { // Float8E4M3FNType
      // Use char16 __builtin_IB_hftohf8_16(half16)
      std::string fnName = "__builtin_IB_hftohf8_16";
      SmallVector<Type> argTypes{src.getType()};
      SmallVector<Value> args{src};
      Value result =
          createDeviceFunctionCall(rewriter, fnName, vecDstTy, argTypes, args,
                                   {}, funcAttrs, op.getOperation())
              ->getResult(0);

      rewriter.replaceOp(op, result);
    } else {
      return rewriter.notifyMatchFailure(
          op, "Unsupported src, dst element type pair.");
+0 −3
Original line number Diff line number Diff line
@@ -366,9 +366,6 @@ LogicalResult TruncfOp::verify() {
  if (isa<VectorType>(srcTy)) {
    VectorType srcVecTy = dyn_cast<VectorType>(srcTy);
    VectorType dstVecTy = dyn_cast<VectorType>(dstTy);
    if (srcVecTy.getNumElements() != dstVecTy.getNumElements())
      return emitOpError(
          "src and dst vector types should have the same number of elements");
    if (srcVecTy.getElementTypeBitWidth() <= dstVecTy.getElementTypeBitWidth())
      return emitError(
          "dst element bitwidth should be less than src element bitwidth");
+1 −0
Original line number Diff line number Diff line
@@ -462,6 +462,7 @@ void SPIRVSerializer::init() {
#if LLVM_HAS_SPIRV_TARGET
static const std::vector<std::string> getDefaultSPIRVExtensions() {
  return {
      "SPV_KHR_bfloat16",
      "SPV_EXT_relaxed_printf_string_address_space",
      "SPV_INTEL_cache_controls",
      "SPV_INTEL_variable_length_array",
+154 −8
Original line number Diff line number Diff line
// RUN: mlir-opt --convert-xevm-to-llvm --split-input-file %s | FileCheck %s

// CHECK: llvm.func spir_funccc @__builtin_IB_hftobf8_16(vector<16xf16>) -> vector<16xi8>
// CHECK-SAME: attributes {convergent, memory_effects = #llvm.memory_effects<other = none,
// CHECK-SAME:   argMem = none, inaccessibleMem = none, errnoMem = none,
// CHECK-SAME:   targetMem0 = none, targetMem1 = none>, no_unwind, will_return}
// CHECK-LABEL: llvm.func @truncf_f16_to_bf8
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf16>
llvm.func @truncf_f16_to_bf8(%src: vector<16xf16>) -> vector<16xi8> {
  // CHECK:  %[[VAR0:.*]] = llvm.shufflevector %[[ARG0]], %[[ARG0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<16xf16>
  // CHECK:  %[[VAR1:.*]] = llvm.shufflevector %[[ARG0]], %[[ARG0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>
  // CHECK:  %[[VAR2:.*]] = llvm.bitcast %[[VAR0]] : vector<8xf16> to vector<16xi8>
  // CHECK:  %[[VAR3:.*]] = llvm.bitcast %[[VAR1]] : vector<8xf16> to vector<16xi8>
  // CHECK:  %[[VAR4:.*]] = llvm.shufflevector %[[VAR2]], %[[VAR2]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi8>
  // CHECK:  %[[VAR5:.*]] = llvm.shufflevector %[[VAR3]], %[[VAR3]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi8>
  // CHECK:  %[[VAR6:.*]] = llvm.shufflevector %4, %5 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
  // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @__builtin_IB_hftobf8_16(%[[ARG0]])
  // CHECK-SAME: {convergent, function_type = !llvm.func<vector<16xi8> (vector<16xf16>)>,
  // CHECK-SAME: linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none,
  // CHECK-SAME:   argMem = none, inaccessibleMem = none, errnoMem = none,
  // CHECK-SAME:   targetMem0 = none, targetMem1 = none>,
  // CHECK-SAME: no_unwind, sym_name = "__builtin_IB_hftobf8_16",
  // CHECK-SAME: visibility_ = 0 : i64, will_return} :
  // CHECK-SAME: (vector<16xf16>) -> vector<16xi8>
  %dst = xevm.truncf %src { src_etype = f16, dst_etype = bf8 } : (vector<16xf16>) -> vector<16xi8>
  llvm.return %dst : vector<16xi8>
}

// -----

// CHECK-LABEL: llvm.func spir_funccc @__builtin_IB_sub_group16_bdpas_f_f_bf8_bf8_8_8
// CHECK: llvm.func spir_funccc @__builtin_IB_hftohf8_16(vector<16xf16>) -> vector<16xi8>
// CHECK-SAME: attributes {convergent, memory_effects = #llvm.memory_effects<other = none,
// CHECK-SAME:   argMem = none, inaccessibleMem = none, errnoMem = none,
// CHECK-SAME:   targetMem0 = none, targetMem1 = none>, no_unwind, will_return}
// CHECK-LABEL: llvm.func @truncf_f16_to_hf8
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf16>
llvm.func @truncf_f16_to_hf8(%src: vector<16xf16>) -> vector<16xi8> {
  // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @__builtin_IB_hftohf8_16(%[[ARG0]])
  // CHECK-SAME: {convergent, function_type = !llvm.func<vector<16xi8> (vector<16xf16>)>,
  // CHECK-SAME: linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none,
  // CHECK-SAME:   argMem = none, inaccessibleMem = none, errnoMem = none,
  // CHECK-SAME:   targetMem0 = none, targetMem1 = none>,
  // CHECK-SAME: no_unwind, sym_name = "__builtin_IB_hftohf8_16",
  // CHECK-SAME: visibility_ = 0 : i64, will_return} :
  // CHECK-SAME: (vector<16xf16>) -> vector<16xi8>
  %dst = xevm.truncf %src { src_etype = f16, dst_etype = f8 } : (vector<16xf16>) -> vector<16xi8>
  llvm.return %dst : vector<16xi8>
}

// -----

// CHECK: llvm.func spir_funccc @__builtin_IB_hftobf8_16(vector<16xf16>) -> vector<16xi8>
// CHECK: llvm.func spir_funccc @_Z14convert_half16Dv16_f(vector<16xf32>) -> vector<16xf16>
// CHECK: llvm.func spir_funccc @__builtin_IB_bftof_16(vector<16xi16>) -> vector<16xf32>
// CHECK-LABEL: llvm.func @truncf_bf16_to_bf8
// CHECK-SAME: %[[ARG0:.*]]: vector<16xbf16>
llvm.func @truncf_bf16_to_bf8(%src: vector<16xbf16>) -> vector<16xi8> {
  // CHECK: %[[VAR0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xbf16> to vector<16xi16>
  // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @__builtin_IB_bftof_16(%[[VAR0]])
  // CHECK-SAME: : (vector<16xi16>) -> vector<16xf32>
  // CHECK: %[[VAR2:.*]] = llvm.call spir_funccc @_Z14convert_half16Dv16_f(%[[VAR1]])
  // CHECK-SAME: : (vector<16xf32>) -> vector<16xf16>
  // CHECK: %[[VAR3:.*]] = llvm.call spir_funccc @__builtin_IB_hftobf8_16(%[[VAR2]])
  // CHECK-SAME: : (vector<16xf16>) -> vector<16xi8>
  %dst = xevm.truncf %src { src_etype = bf16, dst_etype = bf8 } : (vector<16xbf16>) -> vector<16xi8>
  llvm.return %dst : vector<16xi8>
}

// -----

// CHECK: llvm.func spir_funccc @__builtin_IB_hftohf8_16(vector<16xf16>) -> vector<16xi8>
// CHECK: llvm.func spir_funccc @_Z14convert_half16Dv16_f(vector<16xf32>) -> vector<16xf16>
// CHECK: llvm.func spir_funccc @__builtin_IB_bftof_16(vector<16xi16>) -> vector<16xf32>
// CHECK-LABEL: llvm.func @truncf_bf16_to_hf8
// CHECK-SAME: %[[ARG0:.*]]: vector<16xbf16>
llvm.func @truncf_bf16_to_hf8(%src: vector<16xbf16>) -> vector<16xi8> {
  // CHECK: %[[VAR0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xbf16> to vector<16xi16>
  // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @__builtin_IB_bftof_16(%[[VAR0]])
  // CHECK-SAME: : (vector<16xi16>) -> vector<16xf32>
  // CHECK: %[[VAR2:.*]] = llvm.call spir_funccc @_Z14convert_half16Dv16_f(%[[VAR1]])
  // CHECK-SAME: : (vector<16xf32>) -> vector<16xf16>
  // CHECK: %[[VAR3:.*]] = llvm.call spir_funccc @__builtin_IB_hftohf8_16(%[[VAR2]])
  // CHECK-SAME: : (vector<16xf16>) -> vector<16xi8>
  %dst = xevm.truncf %src { src_etype = bf16, dst_etype = f8 } : (vector<16xbf16>) -> vector<16xi8>
  llvm.return %dst : vector<16xi8>
}

// -----

// CHECK: llvm.func spir_funccc @__builtin_IB_dnscl_hf16(i32, i32, i32, i32) -> i32
// CHECK-LABEL: llvm.func @truncf_f16_to_e2m1
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf16>
llvm.func @truncf_f16_to_e2m1(%src: vector<16xf16>) -> vector<8xi8> {
  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : vector<2xi32>
  // CHECK: %[[C7:.*]] = llvm.mlir.constant(7 : i32) : i32
  // CHECK: %[[C6:.*]] = llvm.mlir.constant(6 : i32) : i32
  // CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : i32) : i32
  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
  // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i32) : i32
  // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
  // CHECK: %[[BC:.*]] = llvm.bitcast %[[ARG0]] : vector<16xf16> to vector<8xi32>
  // CHECK: %[[E0:.*]] = llvm.extractelement %[[BC]][%[[C0]] : i32] : vector<8xi32>
  // CHECK: %[[E2:.*]] = llvm.extractelement %[[BC]][%[[C2]] : i32] : vector<8xi32>
  // CHECK: %[[CALL0:.*]] = llvm.call spir_funccc @__builtin_IB_dnscl_hf16(%[[E0]], %[[E2]], %[[C1]], %[[C0]])
  // CHECK-SAME: : (i32, i32, i32, i32) -> i32
  // CHECK: %[[E1:.*]] = llvm.extractelement %[[BC]][%[[C1]] : i32] : vector<8xi32>
  // CHECK: %[[E3:.*]] = llvm.extractelement %[[BC]][%[[C3]] : i32] : vector<8xi32>
  // CHECK: %[[CALL1:.*]] = llvm.call spir_funccc @__builtin_IB_dnscl_hf16(%[[E1]], %[[E3]], %[[C1]], %[[C2]])
  // CHECK-SAME: : (i32, i32, i32, i32) -> i32
  // CHECK: %[[OR0:.*]] = llvm.or %[[CALL0]], %[[CALL1]] : i32
  // CHECK: %[[E4:.*]] = llvm.extractelement %[[BC]][%[[C4]] : i32] : vector<8xi32>
  // CHECK: %[[E6:.*]] = llvm.extractelement %[[BC]][%[[C6]] : i32] : vector<8xi32>
  // CHECK: %[[CALL2:.*]] = llvm.call spir_funccc @__builtin_IB_dnscl_hf16(%[[E4]], %[[E6]], %[[C1]], %[[C0]])
  // CHECK-SAME: : (i32, i32, i32, i32) -> i32
  // CHECK: %[[E5:.*]] = llvm.extractelement %[[BC]][%[[C5]] : i32] : vector<8xi32>
  // CHECK: %[[E7:.*]] = llvm.extractelement %[[BC]][%[[C7]] : i32] : vector<8xi32>
  // CHECK: %[[CALL3:.*]] = llvm.call spir_funccc @__builtin_IB_dnscl_hf16(%[[E5]], %[[E7]], %[[C1]], %[[C2]])
  // CHECK-SAME: : (i32, i32, i32, i32) -> i32
  // CHECK: %[[OR1:.*]] = llvm.or %[[CALL2]], %[[CALL3]] : i32
  // CHECK: %[[INS0:.*]] = llvm.insertelement %[[OR0]], %[[UNDEF]][%[[C0]] : i32] : vector<2xi32>
  // CHECK: %[[INS1:.*]] = llvm.insertelement %[[OR1]], %[[INS0]][%[[C1]] : i32] : vector<2xi32>
  // CHECK: %[[RES:.*]] = llvm.bitcast %[[INS1]] : vector<2xi32> to vector<8xi8>
  %dst = xevm.truncf %src { src_etype = f16, dst_etype = e2m1 } : (vector<16xf16>) -> vector<8xi8>
  llvm.return %dst : vector<8xi8>
}

// -----

// CHECK: llvm.func spir_funccc @__builtin_IB_dnscl_bf16(i32, i32, i32, i32) -> i32
// CHECK-LABEL: llvm.func @truncf_bf16_to_e2m1
// CHECK-SAME: %[[ARG0:.*]]: vector<16xbf16>
llvm.func @truncf_bf16_to_e2m1(%src: vector<16xbf16>) -> vector<8xi8> {
  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : vector<2xi32>
  // CHECK: %[[C7:.*]] = llvm.mlir.constant(7 : i32) : i32
  // CHECK: %[[C6:.*]] = llvm.mlir.constant(6 : i32) : i32
  // CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : i32) : i32
  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
  // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i32) : i32
  // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
  // CHECK: %[[BC:.*]] = llvm.bitcast %[[ARG0]] : vector<16xbf16> to vector<8xi32>
  // CHECK: %[[E0:.*]] = llvm.extractelement %[[BC]][%[[C0]] : i32] : vector<8xi32>
  // CHECK: %[[E2:.*]] = llvm.extractelement %[[BC]][%[[C2]] : i32] : vector<8xi32>
  // CHECK: %[[CALL0:.*]] = llvm.call spir_funccc @__builtin_IB_dnscl_bf16(%[[E0]], %[[E2]], %[[C1]], %[[C0]])
  // CHECK-SAME: : (i32, i32, i32, i32) -> i32
  // CHECK: %[[E1:.*]] = llvm.extractelement %[[BC]][%[[C1]] : i32] : vector<8xi32>
  // CHECK: %[[E3:.*]] = llvm.extractelement %[[BC]][%[[C3]] : i32] : vector<8xi32>
  // CHECK: %[[CALL1:.*]] = llvm.call spir_funccc @__builtin_IB_dnscl_bf16(%[[E1]], %[[E3]], %[[C1]], %[[C2]])
  // CHECK-SAME: : (i32, i32, i32, i32) -> i32
  // CHECK: %[[OR0:.*]] = llvm.or %[[CALL0]], %[[CALL1]] : i32
  // CHECK: %[[E4:.*]] = llvm.extractelement %[[BC]][%[[C4]] : i32] : vector<8xi32>
  // CHECK: %[[E6:.*]] = llvm.extractelement %[[BC]][%[[C6]] : i32] : vector<8xi32>
  // CHECK: %[[CALL2:.*]] = llvm.call spir_funccc @__builtin_IB_dnscl_bf16(%[[E4]], %[[E6]], %[[C1]], %[[C0]])
  // CHECK-SAME: : (i32, i32, i32, i32) -> i32
  // CHECK: %[[E5:.*]] = llvm.extractelement %[[BC]][%[[C5]] : i32] : vector<8xi32>
  // CHECK: %[[E7:.*]] = llvm.extractelement %[[BC]][%[[C7]] : i32] : vector<8xi32>
  // CHECK: %[[CALL3:.*]] = llvm.call spir_funccc @__builtin_IB_dnscl_bf16(%[[E5]], %[[E7]], %[[C1]], %[[C2]])
  // CHECK-SAME: : (i32, i32, i32, i32) -> i32
  // CHECK: %[[OR1:.*]] = llvm.or %[[CALL2]], %[[CALL3]] : i32
  // CHECK: %[[INS0:.*]] = llvm.insertelement %[[OR0]], %[[UNDEF]][%[[C0]] : i32] : vector<2xi32>
  // CHECK: %[[INS1:.*]] = llvm.insertelement %[[OR1]], %[[INS0]][%[[C1]] : i32] : vector<2xi32>
  // CHECK: %[[RES:.*]] = llvm.bitcast %[[INS1]] : vector<2xi32> to vector<8xi8>
  %dst = xevm.truncf %src { src_etype = bf16, dst_etype = e2m1 } : (vector<16xbf16>) -> vector<8xi8>
  llvm.return %dst : vector<8xi8>
}

// -----

// CHECK: llvm.func spir_funccc @__builtin_IB_sub_group16_bdpas_f_f_bf8_bf8_8_8
// CHECK-SAME: (vector<8xf32>, vector<8xi16>, vector<8xi32>, i8, i8) -> vector<8xf32>
// CHECK-SAME:   attributes {convergent, memory_effects = #llvm.memory_effects<other = none,
// CHECK-SAME:   argMem = none, inaccessibleMem = none, errnoMem = none,
+0 −8
Original line number Diff line number Diff line
@@ -2027,14 +2027,6 @@ llvm.func @invalid_xevm_truncf_1(%arg0: vector<8xf16>) {

// -----

llvm.func @invalid_xevm_truncf_1(%arg0: vector<8xf16>) {
  // expected-error@+1 {{op src and dst vector types should have the same number of elements}}
  %0 = xevm.truncf %arg0 { src_etype = f16, dst_etype = bf8 } : (vector<8xf16>) -> vector<4xi8>
  llvm.return
}

// -----

llvm.func @invalid_xevm_mma_mx(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>, %scale_a: vector<2xi8>, %scale_b: vector<2xi8>) -> vector<8xf32> {
  // expected-error@+1 {{op type of C operand must match result type}}
  %c_result = xevm.mma_mx %loaded_a, %loaded_b_casted, %scale_a, %scale_b, %loaded_c_casted { shape=<m=8, n=16, k=64>,
Loading