Commit 26aff847 authored by peter klausler's avatar peter klausler
Browse files

[flang] Fold COUNT()

Complete folding of the intrinsic reduction function COUNT() for all
cases, including partial reductions with DIM= arguments.

Differential Revision: https://reviews.llvm.org/D109911
parent 47373f94
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ add_flang_library(FortranEvaluate
  fold-integer.cpp
  fold-logical.cpp
  fold-real.cpp
  fold-reduction.cpp
  formatting.cpp
  host.cpp
  initial-image.cpp
+4 −4
Original line number Diff line number Diff line
@@ -492,7 +492,7 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
    // Build and return constant result
    if constexpr (TR::category == TypeCategory::Character) {
      auto len{static_cast<ConstantSubscript>(
          results.size() ? results[0].length() : 0)};
          results.empty() ? 0 : results[0].length())};
      return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
    } else {
      return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
@@ -944,7 +944,7 @@ Expr<T> FoldMINorMAX(
  if (constantArgs.size() != funcRef.arguments().size()) {
    return Expr<T>(std::move(funcRef));
  }
  CHECK(constantArgs.size() > 0);
  CHECK(!constantArgs.empty());
  Expr<T> result{std::move(*constantArgs[0])};
  for (std::size_t i{1}; i < constantArgs.size(); ++i) {
    Extremum<T> extremum{order, result, Expr<T>{std::move(*constantArgs[i])}};
@@ -1075,7 +1075,7 @@ private:
    Expr<T> folded{Fold(context_, common::Clone(expr.value()))};
    if (const auto *c{UnwrapConstantValue<T>(folded)}) {
      // Copy elements in Fortran array element order
      if (c->size() > 0) {
      if (!c->empty()) {
        ConstantSubscripts index{c->lbounds()};
        do {
          elements_.emplace_back(c->At(index));
@@ -1156,7 +1156,7 @@ template <typename T>
std::optional<Expr<T>> AsFlatArrayConstructor(const Expr<T> &expr) {
  if (const auto *c{UnwrapConstantValue<T>(expr)}) {
    ArrayConstructor<T> result{expr};
    if (c->size() > 0) {
    if (!c->empty()) {
      ConstantSubscripts at{c->lbounds()};
      do {
        result.Push(Expr<T>{Constant<T>{c->At(at)}});
+30 −14
Original line number Diff line number Diff line
@@ -174,21 +174,47 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
  return Expr<T>{std::move(funcRef)};
}

// COUNT()
template <typename T>
static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
  static_assert(T::category == TypeCategory::Integer);
  ActualArguments &arg{ref.arguments()};
  if (const Constant<LogicalResult> *mask{arg.empty()
              ? nullptr
              : Folder<LogicalResult>{context}.Folding(arg[0])}) {
    std::optional<ConstantSubscript> dim;
    if (arg.size() > 1 && arg[1]) {
      dim = CheckDIM(context, arg[1], mask->Rank());
      if (!dim) {
        mask = nullptr;
      }
    }
    if (mask) {
      auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
        if (mask->At(at).IsTrue()) {
          element = element.AddSigned(Scalar<T>{1}).value;
        }
      }};
      return Expr<T>{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
    }
  }
  return Expr<T>{std::move(ref)};
}

// for IALL, IANY, & IPARITY
template <typename T>
static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
    Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
    Scalar<T> identity) {
  static_assert(T::category == TypeCategory::Integer);
  using Element = Scalar<T>;
  std::optional<ConstantSubscript> dim;
  if (std::optional<Constant<T>> array{
          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
              /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
    auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
    auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
      element = (element.*operation)(array->At(at));
    }};
    return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
    return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
  }
  return Expr<T>{std::move(ref)};
}
@@ -237,17 +263,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
          cx->u);
    }
  } else if (name == "count") {
    if (!args[1]) { // TODO: COUNT(x,DIM=d)
      if (const auto *constant{UnwrapConstantValue<LogicalResult>(args[0])}) {
        std::int64_t result{0};
        for (const auto &element : constant->values()) {
          if (element.IsTrue()) {
            ++result;
          }
        }
        return Expr<T>{result};
      }
    }
    return FoldCount<T>(context, std::move(funcRef));
  } else if (name == "digits") {
    if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
      return Expr<T>{std::visit(
+1 −1
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ static Expr<T> FoldAllAny(FoldingContext &context, FunctionRef<T> &&ref,
    auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
      element = (element.*operation)(array->At(at));
    }};
    return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
    return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
  }
  return Expr<T>{std::move(ref)};
}
+32 −0
Original line number Diff line number Diff line
//===-- lib/Evaluate/fold-reduction.cpp -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "fold-reduction.h"

namespace Fortran::evaluate {

std::optional<ConstantSubscript> CheckDIM(
    FoldingContext &context, std::optional<ActualArgument> &arg, int rank) {
  if (arg) {
    if (auto *dimConst{Folder<SubscriptInteger>{context}.Folding(arg)}) {
      if (auto dimScalar{dimConst->GetScalarValue()}) {
        auto dim{dimScalar->ToInt64()};
        if (dim >= 1 && dim <= rank) {
          return {dim};
        } else {
          context.messages().Say(
              "DIM=%jd is not valid for an array of rank %d"_err_en_US,
              static_cast<std::intmax_t>(dim), rank);
        }
      }
    }
  }
  return std::nullopt;
}

} // namespace Fortran::evaluate
Loading