Commit 0ee1db2d authored by Florian Hahn's avatar Florian Hahn
Browse files

[X86] Try to avoid casts around logical vector ops recursively.

Currently PromoteMaskArithemtic only looks at a single operation to
skip casts. This means we miss cases where we combine multiple masks.

This patch updates PromoteMaskArithemtic to try to recursively promote
AND/XOR/AND nodes that terminate in truncates of the right size or
constant vectors.

Reviewers: craig.topper, RKSimon, spatel

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D72524
parent 886d2c2c
Loading
Loading
Loading
Loading
+64 −39
Original line number Diff line number Diff line
@@ -39898,33 +39898,36 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) {
  return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y);
}
// On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized
// register. In most cases we actually compare or select YMM-sized registers
// and mixing the two types creates horrible code. This method optimizes
// some of the transition sequences.
// Even with AVX-512 this is still useful for removing casts around logical
// operations on vXi1 mask types.
static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
                                     const X86Subtarget &Subtarget) {
  EVT VT = N->getValueType(0);
  assert(VT.isVector() && "Expected vector type");
// Try to widen AND, OR and XOR nodes to VT in order to remove casts around
// logical operations, like in the example below.
//   or (and (truncate x, truncate y)),
//      (xor (truncate z, build_vector (constants)))
// Given a target type \p VT, we generate
//   or (and x, y), (xor z, zext(build_vector (constants)))
// given x, y and z are of type \p VT. We can do so, if operands are either
// truncates from VT types, the second operand is a vector of constants or can
// be recursively promoted.
static SDValue PromoteMaskArithmetic(SDNode *N, EVT VT, SelectionDAG &DAG,
                                     unsigned Depth) {
  // Limit recursion to avoid excessive compile times.
  if (Depth >= SelectionDAG::MaxRecursionDepth)
    return SDValue();
  assert((N->getOpcode() == ISD::ANY_EXTEND ||
          N->getOpcode() == ISD::ZERO_EXTEND ||
          N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");
  if (N->getOpcode() != ISD::XOR && N->getOpcode() != ISD::AND &&
      N->getOpcode() != ISD::OR)
    return SDValue();
  SDValue Narrow = N->getOperand(0);
  EVT NarrowVT = Narrow.getValueType();
  SDValue N0 = N->getOperand(0);
  SDValue N1 = N->getOperand(1);
  SDLoc DL(N);
  if (Narrow->getOpcode() != ISD::XOR &&
      Narrow->getOpcode() != ISD::AND &&
      Narrow->getOpcode() != ISD::OR)
  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
  if (!TLI.isOperationLegalOrPromote(N->getOpcode(), VT))
    return SDValue();
  SDValue N0  = Narrow->getOperand(0);
  SDValue N1  = Narrow->getOperand(1);
  SDLoc DL(Narrow);
  if (SDValue NN0 = PromoteMaskArithmetic(N0.getNode(), VT, DAG, Depth + 1))
    N0 = NN0;
  else {
    // The Left side has to be a trunc.
    if (N0.getOpcode() != ISD::TRUNCATE)
      return SDValue();
@@ -39933,29 +39936,51 @@ static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
    if (N0.getOperand(0).getValueType() != VT)
      return SDValue();
    N0 = N0.getOperand(0);
  }
  if (SDValue NN1 = PromoteMaskArithmetic(N1.getNode(), VT, DAG, Depth + 1))
    N1 = NN1;
  else {
    // The right side has to be a 'trunc' or a constant vector.
    bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
                    N1.getOperand(0).getValueType() == VT;
  if (!RHSTrunc &&
      !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
    if (!RHSTrunc && !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
      return SDValue();
  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
  if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), VT))
    return SDValue();
  // Set N0 and N1 to hold the inputs to the new wide operation.
  N0 = N0.getOperand(0);
    if (RHSTrunc)
      N1 = N1.getOperand(0);
    else
      N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
  }
  return DAG.getNode(N->getOpcode(), DL, VT, N0, N1);
}
// On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized
// register. In most cases we actually compare or select YMM-sized registers
// and mixing the two types creates horrible code. This method optimizes
// some of the transition sequences.
// Even with AVX-512 this is still useful for removing casts around logical
// operations on vXi1 mask types.
static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
                                     const X86Subtarget &Subtarget) {
  EVT VT = N->getValueType(0);
  assert(VT.isVector() && "Expected vector type");
  SDLoc DL(N);
  assert((N->getOpcode() == ISD::ANY_EXTEND ||
          N->getOpcode() == ISD::ZERO_EXTEND ||
          N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");
  SDValue Narrow = N->getOperand(0);
  EVT NarrowVT = Narrow.getValueType();
  // Generate the wide operation.
  SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, VT, N0, N1);
  unsigned Opcode = N->getOpcode();
  switch (Opcode) {
  SDValue Op = PromoteMaskArithmetic(Narrow.getNode(), VT, DAG, 0);
  if (!Op)
    return SDValue();
  switch (N->getOpcode()) {
  default: llvm_unreachable("Unexpected opcode");
  case ISD::ANY_EXTEND:
    return Op;
+208 −566

File changed.

Preview size limit exceeded, changes collapsed.