Commit 7476e569 authored by River Riddle's avatar River Riddle
Browse files

[mlir][Pass] Enable printing pass options as part of `-help`.

Summary:
This revision adds support for printing pass options as part of the normal help description. This also moves registered passes and pipelines into different sections of the help.

Example:
```
  Compiler passes to run
    --pass-pipeline                                     -   ...
    Passes:
      --affine-data-copy-generate                       -   ...
      --convert-gpu-to-spirv                            -   ...
        --workgroup-size=<long>                         - ...
      --test-options-pass                               -   ...
        --list=<int>                                    - ...
        --string=<string>                               - ...
        --string-list=<string>                          - ...
    Pass Pipelines:
      --test-options-pass-pipeline                      -   ...
        --list=<int>                                    - ...
        --string=<string>                               - ...
        --string-list=<string>                          - ...
```

Differential Revision: https://reviews.llvm.org/D74246
parent ae391054
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -121,6 +121,7 @@ public:
protected:
  explicit Pass(const PassID *passID, Optional<StringRef> opName = llvm::None)
      : passID(passID), opName(opName) {}
  Pass(const Pass &other) : Pass(other.passID, other.opName) {}

  /// Returns the current pass state.
  detail::PassExecutionState &getPassState() {
@@ -178,6 +179,9 @@ private:

  /// Allow access to 'clone' and 'run'.
  friend class OpPassManager;

  /// Allow access to 'passOptions'.
  friend class PassInfo;
};

//===----------------------------------------------------------------------===//
+10 −0
Original line number Diff line number Diff line
@@ -182,6 +182,9 @@ public:
  };

  PassOptions() = default;
  /// Delete the copy constructor to avoid copying the internal options map.
  PassOptions(const PassOptions &) = delete;
  PassOptions(PassOptions &&) = delete;

  /// Copy the option values from 'other' into 'this', where 'other' has the
  /// same options as 'this'.
@@ -196,6 +199,13 @@ public:
  /// 'parseFromString'.
  void print(raw_ostream &os);

  /// Print the help string for the options held by this struct. `descIndent` is
  /// the indent that the descriptions should be aligned.
  void printHelp(size_t indent, size_t descIndent) const;

  /// Return the maximum width required when printing the help string.
  size_t getOptionWidth() const;

private:
  /// A list of all of the opaque options.
  std::vector<OptionBase *> options;
+55 −27
Original line number Diff line number Diff line
@@ -21,6 +21,10 @@ namespace mlir {
class OpPassManager;
class Pass;

namespace detail {
class PassOptions;
} // end namespace detail

/// A registry function that adds passes to the given pass manager. This should
/// also parse options and return success() if parsing succeeded.
using PassRegistryFunction =
@@ -55,28 +59,45 @@ public:
  /// Returns a description for the pass, this never returns null.
  StringRef getPassDescription() const { return description; }

  /// Print the help information for this pass. This includes the argument,
  /// description, and any pass options. `descIndent` is the indent that the
  /// descriptions should be aligned.
  void printHelpStr(size_t indent, size_t descIndent) const;

  /// Return the maximum width required when printing the options of this entry.
  size_t getOptionWidth() const;

protected:
  PassRegistryEntry(StringRef arg, StringRef description,
                    const PassRegistryFunction &builder)
      : arg(arg), description(description), builder(builder) {}
  PassRegistryEntry(
      StringRef arg, StringRef description, const PassRegistryFunction &builder,
      std::function<void(function_ref<void(const detail::PassOptions &)>)>
          optHandler)
      : arg(arg), description(description), builder(builder),
        optHandler(optHandler) {}

private:
  // The argument with which to invoke the pass via mlir-opt.
  /// The argument with which to invoke the pass via mlir-opt.
  StringRef arg;

  // Description of the pass.
  /// Description of the pass.
  StringRef description;

  // Function to register this entry to a pass manager pipeline.
  /// Function to register this entry to a pass manager pipeline.
  PassRegistryFunction builder;

  /// Function to invoke a handler for a pass options instance.
  std::function<void(function_ref<void(const detail::PassOptions &)>)>
      optHandler;
};

