pyxasm_visitor.hpp 9.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#pragma once
#include <regex>

#include "IRProvider.hpp"
#include "pyxasmBaseVisitor.h"
#include "qrt.hpp"
#include "xacc.hpp"

using namespace pyxasm;

std::map<std::string, std::string> common_name_map{
    {"CX", "CNOT"}, {"qcor::exp", "exp_i_theta"}, {"exp", "exp_i_theta"}};

using pyxasm_result_type =
    std::pair<std::string, std::shared_ptr<xacc::Instruction>>;

class pyxasm_visitor : public pyxasmBaseVisitor {
18
 protected:
19
  std::shared_ptr<xacc::IRProvider> provider;
20
21
  // List of buffers in the *context* of this XASM visitor
  std::vector<std::string> bufferNames;
22

23
 public:
24
25
  pyxasm_visitor(const std::vector<std::string> &buffers = {})
      : provider(xacc::getIRProvider("quantum")), bufferNames(buffers) {}
26
27
28
29
  pyxasm_result_type result;

  bool in_for_loop = false;

30
31
  antlrcpp::Any visitAtom_expr(
      pyxasmParser::Atom_exprContext *context) override {
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
      
    // Handle kernel::ctrl(...), kernel::adjoint(...)
    if (!context->trailer().empty() && context->trailer()[0]->getText() == ".ctrl") {
      std::cout << "HELLO: " << context->getText() << "\n";
      std::cout << context->trailer()[0]->getText() << "\n";
      std::cout << context->atom()->getText() << "\n";

      std::cout << context->trailer()[1]->getText() << "\n";
      std::cout << context->trailer()[1]->arglist() << "\n";
      auto arg_list = context->trailer()[1]->arglist();

      std::stringstream ss;
      ss << context->atom()->getText() << "::ctrl(parent_kernel";
      for (int i = 0; i < arg_list->argument().size(); i++) {
        ss << ", " << arg_list->argument(i)->getText();
      }
      ss << ");\n";

      std::cout << "HELLO SS: " << ss.str() << "\n";
      result.first = ss.str();
      return 0;

    }
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    if (context->atom()->NAME() != nullptr) {
      auto inst_name = context->atom()->NAME()->getText();

      if (common_name_map.count(inst_name)) {
        inst_name = common_name_map[inst_name];
      }

      if (xacc::container::contains(provider->getInstructions(), inst_name)) {
        // Create an instance of the Instruction with the given name
        auto inst = provider->createInstruction(inst_name, 0);

        // If it is not composite, look for its bit expressions
        // and parameter expressions
        if (!inst->isComposite()) {
          // Get the number of required bits and parameters
          auto required_bits = inst->nRequiredBits();
          auto required_params = inst->getParameters().size();

          if (!context->trailer().empty()) {
            auto atom_n_args =
                context->trailer()[0]->arglist()->argument().size();

            if (required_bits + required_params != atom_n_args &&
                inst_name != "Measure") {
              std::stringstream xx;
              xx << "Invalid quantum instruction expression. " << inst_name
                 << " requires " << required_bits << " qubit args and "
                 << required_params << " parameter args.";
              xacc::error(xx.str());
            }

            // Get the qubit expresssions
            std::vector<std::string> buffer_names;
            for (int i = 0; i < required_bits; i++) {
              auto bit_expr = context->trailer()[0]->arglist()->argument()[i];
              auto bit_expr_str = bit_expr->getText();

              auto found_bracket = bit_expr_str.find_first_of("[");
              if (found_bracket != std::string::npos) {
                auto buffer_name = bit_expr_str.substr(0, found_bracket);
95
96
97
                auto bit_idx_expr = bit_expr_str.substr(
                    found_bracket + 1,
                    bit_expr_str.length() - found_bracket - 2);
98
99
100
101
102
103
104
105
106
107
108
                buffer_names.push_back(buffer_name);
                inst->setBitExpression(i, bit_idx_expr);
              } else {
                xacc::error("Must provide qreg[IDX] and not just qreg.");
              }
            }
            inst->setBufferNames(buffer_names);

            // Get the parameter expressions
            int counter = 0;
            for (int i = required_bits; i < atom_n_args; i++) {
109
110
111
112
113
              inst->setParameter(counter,
                                 replacePythonConstants(context->trailer()[0]
                                                            ->arglist()
                                                            ->argument()[i]
                                                            ->getText()));
114
115
116
              counter++;
            }
          }
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
          result.second = inst;
        } else {
          // Composite instructions, e.g. exp_i_theta
          if (inst_name == "exp_i_theta") {
            // Expected 3 params:
            if (context->trailer()[0]->arglist()->argument().size() != 3) {
              xacc::error(
                  "Invalid number of arguments for the 'exp_i_theta' "
                  "instruction. Expected 3, got " +
                  std::to_string(
                      context->trailer()[0]->arglist()->argument().size()) +
                  ". Please check your input.");
            }

            std::stringstream ss;
            // Delegate to the QRT call directly.
            ss << "quantum::exp("
               << context->trailer()[0]->arglist()->argument(0)->getText()
               << ", "
               << context->trailer()[0]->arglist()->argument(1)->getText()
               << ", "
               << context->trailer()[0]->arglist()->argument(2)->getText()
               << ");\n";
            result.first = ss.str();
          } else {
            xacc::error("Composite instruction '" + inst_name +
                        "' is not currently supported.");
          }
145
        }
146
      } else {
147
148
149
150
151
152
153
154
155
        // This kernel *callable* is not an intrinsic instruction, just
        // reassemble the call:
        // Check that the *first* argument is a *qreg* in the current context of
        // *this* kernel.
        if (!context->trailer().empty() &&
            !context->trailer()[0]->arglist()->argument().empty() &&
            xacc::container::contains(
                bufferNames,
                context->trailer()[0]->arglist()->argument(0)->getText())) {
156
          std::stringstream ss;
157
          // Use the kernel call with a parent kernel arg.
158
159
160
161
162
163
164
165
166
167
168
169
170
          ss << inst_name << "(parent_kernel, ";
          // TODO: We potentially need to handle *inline* expressions in the
          // function call.
          const auto &argList = context->trailer()[0]->arglist()->argument();
          for (size_t i = 0; i < argList.size(); ++i) {
            ss << argList[i]->getText();
            if (i != argList.size() - 1) {
              ss << ", ";
            }
          }
          ss << ");\n";
          result.first = ss.str();
        }
171
172
173
174
175
176
177
      }
    }
    return 0;
  }

