Commit e617f1bf authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

Got custom CompositeInstructions working with lambda xasm language

parent 29845df3
Loading
Loading
Loading
Loading
+98 −21
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@
#include "Utils.hpp"
#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTImporter.h"
#include "clang/ASTMatchers/ASTMatchers.h"

#include "qcor_clang_utils.hpp"

@@ -21,10 +22,13 @@ void FuzzyParsingExternalSemaSource::initialize() {
  validInstructions = provider->getInstructions();
  validInstructions.push_back("CX");

  std::vector<std::shared_ptr<xacc::Instruction>> composites;
  std::string totalTempSource = "";
  for (auto &instructionName : validInstructions) {
    std::string tmpSource = "void " + instructionName + "(";
    auto tmpInst = provider->createInstruction(
        instructionName == "CX" ? "CNOT" : instructionName, {});
    if (!tmpInst->isComposite()) {
      int nRequiredBits = tmpInst->nRequiredBits();
      tmpSource += "int q0";
      for (int i = 1; i < nRequiredBits; i++) {
@@ -38,31 +42,104 @@ void FuzzyParsingExternalSemaSource::initialize() {
        }
      }
      tmpSource += "){return;}";
    quantumInstructionASTs.insert(
        {instructionName + "__qcor_instruction",
         tooling::buildASTFromCodeWithArgs(tmpSource, {"-std=c++11"})});
      quantumInstruction2src.insert(
          {instructionName + "__qcor_instruction", tmpSource});
    } else {
      compositeInstructions.push_back(instructionName + "__qcor_instruction");
    }
  }

  hast = tooling::buildASTFromCodeWithArgs(
      "#include \"heterogeneous.hpp\"\nvoid f(xacc::HeterogeneousMap&& "
      "m, std::vector<double>& x){return;}",
      {"-std=c++14", "-I/home/cades/.xacc/include/xacc"});
  hMapRValue = FirstDeclMatcher<ParmVarDecl>().match(
      hast->getASTContext().getTranslationUnitDecl(), namedDecl(hasName("m")));
  stdVector = FirstDeclMatcher<ParmVarDecl>().match(
      hast->getASTContext().getTranslationUnitDecl(), namedDecl(hasName("x")));
}