/// A structure to represent the information of a registered pass pipeline.
class PassPipelineInfo : public PassRegistryEntry {
public:
  PassPipelineInfo(StringRef arg, StringRef description,
                   const PassRegistryFunction &builder)
      : PassRegistryEntry(arg, description, builder) {}
  PassPipelineInfo(
      StringRef arg, StringRef description, const PassRegistryFunction &builder,
      std::function<void(function_ref<void(const detail::PassOptions &)>)>
          optHandler)
      : PassRegistryEntry(arg, description, builder, optHandler) {}
};

/// A structure to represent the information for a derived pass class.
@@ -94,8 +115,10 @@ public:

/// Register a specific dialect pipeline registry function with the system,
/// typically used through the PassPipelineRegistration template.
void registerPassPipeline(StringRef arg, StringRef description,
                          const PassRegistryFunction &function);
void registerPassPipeline(
    StringRef arg, StringRef description, const PassRegistryFunction &function,
    std::function<void(function_ref<void(const detail::PassOptions &)>)>
        optHandler);

/// Register a specific dialect pass allocator function with the system,
/// typically used through the PassRegistration template.
@@ -113,7 +136,6 @@ void registerPass(StringRef arg, StringRef description, const PassID *passID,
///   static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");
///
template <typename ConcretePass> struct PassRegistration {

  PassRegistration(StringRef arg, StringRef description,
                   const PassAllocatorFunction &constructor) {
    registerPass(arg, description, PassID::getID<ConcretePass>(), constructor);
@@ -142,13 +164,17 @@ struct PassPipelineRegistration {
  PassPipelineRegistration(
      StringRef arg, StringRef description,
      std::function<void(OpPassManager &, const Options &options)> builder) {
    registerPassPipeline(arg, description,
    registerPassPipeline(
        arg, description,
        [builder](OpPassManager &pm, StringRef optionsStr) {
          Options options;
          if (failed(options.parseFromString(optionsStr)))
            return failure();
          builder(pm, options);
          return success();
        },
        [](function_ref<void(const detail::PassOptions &)> optHandler) {
          optHandler(Options());
        });
  }
};
@@ -158,13 +184,15 @@ struct PassPipelineRegistration {
template <> struct PassPipelineRegistration<EmptyPipelineOptions> {
  PassPipelineRegistration(StringRef arg, StringRef description,
                           std::function<void(OpPassManager &)> builder) {
    registerPassPipeline(arg, description,
    registerPassPipeline(
        arg, description,
        [builder](OpPassManager &pm, StringRef optionsStr) {
          if (!optionsStr.empty())
            return failure();
          builder(pm);
          return success();
                         });
        },
        [](function_ref<void(const detail::PassOptions &)>) {});
  }
};

+126 −20
Original line number Diff line number Diff line
@@ -35,13 +35,48 @@ buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
  };
}

/// Utility to print the help string for a specific option.
void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
                     size_t descIndent, bool isTopLevel) {
  size_t numSpaces = descIndent - indent - 4;
  llvm::outs().indent(indent)
      << "--" << llvm::left_justify(arg, numSpaces) << "-   " << desc << '\n';
}

//===----------------------------------------------------------------------===//
// PassRegistry
//===----------------------------------------------------------------------===//

/// Print the help information for this pass. This includes the argument,
/// description, and any pass options. `descIndent` is the indent that the
/// descriptions should be aligned.
void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
  printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
                  /*isTopLevel=*/true);
  // If this entry has options, print the help for those as well.
  optHandler([=](const PassOptions &options) {
    options.printHelp(indent, descIndent);
  });
}

/// Return the maximum width required when printing the options of this
/// entry.
size_t PassRegistryEntry::getOptionWidth() const {
  size_t maxLen = 0;
  optHandler([&](const PassOptions &options) mutable {
    maxLen = options.getOptionWidth() + 2;
  });
  return maxLen;
}

//===----------------------------------------------------------------------===//
// PassPipelineInfo
//===----------------------------------------------------------------------===//

void mlir::registerPassPipeline(StringRef arg, StringRef description,
                                const PassRegistryFunction &function) {
  PassPipelineInfo pipelineInfo(arg, description, function);
void mlir::registerPassPipeline(
    StringRef arg, StringRef description, const PassRegistryFunction &function,
    std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
  PassPipelineInfo pipelineInfo(arg, description, function, optHandler);
  bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
  assert(inserted && "Pass pipeline registered multiple times");
  (void)inserted;
@@ -53,7 +88,12 @@ void mlir::registerPassPipeline(StringRef arg, StringRef description,

PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID,
                   const PassAllocatorFunction &allocator)
    : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)) {}
    : PassRegistryEntry(
          arg, description, buildDefaultRegistryFn(allocator),
          // Use a temporary pass to provide an options instance.
          [=](function_ref<void(const PassOptions &)> optHandler) {
            optHandler(allocator()->passOptions);
          }) {}

