Unverified Commit 82867439 authored by Valentin Clement (バレンタイン クレメン)'s avatar Valentin Clement (バレンタイン クレメン) Committed by GitHub
Browse files

[flang][openacc] Allow acc routine at the top level (#69936)

Some compilers allow the `$acc routine(<name>)` to be placed at the
program unit level. To be compatible, this patch enables the use of acc
routine at this level. These acc routine directives must have a name.
parent 93f8e52d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -23,3 +23,4 @@ local:
  warning instead of an error as other compiler accepts it.
* The `if` clause accepts scalar integer expression in addition to scalar
  logical expression.
* `!$acc routine` directive can be placed at the top level. 
+6 −0
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ namespace Fortran {
namespace parser {
struct OpenACCConstruct;
struct OpenACCDeclarativeConstruct;
struct OpenACCRoutineConstruct;
} // namespace parser

namespace semantics {
@@ -71,6 +72,11 @@ void genOpenACCDeclarativeConstruct(AbstractConverter &,
                                    StatementContext &,
                                    const parser::OpenACCDeclarativeConstruct &,
                                    AccRoutineInfoMappingList &);
void genOpenACCRoutineConstruct(AbstractConverter &,
                                Fortran::semantics::SemanticsContext &,
                                mlir::ModuleOp &,
                                const parser::OpenACCRoutineConstruct &,
                                AccRoutineInfoMappingList &);

void finalizeOpenACCRoutineAttachment(mlir::ModuleOp &,
                                      AccRoutineInfoMappingList &);
+14 −2
Original line number Diff line number Diff line
@@ -135,6 +135,7 @@ using Constructs =

using Directives =
    std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
               parser::OpenACCRoutineConstruct,
               parser::OpenACCDeclarativeConstruct, parser::OpenMPConstruct,
               parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective>;

@@ -360,7 +361,8 @@ using ProgramVariant =
    ReferenceVariant<parser::MainProgram, parser::FunctionSubprogram,
                     parser::SubroutineSubprogram, parser::Module,
                     parser::Submodule, parser::SeparateModuleSubprogram,
                     parser::BlockData, parser::CompilerDirective>;
                     parser::BlockData, parser::CompilerDirective,
                     parser::OpenACCRoutineConstruct>;
/// A program is a list of program units.
/// These units can be function like, module like, or block data.
struct ProgramUnit : ProgramVariant {
@@ -763,10 +765,20 @@ struct CompilerDirectiveUnit : public ProgramUnit {
  CompilerDirectiveUnit(const CompilerDirectiveUnit &) = delete;
};

// Top level OpenACC routine directives
struct OpenACCDirectiveUnit : public ProgramUnit {
  OpenACCDirectiveUnit(const parser::OpenACCRoutineConstruct &directive,
                       const PftNode &parent)
      : ProgramUnit{directive, parent}, routine{directive} {};
  OpenACCDirectiveUnit(OpenACCDirectiveUnit &&) = default;
  OpenACCDirectiveUnit(const OpenACCDirectiveUnit &) = delete;
  const parser::OpenACCRoutineConstruct &routine;
};

/// A Program is the top-level root of the PFT.
struct Program {
  using Units = std::variant<FunctionLikeUnit, ModuleLikeUnit, BlockDataUnit,
                             CompilerDirectiveUnit>;
                             CompilerDirectiveUnit, OpenACCDirectiveUnit>;

  Program(semantics::CommonBlockList &&commonBlocks)
      : commonBlocks{std::move(commonBlocks)} {}
+3 −1
Original line number Diff line number Diff line
@@ -262,6 +262,7 @@ struct PauseStmt;
struct OpenACCConstruct;
struct AccEndCombinedDirective;
struct OpenACCDeclarativeConstruct;
struct OpenACCRoutineConstruct;
struct OpenMPConstruct;
struct OpenMPDeclarativeConstruct;
struct OmpEndLoopDirective;
@@ -558,7 +559,8 @@ struct ProgramUnit {
      common::Indirection<FunctionSubprogram>,
      common::Indirection<SubroutineSubprogram>, common::Indirection<Module>,
      common::Indirection<Submodule>, common::Indirection<BlockData>,
      common::Indirection<CompilerDirective>>
      common::Indirection<CompilerDirective>,
      common::Indirection<OpenACCRoutineConstruct>>
      u;
};

+13 −0
Original line number Diff line number Diff line
@@ -316,6 +316,7 @@ public:
                         globalOmpRequiresSymbol = b.symTab.symbol();
                     },
                     [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
                     [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
                 },
                 u);
    }
@@ -328,6 +329,14 @@ public:
              [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
              [&](Fortran::lower::pft::BlockDataUnit &b) {},
              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
              [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
                builder = new fir::FirOpBuilder(bridge.getModule(),
                                                bridge.getKindMap());
                Fortran::lower::genOpenACCRoutineConstruct(
                    *this, bridge.getSemanticsContext(), bridge.getModule(),
                    d.routine, accRoutineInfos);
                builder = nullptr;
              },
          },
          u);
    }
@@ -2362,6 +2371,10 @@ private:
      genFIR(e);
  }

  void genFIR(const Fortran::parser::OpenACCRoutineConstruct &acc) {
    // Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
  }

  void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
    mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
    localSymbols.pushScope();
Loading