  antlrcpp::Any visitFor_stmt(pyxasmParser::For_stmtContext *context) override {
    auto counter_expr = context->exprlist()->expr()[0];
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
178
179
180
181
182
183
184
185
186
187
    auto iter_container = context->testlist()->test()[0]->getText();
    // Rewrite:
    // Python: "for <var> in <expr>:"
    // C++: for (auto& var: <expr>) {}
    // Note: we add range(int) as a C++ function to support this common pattern.
    std::stringstream ss;
    ss << "for (auto &" << counter_expr->getText() << " : " << iter_container
       << ") {\n";
    result.first = ss.str();
    in_for_loop = true;
188
189
    return 0;
  }
190
191
192
193
194
195
196
197
198

  antlrcpp::Any visitExpr_stmt(pyxasmParser::Expr_stmtContext *ctx) override {
    if (ctx->ASSIGN().size() == 1 && ctx->testlist_star_expr().size() == 2) {
      // Handle simple assignment: a = expr
      std::stringstream ss;
      const std::string lhs = ctx->testlist_star_expr(0)->getText();
      const std::string rhs = ctx->testlist_star_expr(1)->getText();
      ss << "auto " << lhs << " = " << rhs << "; \n";
      result.first = ss.str();
199
200
201
202
203
204
      if (rhs.find("**") != std::string::npos) {
        // keep processing
        return visitChildren(ctx);
      } else {
        return 0;
      }
205
206
207
208
    } else {
      return visitChildren(ctx);
    }
  }
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
  antlrcpp::Any visitPower(pyxasmParser::PowerContext *context) override {
    if (context->getText().find("**") != std::string::npos &&
        context->factor() != nullptr) {
      // Here we handle x**y from parent assignment expression
      auto replaceAll = [](std::string &s, const std::string &search,
                           const std::string &replace) {
        for (std::size_t pos = 0;; pos += replace.length()) {
          // Locate the substring to replace
          pos = s.find(search, pos);
          if (pos == std::string::npos) break;
          // Replace by erasing and inserting
          s.erase(pos, search.length());
          s.insert(pos, replace);
        }
      };
      auto factor = context->factor();
      auto atom_expr = context->atom_expr();
      std::string s =
          "std::pow(" + atom_expr->getText() + ", " + factor->getText() + ")";
      replaceAll(result.first, context->getText(), s);
      return 0;
    }
    return visitChildren(context);
  }

 private:
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
  // Replaces common Python constants, e.g. 'math.pi' or 'numpy.pi'.
  // Note: the library names have been resolved to their original names.
  std::string replacePythonConstants(const std::string &in_pyExpr) const {
    // List of all keywords to be replaced
    const std::map<std::string, std::string> REPLACE_MAP{{"math.pi", "M_PI"},
                                                         {"numpy.pi", "M_PI"}};
    std::string newSrc = in_pyExpr;
    for (const auto &[key, value] : REPLACE_MAP) {
      const auto pos = newSrc.find(key);
      if (pos != std::string::npos) {
        newSrc.replace(pos, key.length(), value);
      }
    }
    return newSrc;
  }
251
};