void mlir::registerPass(StringRef arg, StringRef description,
                        const PassID *passID,
@@ -137,20 +177,46 @@ void detail::PassOptions::print(raw_ostream &os) {
    return;

  // Sort the options to make the ordering deterministic.
  SmallVector<OptionBase *, 4> orderedOptions(options.begin(), options.end());
  llvm::array_pod_sort(orderedOptions.begin(), orderedOptions.end(),
                       [](OptionBase *const *lhs, OptionBase *const *rhs) {
                         return (*lhs)->getArgStr().compare(
                             (*rhs)->getArgStr());
                       });
  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
    return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
  };
  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);

  // Interleave the options with ' '.
  os << '{';
  interleave(
      orderedOptions, os, [&](OptionBase *option) { option->print(os); }, " ");
      orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
  os << '}';
}

/// Print the help string for the options held by this struct. `descIndent` is
/// the indent within the stream that the descriptions should be aligned.
void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
  // Sort the options to make the ordering deterministic.
  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
    return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
  };
  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
  for (OptionBase *option : orderedOps) {
    // TODO(riverriddle) printOptionInfo assumes a specific indent and will
    // print options with values with incorrect indentation. We should add
    // support to llvm::cl::Option for passing in a base indent to use when
    // printing.
    llvm::outs().indent(indent);
    option->getOption()->printOptionInfo(descIndent - indent);
  }
}

/// Return the maximum width required when printing the help string.
size_t detail::PassOptions::getOptionWidth() const {
  size_t max = 0;
  for (auto *option : options)
    max = std::max(max, option->getOption()->getOptionWidth());
  return max;
}

//===----------------------------------------------------------------------===//
// TextualPassPipeline Parser
//===----------------------------------------------------------------------===//
@@ -443,6 +509,7 @@ struct PassNameParser : public llvm::cl::parser<PassArgData> {
  void initialize();
  void printOptionInfo(const llvm::cl::Option &opt,
                       size_t globalWidth) const override;
  size_t getOptionWidth(const llvm::cl::Option &opt) const override;
  bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
             PassArgData &value);
};
@@ -467,15 +534,54 @@ void PassNameParser::initialize() {
  }
}

void PassNameParser::printOptionInfo(const llvm::cl::Option &O,
                                     size_t GlobalWidth) const {
  PassNameParser *TP = const_cast<PassNameParser *>(this);
  llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
                       [](const PassNameParser::OptionInfo *VT1,
                          const PassNameParser::OptionInfo *VT2) {
                         return VT1->Name.compare(VT2->Name);
void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
                                     size_t globalWidth) const {
  // Print the information for the top-level option.
  if (opt.hasArgStr()) {
    llvm::outs() << "  --" << opt.ArgStr;
    opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
  } else {
    llvm::outs() << "  " << opt.HelpStr << '\n';
  }

  // Print the top-level pipeline argument.
  printOptionHelp(passPipelineArg,
                  "A textual description of a pass pipeline to run",
                  /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr());

  // Functor used to print the ordered entries of a registration map.
  auto printOrderedEntries = [&](StringRef header, auto &map) {
    llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
    for (auto &kv : map)
      orderedEntries.push_back(&kv.second);
    llvm::array_pod_sort(
        orderedEntries.begin(), orderedEntries.end(),
        [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
          return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
        });
  llvm::cl::parser<PassArgData>::printOptionInfo(O, GlobalWidth);

    llvm::outs().indent(4) << header << ":\n";
    for (PassRegistryEntry *entry : orderedEntries)
      entry->printHelpStr(/*indent=*/6, globalWidth);
  };

  // Print the available passes.
  printOrderedEntries("Passes", *passRegistry);

  // Print the available pass pipelines.
  if (!passPipelineRegistry->empty())
    printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
}

size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
  size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;

  // Check for any wider pass or pipeline options.
  for (auto &entry : *passRegistry)
    maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
  for (auto &entry : *passPipelineRegistry)
    maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
  return maxWidth;
}

bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,