Commit 0b8cb87e authored by Max Kudryavtsev's avatar Max Kudryavtsev Committed by Stella Stamenova
Browse files

[MLIR][STD] Add safe scalar constant propagation for FPTruncOp

Perform scalar constant propagation for FPTruncOp only if the resulting value can be represented without precision loss or rounding.

Example:
%cst = constant 1.000000e+00 : f32
%0 = fptrunc %cst : f32 to bf16
-->
%cst = constant 1.000000e+00 : bf16

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D107518
parent 1854db74
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -1220,6 +1220,8 @@ def FPTruncOp : ArithmeticCastOp<"fptrunc"> {
    If the value cannot be exactly represented, it is rounded using the default
    rounding mode. When operating on vectors, casts elementwise.
  }];

  let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
+21 −0
Original line number Diff line number Diff line
@@ -1414,6 +1414,27 @@ bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
}

/// Perform safe const propagation for fptrunc, i.e. only propagate
/// if FP value can be represented without precision loss or rounding.
OpFoldResult FPTruncOp::fold(ArrayRef<Attribute> operands) {
  assert(operands.size() == 1 && "unary operation takes one operand");

  auto constOperand = operands.front();
  if (!constOperand || !constOperand.isa<FloatAttr>())
    return {};

  // Convert to target type via 'double'.
  double sourceValue =
      constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
  auto targetAttr = FloatAttr::get(getType(), sourceValue);

  // Propagate if constant's value does not change after truncation.
  if (sourceValue == targetAttr.getValue().convertToDouble())
    return targetAttr;

  return {};
}

//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
+2 −3
Original line number Diff line number Diff line
@@ -196,11 +196,10 @@ func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x
}

// CHECK-LABEL: @generalize_soft_plus_2d_f32
//      CHECK: %[[C1:.+]] = constant 1.000000e+00 : f64
//      CHECK: %[[C1:.+]] = constant 1.000000e+00 : f32
//      CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
// CHECK-NEXT:   %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32
// CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
// CHECK-NEXT:   %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32
// CHECK-NEXT:   %[[SUM:.+]] = addf %[[C1]], %[[EXP]] : f32
// CHECK-NEXT:   %[[LOG:.+]] = math.log %[[SUM]] : f32
// CHECK-NEXT:   linalg.yield %[[LOG]] : f32
// CHECK-NEXT: -> tensor<16x32xf32>
+19 −0
Original line number Diff line number Diff line
@@ -80,6 +80,25 @@ func @truncConstant(%arg0: i8) -> i16 {
  return %tr : i16
}

// CHECK-LABEL: @truncFPConstant
//       CHECK:   %[[cres:.+]] = constant 1.000000e+00 : bf16
//       CHECK:   return %[[cres]]
func @truncFPConstant() -> bf16 {
  %cst = constant 1.000000e+00 : f32
  %0 = fptrunc %cst : f32 to bf16
  return %0 : bf16
}

// Test that cases with rounding are NOT propagated
// CHECK-LABEL: @truncFPConstantRounding
//       CHECK:   constant 1.444000e+25 : f32
//       CHECK:   fptrunc
func @truncFPConstantRounding() -> bf16 {
  %cst = constant 1.444000e+25 : f32
  %0 = fptrunc %cst : f32 to bf16
  return %0 : bf16
}

// CHECK-LABEL: @tripleAddAdd
//       CHECK:   %[[cres:.+]] = constant 59 : index 
//       CHECK:   %[[add:.+]] = addi %arg0, %[[cres]] : index