Commit 2fe9a342 authored by Adrian Kuegel's avatar Adrian Kuegel
Browse files

[mlir][SCF] Use getResult() instead of static_cast<Value>().

parent 0b3f6ff3
Loading
Loading
Loading
Loading
Loading
+11 −14
Original line number Diff line number Diff line
@@ -68,10 +68,9 @@ TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
  checkUnidimensional(forOp.get());

  OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
      loc, ArrayRef<OpFoldResult>(static_cast<Value>(lb.get())),
      ArrayRef<OpFoldResult>(static_cast<Value>(ub.get())),
      ArrayRef<OpFoldResult>(static_cast<Value>(step.get())), ValueRange(),
      std::nullopt);
      loc, ArrayRef<OpFoldResult>(lb->getResult()),
      ArrayRef<OpFoldResult>(ub->getResult()),
      ArrayRef<OpFoldResult>(step->getResult()), ValueRange(), std::nullopt);
  checkUnidimensional(forallOp.get());

  OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
@@ -87,19 +86,17 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
      b.create<arith::ConstantIndexOp>(loc, 10);
  OwningOpRef<arith::ConstantIndexOp> step =
      b.create<arith::ConstantIndexOp>(loc, 2);
  auto lbValue = static_cast<Value>(lb.get());
  auto ubValue = static_cast<Value>(ub.get());
  auto stepValue = static_cast<Value>(step.get());

  OwningOpRef<scf::ForallOp> forallOp =
      b.create<scf::ForallOp>(loc, ArrayRef<OpFoldResult>({lbValue, lbValue}),
                              ArrayRef<OpFoldResult>({ubValue, ubValue}),
                              ArrayRef<OpFoldResult>({stepValue, stepValue}),
  OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
      loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
      ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
      ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
      ValueRange(), std::nullopt);
  checkMultidimensional(forallOp.get());

  OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
      loc, ValueRange({lbValue, lbValue}), ValueRange({ubValue, ubValue}),
      ValueRange({stepValue, stepValue}), ValueRange());
      loc, ValueRange({lb->getResult(), lb->getResult()}),
      ValueRange({ub->getResult(), ub->getResult()}),
      ValueRange({step->getResult(), step->getResult()}), ValueRange());
  checkMultidimensional(parallelOp.get());
}