bool FuzzyParsingExternalSemaSource::LookupUnqualified(clang::LookupResult &R,
                                                       clang::Scope *S) {
  std::string unknownName = R.getLookupName().getAsString();

  // If this is a valid quantum instruction, tell Clang its
  // all gonna be ok, we got this...
  if (quantumInstructionASTs.count(unknownName + "__qcor_instruction") &&
  // If this is a valid quantum instruction, tell Clang not to error
  if (quantumInstruction2src.count(unknownName + "__qcor_instruction") &&
      S->getFlags() != 128 && S->getBlockParent() != nullptr) {

    auto Matcher = namedDecl(hasName(unknownName));

    auto ast = tooling::buildASTFromCodeWithArgs(
        quantumInstruction2src[unknownName + "__qcor_instruction"],
        {"-std=c++11"});
    FunctionDecl *D0 = FirstDeclMatcher<FunctionDecl>().match(
        quantumInstructionASTs[unknownName + "__qcor_instruction"]
            ->getASTContext()
            .getTranslationUnitDecl(),
        Matcher);
        ast->getASTContext().getTranslationUnitDecl(), Matcher);

    quantumInstructionASTs.push_back(std::move(ast));

    R.addDecl(D0);
    D0->dump();
  } else if (std::find(compositeInstructions.begin(),
                       compositeInstructions.end(),
                       unknownName + "__qcor_instruction") !=
             std::end(compositeInstructions)) {

    if (!qbit) {
      // Save pointers to xacc::qbit, xacc::HeterogeneousMap&& ParmVarDecl
      qbit = FirstDeclMatcher<ParmVarDecl>().match(
          ci.getASTContext().getTranslationUnitDecl(),
          parmVarDecl(hasType(recordDecl(matchesName("xacc::qbit")))));
    }

    // This is a Circuit Generator CompositeInstruction. We assume
    // (for now) it has must have prototype
    // f(qbit q, std::vector<double>& x, HeterogeneousMap&&)
    // or f(qbit q, HeterogeneousMap&&)

    // Create a new ParmVarDecl exactly like that one
    auto qb_copy = ParmVarDecl::Create(
        ci.getASTContext(), ci.getSema().getFunctionLevelDeclContext(),
        SourceLocation(), SourceLocation(), qbit->getIdentifier(),
        qbit->getType(), 0, SC_None, nullptr);
   auto v_copy = ParmVarDecl::Create(
        ci.getASTContext(), ci.getSema().getFunctionLevelDeclContext(),
        SourceLocation(), SourceLocation(), stdVector->getIdentifier(),
        stdVector->getType(), 0, SC_None, nullptr);
    auto h_copy = ParmVarDecl::Create(
        ci.getASTContext(), ci.getSema().getFunctionLevelDeclContext(),
        SourceLocation(), SourceLocation(), hMapRValue->getIdentifier(),
        hMapRValue->getType(), 0, SC_None, nullptr);

    // Use astContext.getFunctionType (RETURNTYPE, ARGSPARMVARS, fpi)
    // to create a new Function QualType
    std::vector<QualType> ParamTypes;
    ParamTypes.push_back(qb_copy->getType());
    ParamTypes.push_back(v_copy->getType());
    ParamTypes.push_back(h_copy->getType());
    FunctionProtoType::ExtProtoInfo fpi;
    fpi.Variadic = false;
    llvm::ArrayRef<QualType> Args(ParamTypes);
    QualType newFT = ci.getASTContext().getFunctionType(
        ci.getASTContext().VoidTy, Args, fpi);

    // Then use FunctionDecl::Create() to create a new functiondecl
    auto fdecl = FunctionDecl::Create(ci.getASTContext(),
                                      R.getSema().getFunctionLevelDeclContext(),
                                      SourceLocation(), SourceLocation(),
                                      R.getLookupName(), newFT, 0, SC_None);
    std::vector<ParmVarDecl *> params{qb_copy, v_copy, h_copy};
    llvm::ArrayRef<ParmVarDecl *> parms(params);
    fdecl->setParams(parms);
    std::vector<Stmt *> svec;
    auto rtrn = ReturnStmt::CreateEmpty(ci.getASTContext(), false);
    svec.push_back(rtrn);
    llvm::ArrayRef<Stmt *> stmts(svec);
    auto cmp = CompoundStmt::Create(ci.getASTContext(), stmts, SourceLocation(),
                                    SourceLocation());
    fdecl->setBody(cmp);
    fdecl->dump();

    R.addDecl(fdecl);
    return true;
  }
  return false;
+13 −4
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@

#include "clang/AST/ASTContext.h"
#include "clang/Frontend/ASTUnit.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Sema/ExternalSemaSource.h"
#include "clang/Sema/Lookup.h"

@@ -13,19 +14,27 @@ namespace compiler {
class FuzzyParsingExternalSemaSource : public ExternalSemaSource {
private:
  std::vector<std::string> validInstructions;
  ASTContext *m_Context;
  CompilerInstance &ci;
  ParmVarDecl *qbit;
  ParmVarDecl *hMapRValue;
  ParmVarDecl *stdVector;
  std::unique_ptr<ASTUnit> hast;
  std::vector<std::string> compositeInstructions;
//   std::vector<bool> compositeRequiresStdVector;

  // Keep a vector of ASTs for each FunctionDecl
  // representation of our quantum instructions.
  // This ExternalSemaSource should exist throughout
  // the tooling lifetime, so we should be good with
  // regards to these nodes being deleted
  std::map<std::string, std::unique_ptr<ASTUnit>> quantumInstructionASTs;
  std::vector<std::unique_ptr<ASTUnit>> quantumInstructionASTs;
  std::map<std::string, std::string> quantumInstruction2src;

public:
  FuzzyParsingExternalSemaSource() = default;
  FuzzyParsingExternalSemaSource(CompilerInstance &c) : ci(c) {}
  void initialize();
  void setASTContext(ASTContext *context) { m_Context = context; }
  //   void setASTContext(ASTContext *context) { m_Context = context; }
  //   void setFileManager(FileManager *m) { manager = m; }

  bool LookupUnqualified(clang::LookupResult &R, clang::Scope *S) override;
};
+6 −2
Original line number Diff line number Diff line
@@ -45,15 +45,19 @@ protected:
  }

  void ExecuteAction() override {

    CompilerInstance &CI = getCompilerInstance();
    CI.createSema(getTranslationUnitKind(), nullptr);
    rewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());

    auto fuzzyParser =
        std::make_shared<qcor::compiler::FuzzyParsingExternalSemaSource>();
        std::make_shared<qcor::compiler::FuzzyParsingExternalSemaSource>(CI);
    fuzzyParser->initialize();
    // fuzzyParser->setASTContext(&CI.getASTContext());
    // fuzzyParser->setFileManager(&CI.getFileManager());
    CI.getSema().addExternalSource(fuzzyParser.get());


    // FIXME Hook this back up
    // auto pragmaHandlers =
    // xacc::getServices<qcor::compiler::QCORPragmaHandler>(); for (auto p :
+13 −15
Original line number Diff line number Diff line
@@ -112,11 +112,11 @@ bool QCORASTVisitor::VisitLambdaExpr(LambdaExpr *LE) {
      targetAccelerator = xacc::getAccelerator();
    }

    // std::cout << "LAMBDA STR:\n" << xaccKernelLambdaStr << "\n";
    std::cout << "LAMBDA STR:\n" << xaccKernelLambdaStr << "\n";
    auto compiler = xacc::getCompiler("xasm");
    auto ir = compiler->compile(xaccKernelLambdaStr, targetAccelerator);

    auto function = ir->getComposites()[0]; //.getFunction();
    auto function = ir->getComposites()[0]; 
    for (auto &inst : function->getInstructions()) {
      if (!inst->isComposite() && inst->nParameters() > 0) {
        int counter = 0;
@@ -138,7 +138,8 @@ bool QCORASTVisitor::VisitLambdaExpr(LambdaExpr *LE) {
      }
    }

    // std::cout << "\n\nXACC IR:\n" << function->toString() << "\n";
    std::cout << "HELLO: " << function->getVariables() << "\n";
    std::cout << "\n\nXACC IR:\n" << function->toString() << "\n";

    auto sr = LE->getBody()->getSourceRange();
    if (!xacc::optionExists("accelerator")) {
@@ -203,7 +204,6 @@ bool QCORASTVisitor::VisitLambdaExpr(LambdaExpr *LE) {
    replacement += "}\n";
    rewriter.ReplaceText(sr, replacement);


    SourceLocation sll;
    QualType StrTy = ci.getASTContext().getConstantArrayType(
        ci.getASTContext().adjustStringLiteralBaseType(
@@ -230,7 +230,6 @@ bool QCORASTVisitor::VisitLambdaExpr(LambdaExpr *LE) {
    auto rtrn = ReturnStmt::Create(ci.getASTContext(), SourceLocation(),
                                   fnameSL, nullptr);


    auto cs = LE->getCallOperator()->getBody();
    for (auto it = cs->child_begin(); it != cs->child_end(); ++it) {
      svec.push_back(*it);
@@ -242,7 +241,6 @@ bool QCORASTVisitor::VisitLambdaExpr(LambdaExpr *LE) {
                                    SourceLocation());
    LE->getCallOperator()->setBody(cmp);


    llvm::ArrayRef<ParmVarDecl *> parms(params);
    LE->getCallOperator()->getAsFunction()->setParams(parms);
  }
+1 −25
Original line number Diff line number Diff line
@@ -15,7 +15,6 @@ using namespace clang::ast_matchers;
namespace qcor {
namespace compiler {
enum class DeclMatcherKind { First, Last };
enum class ExprMatcherKind { First, Last };

// Matcher class to retrieve the first/last matched node under a given AST.
template <typename NodeType, DeclMatcherKind MatcherKind>
@@ -40,33 +39,10 @@ public:
  }
};

template <typename NodeType, ExprMatcherKind MatcherKind>
class ExprMatcher : public ast_matchers::MatchFinder::MatchCallback {
  NodeType *Node = nullptr;
  void run(const MatchFinder::MatchResult &Result) override {
    if ((MatcherKind == ExprMatcherKind::First && Node == nullptr) ||
        MatcherKind == ExprMatcherKind::Last) {
      Node = const_cast<NodeType *>(Result.Nodes.getNodeAs<NodeType>(""));
    }
  }
  ASTContext& _ctx;
public:
  ExprMatcher(ASTContext& ctx) : _ctx(ctx) {}
  // Returns the first/last matched node under the tree rooted in `D`.
  template <typename MatcherType>
  NodeType *match(const Expr *D, const MatcherType &AMatcher) {
    MatchFinder Finder;
    Finder.addMatcher(AMatcher.bind(""), this);
    Finder.matchAST(_ctx);
    assert(Node);
    return Node;
  }
};

template <typename NodeType>
using FirstDeclMatcher = DeclMatcher<NodeType, DeclMatcherKind::First>;
template <typename NodeType>
using FirstExprMatcher = ExprMatcher<NodeType, ExprMatcherKind::First>;
using LastDeclMatcher = DeclMatcher<NodeType, DeclMatcherKind::Last>;
}
}
#endif
 No newline at end of file
Loading