This GitLab instance is undergoing maintenance and is operating in read-only mode.

You are on a read-only GitLab instance.
pyxasm_visitor.hpp 24 KB
Newer Older
1
2
3
4
5
#pragma once
#include <regex>

#include "IRProvider.hpp"
#include "pyxasmBaseVisitor.h"
6
#include "qcor_utils.hpp"
7
8
9
10
11
12
13
14
15
16
17
18
#include "qrt.hpp"
#include "xacc.hpp"

using namespace pyxasm;

std::map<std::string, std::string> common_name_map{
    {"CX", "CNOT"}, {"qcor::exp", "exp_i_theta"}, {"exp", "exp_i_theta"}};

using pyxasm_result_type =
    std::pair<std::string, std::shared_ptr<xacc::Instruction>>;

class pyxasm_visitor : public pyxasmBaseVisitor {
19
 protected:
20
  std::shared_ptr<xacc::IRProvider> provider;
21
22
  // List of buffers in the *context* of this XASM visitor
  std::vector<std::string> bufferNames;
23
24
  // List of *declared* variables
  std::vector<std::string> declared_var_names;
25

26
 public:
27
28
  pyxasm_visitor(const std::vector<std::string> &buffers = {},
                 const std::vector<std::string> &local_var_names = {})
29
30
      : provider(xacc::getIRProvider("quantum")),
        bufferNames(buffers),
31
        declared_var_names(local_var_names) {}
32
  pyxasm_result_type result;
33
34
  // New var declared (auto type) after visiting this node.
  std::string new_var;
35
  bool in_for_loop = false;
36
37
38
  // Var to keep track of sub-node rewrite:
  // e.g., traverse down the AST recursively.
  std::stringstream sub_node_translation;
39
  bool is_processing_sub_expr = false;
40

41
42
  antlrcpp::Any visitAtom_expr(
      pyxasmParser::Atom_exprContext *context) override {
43
    // std::cout << "Atom_exprContext: " << context->getText() << "\n";
44
45
46
47
48
49
50
51
52
53
54
55
56
    // Strategy:
    // At the top level, we analyze the trailer to determine the 
    // list of function call arguments.
    // Then, traverse down the arg. node to see if there is a potential rewrite rules
    // e.g. for arrays (as testlist_comp nodes)
    // Otherwise, just get the argument text as is.
    /*
    atom_expr: (AWAIT)? atom trailer*;
    atom: ('(' (yield_expr|testlist_comp)? ')' |
       '[' (testlist_comp)? ']' |
       '{' (dictorsetmaker)? '}' |
       NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False');
    */
57
58
59
    // Only processes these for sub-expressesions, 
    // e.g. re-entries to this function
    if (is_processing_sub_expr) {
60
61
      if (context->atom() && context->atom()->OPEN_BRACK() &&
          context->atom()->CLOSE_BRACK() && context->atom()->testlist_comp()) {
62
        // Array type expression:
63
64
        // std::cout << "Array atom expression: "
        //           << context->atom()->testlist_comp()->getText() << "\n";
65
66
67
68
        // Use braces
        sub_node_translation << "{";
        bool firstElProcessed = false;
        for (auto &testNode : context->atom()->testlist_comp()->test()) {
69
          // std::cout << "Array elem: " << testNode->getText() << "\n";
70
71
72
73
74
75
          // Add comma if needed (there is a previous element)
          if (firstElProcessed) {
            sub_node_translation << ", ";
          }
          sub_node_translation << testNode->getText();
          firstElProcessed = true;
76
        }
77
78
        sub_node_translation << "}";
        return 0;
79
80
      }

81
82
83
      // We don't have a re-write rule for this one (py::dict)
      if (context->atom() && context->atom()->OPEN_BRACE() &&
          context->atom()->CLOSE_BRACE() && context->atom()->dictorsetmaker()) {
84
        // Dict:
85
86
        // std::cout << "Dict atom expression: "
        //           << context->atom()->dictorsetmaker()->getText() << "\n";
87
88
89
        // TODO:
        return 0;
      }
90

91
92
93
94
95
96
97
98
99
100
      if (context->atom() && !context->atom()->STRING().empty()) {
        // Strings:
        for (auto &strNode : context->atom()->STRING()) {
          std::string cppStrLiteral = strNode->getText();
          // Handle Python single-quotes
          if (cppStrLiteral.front() == '\'' && cppStrLiteral.back() == '\'') {
            cppStrLiteral.front() = '"';
            cppStrLiteral.back() = '"';
          }
          sub_node_translation << cppStrLiteral;
101
102
          // std::cout << "String expression: " << strNode->getText() << " --> "
          //           << cppStrLiteral << "\n";
103
        }
104
        return 0;
105
106
      }

107
108
109
110
111
112
113
114
115
116
      const auto isSliceOp =
          [](pyxasmParser::Atom_exprContext *atom_expr_context) -> bool {
        if (atom_expr_context->trailer().size() == 1) {
          auto subscriptlist = atom_expr_context->trailer(0)->subscriptlist();
          if (subscriptlist && subscriptlist->subscript().size() == 1) {
            auto subscript = subscriptlist->subscript(0);
            const auto nbTestTerms = subscript->test().size();
            // Multiple test terms (separated by ':')
            return (nbTestTerms > 1);
          }
117
118
        }

119
120
        return false;
      };
121

122
123
124
125
126
      // Handle slicing operations (multiple array subscriptions separated by
      // ':') on a qreg.
      if (context->atom() &&
          xacc::container::contains(bufferNames, context->atom()->getText()) &&
          isSliceOp(context)) {
127
        // std::cout << "Slice op: " << context->getText() << "\n";
128
129
130
131
132
133
134
135
        sub_node_translation << context->atom()->getText()
                             << ".extract_range({";
        auto subscripts =
            context->trailer(0)->subscriptlist()->subscript(0)->test();
        assert(subscripts.size() > 1);
        std::vector<std::string> subscriptTerms;
        for (auto &test : subscripts) {
          subscriptTerms.emplace_back(test->getText());
136
137
        }

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        auto sliceOp =
            context->trailer(0)->subscriptlist()->subscript(0)->sliceop();
        if (sliceOp && sliceOp->test()) {
          subscriptTerms.emplace_back(sliceOp->test()->getText());
        }
        assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3);

        for (int i = 0; i < subscriptTerms.size(); ++i) {
          // Need to cast to prevent compiler errors,
          // e.g. when using q.size() which returns an int.
          sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i]
                               << ")";
          if (i != subscriptTerms.size() - 1) {
            sub_node_translation << ", ";
          }
        }

        sub_node_translation << "})";

        // convert the slice op to initializer list:
158
159
        // std::cout << "Slice Convert: " << context->getText() << " --> "
        //           << sub_node_translation.str() << "\n";
160
161
        return 0;
      }
162
163
164
165

      return 0;
    }

166
    // Handle kernel::ctrl(...), kernel::adjoint(...)
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
167
168
169
    if (!context->trailer().empty() &&
        (context->trailer()[0]->getText() == ".ctrl" ||
         context->trailer()[0]->getText() == ".adjoint")) {
170
171
172
      // std::cout << "HELLO: " << context->getText() << "\n";
      // std::cout << context->trailer()[0]->getText() << "\n";
      // std::cout << context->atom()->getText() << "\n";
173

174
175
      // std::cout << context->trailer()[1]->getText() << "\n";
      // std::cout << context->trailer()[1]->arglist() << "\n";
176
177
178
      auto arg_list = context->trailer()[1]->arglist();

      std::stringstream ss;
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
179
180
      // Remove the first '.' character
      const std::string methodName = context->trailer()[0]->getText().substr(1);
181
182
183
184
185
186
187
188
189
      // If this is a *variable*, then using '.' for control/adjoint.
      // Otherwise, use '::' (global scope kernel names)
      const std::string separator =
          (xacc::container::contains(declared_var_names,
                                     context->atom()->getText()))
              ? "."
              : "::";

      ss << context->atom()->getText() << separator << methodName
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
190
         << "(parent_kernel";
191
      for (int i = 0; i < arg_list->argument().size(); i++) {
192
        ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i)));
193
194
195
      }
      ss << ");\n";

196
      // std::cout << "HELLO SS: " << ss.str() << "\n";
197
198
199
      result.first = ss.str();
      return 0;
    }
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    if (context->atom()->NAME() != nullptr) {
      auto inst_name = context->atom()->NAME()->getText();

      if (common_name_map.count(inst_name)) {
        inst_name = common_name_map[inst_name];
      }

      if (xacc::container::contains(provider->getInstructions(), inst_name)) {
        // Create an instance of the Instruction with the given name
        auto inst = provider->createInstruction(inst_name, 0);

        // If it is not composite, look for its bit expressions
        // and parameter expressions
        if (!inst->isComposite()) {
          // Get the number of required bits and parameters
          auto required_bits = inst->nRequiredBits();
          auto required_params = inst->getParameters().size();

          if (!context->trailer().empty()) {
            auto atom_n_args =
                context->trailer()[0]->arglist()->argument().size();

            if (required_bits + required_params != atom_n_args &&
                inst_name != "Measure") {
              std::stringstream xx;
              xx << "Invalid quantum instruction expression. " << inst_name
                 << " requires " << required_bits << " qubit args and "
                 << required_params << " parameter args.";
              xacc::error(xx.str());
            }

            // Get the qubit expresssions
            std::vector<std::string> buffer_names;
            for (int i = 0; i < required_bits; i++) {
              auto bit_expr = context->trailer()[0]->arglist()->argument()[i];
235
              auto bit_expr_str = rewriteFunctionArgument(*bit_expr);
236
237
238
239

              auto found_bracket = bit_expr_str.find_first_of("[");
              if (found_bracket != std::string::npos) {
                auto buffer_name = bit_expr_str.substr(0, found_bracket);
240
241
242
                auto bit_idx_expr = bit_expr_str.substr(
                    found_bracket + 1,
                    bit_expr_str.length() - found_bracket - 2);
243
244
245
                buffer_names.push_back(buffer_name);
                inst->setBitExpression(i, bit_idx_expr);
              } else {
246
247
248
                // Indicate this is a qubit(-1) or a qreg(-2)
                inst->setBitExpression(-1, bit_expr_str);
                buffer_names.push_back(bit_expr_str);
249
250
251
252
253
254
255
              }
            }
            inst->setBufferNames(buffer_names);

            // Get the parameter expressions
            int counter = 0;
            for (int i = required_bits; i < atom_n_args; i++) {
256
257
258
259
260
              inst->setParameter(counter,
                                 replacePythonConstants(context->trailer()[0]
                                                            ->arglist()
                                                            ->argument()[i]
                                                            ->getText()));
261
262
263
              counter++;
            }
          }
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
          result.second = inst;
        } else {
          // Composite instructions, e.g. exp_i_theta
          if (inst_name == "exp_i_theta") {
            // Expected 3 params:
            if (context->trailer()[0]->arglist()->argument().size() != 3) {
              xacc::error(
                  "Invalid number of arguments for the 'exp_i_theta' "
                  "instruction. Expected 3, got " +
                  std::to_string(
                      context->trailer()[0]->arglist()->argument().size()) +
                  ". Please check your input.");
            }

            std::stringstream ss;
            // Delegate to the QRT call directly.
            ss << "quantum::exp("
               << context->trailer()[0]->arglist()->argument(0)->getText()
               << ", "
               << context->trailer()[0]->arglist()->argument(1)->getText()
               << ", "
               << context->trailer()[0]->arglist()->argument(2)->getText()
               << ");\n";
            result.first = ss.str();
288
289
290
291
292
293
          }
          // Handle potential name collision: user-defined kernel having the
          // same name as an XACC circuit: e.g. common names such as qft, iqft
          // Note: these circuits (except exp_i_theta) don't have QRT
          // equivalents.
          // Condition: first argument is a qubit register
294
295
296
297
298
299
300
301
          else if (xacc::container::contains(
                       ::quantum::kernels_in_translation_unit, inst_name) ||
                   !context->trailer()[0]->arglist()->argument().empty() &&
                       xacc::container::contains(bufferNames,
                                                 context->trailer()[0]
                                                     ->arglist()
                                                     ->argument(0)
                                                     ->getText())) {
302
303
304
305
306
307
308
309
310
311
312
313
            std::stringstream ss;
            // Use the kernel call with a parent kernel arg.
            ss << inst_name << "(parent_kernel, ";
            const auto &argList = context->trailer()[0]->arglist()->argument();
            for (size_t i = 0; i < argList.size(); ++i) {
              ss << argList[i]->getText();
              if (i != argList.size() - 1) {
                ss << ", ";
              }
            }
            ss << ");\n";
            result.first = ss.str();
314
315
316
317
          } else {
            xacc::error("Composite instruction '" + inst_name +
                        "' is not currently supported.");
          }
318
        }
319
      } else {
320
321
322
        // This kernel *callable* is not an intrinsic instruction, just
        // reassemble the call:
        // Check that the *first* argument is a *qreg* in the current context of
323
324
325
326
327
328
329
330
        // *this* kernel or the function name is a kernel in translation unit.
        if (xacc::container::contains(::quantum::kernels_in_translation_unit,
                                      inst_name) ||
            (!context->trailer().empty() && context->trailer()[0]->arglist() &&
             !context->trailer()[0]->arglist()->argument().empty() &&
             xacc::container::contains(
                 bufferNames,
                 context->trailer()[0]->arglist()->argument(0)->getText()))) {
331
          std::stringstream ss;
332
          // Use the kernel call with a parent kernel arg.
333
334
335
336
337
338
339
340
341
342
343
344
          ss << inst_name << "(parent_kernel, ";
          // TODO: We potentially need to handle *inline* expressions in the
          // function call.
          const auto &argList = context->trailer()[0]->arglist()->argument();
          for (size_t i = 0; i < argList.size(); ++i) {
            ss << argList[i]->getText();
            if (i != argList.size() - 1) {
              ss << ", ";
            }
          }
          ss << ");\n";
          result.first = ss.str();
345
        } else {
346
347
348
349
          if (!context->trailer().empty()) {
            // A classical call-like expression: i.e. not a kernel call:
            // Just output it *as-is* to the C++ stream.
            // We can hook more sophisticated code-gen here if required.
350
            // std::cout << "Callable: " << context->getText() << "\n";
351
            std::stringstream ss;
352
353
354
355
356
357
358

            if (context->trailer()[0]->arglist() &&
                !context->trailer()[0]->arglist()->argument().empty()) {
              const auto &argList =
                  context->trailer()[0]->arglist()->argument();
              ss << inst_name << "(";
              for (size_t i = 0; i < argList.size(); ++i) {                
359
                ss << rewriteFunctionArgument(*(argList[i]));                
360
361
362
363
364
365
366
367
                if (i != argList.size() - 1) {
                  ss << ", ";
                }
              }
              ss << ");\n";
            } else {
              ss << context->getText() << ";\n";
            }
368
369
            result.first = ss.str();
          }
370
        }
371
372
373
374
375
376
      }
    }
    return 0;
  }

  antlrcpp::Any visitFor_stmt(pyxasmParser::For_stmtContext *context) override {
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
377
378
    // Rewrite:
    // Python: "for <var> in <expr>:"
379
    // C++: for (auto var: <expr>) {}
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
380
    // Note: we add range(int) as a C++ function to support this common pattern.
381
382
383
384
385
    // or
    // Python: "for <idx>,<var> in enumerate(<listvar>):"
    // C++: for (auto [idx, var] : enumerate(listvar))
    auto iter_container = context->testlist()->test()[0]->getText();
    std::string counter_expr = context->exprlist()->expr()[0]->getText();
386
387
    // Add the for loop variable to the tracking list as well.
    new_var = counter_expr;
388
389
390
391
392
393
394
395
    if (context->exprlist()->expr().size() > 1) {
      counter_expr = "[" + counter_expr;
      for (int i = 1; i < context->exprlist()->expr().size(); i++) {
        counter_expr += ", " + context->exprlist()->expr()[i]->getText();
      }
      counter_expr += "]";
    }

Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
396
    std::stringstream ss;
397
    ss << "for (auto " << counter_expr << " : " << iter_container << ") {\n";
Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
398
399
    result.first = ss.str();
    in_for_loop = true;
400
401
    return 0;
  }
402
403
404
405
406
407

  antlrcpp::Any visitExpr_stmt(pyxasmParser::Expr_stmtContext *ctx) override {
    if (ctx->ASSIGN().size() == 1 && ctx->testlist_star_expr().size() == 2) {
      // Handle simple assignment: a = expr
      std::stringstream ss;
      const std::string lhs = ctx->testlist_star_expr(0)->getText();
408
      std::string rhs = replacePythonConstants(
409
          replaceMeasureAssignment(ctx->testlist_star_expr(1)->getText()));
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425

      if (lhs.find(",") != std::string::npos) {
        // this is
        // var1, var2, ... = some_tuple_thing
        // We only support var1, var2 = ... for now
        // where ... is a pair-like object
        std::vector<std::string> suffix{".first", ".second"};
        auto vars = xacc::split(lhs, ',');
        for (auto [i, var] : qcor::enumerate(vars)) {
          if (xacc::container::contains(declared_var_names, var)) {
            ss << var << " = " << rhs << suffix[i] << ";\n";
          } else {
            ss << "auto " << var << " = " << rhs << suffix[i] << ";\n";
            new_var = lhs;
          }
        }
426
      } else {
427
428
        // Strategy: try to traverse the rhs to see if there is a possible rewrite;
        // Otherwise, use the text as is.
429
        is_processing_sub_expr = true;
430
431
432
433
434
435
436
437
438
439
440
441
442
        // clear the sub_node_translation  
        sub_node_translation.str(std::string());

        // visit arg sub-node:
        visitChildren(ctx->testlist_star_expr(1));

        // Check if there is a rewrite:
        if (!sub_node_translation.str().empty()) {
          // Update RHS
          rhs = replacePythonConstants(
              replaceMeasureAssignment(sub_node_translation.str()));
        }

443
        if (xacc::container::contains(declared_var_names, lhs)) {
444
          ss << lhs << " = " << rhs << "; \n";
445
446
        } else {
          // New variable: need to add *auto*
447
          ss << "auto " << lhs << " = " << rhs << "; \n";
448
449
          new_var = lhs;
        }
450
      }
451

452
      result.first = ss.str();
453
454
455
456
457
458
      if (rhs.find("**") != std::string::npos) {
        // keep processing
        return visitChildren(ctx);
      } else {
        return 0;
      }
459
    } else {
460
461
462
463
464
465
466
467
468
469
470
      // Visit child node:
      auto child_result = visitChildren(ctx);
      const auto translated_src = sub_node_translation.str();
      sub_node_translation.str(std::string());
      // If no child nodes, perform the codegen (result.first is not set)
      // but just appending the incremental translation collector;
      // return the collected C++ statement.
      if (result.first.empty() && !translated_src.empty()) {
        result.first = translated_src + ";\n";
      }
      return child_result;
471
472
    }
  }
473

474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
  antlrcpp::Any visitPower(pyxasmParser::PowerContext *context) override {
    if (context->getText().find("**") != std::string::npos &&
        context->factor() != nullptr) {
      // Here we handle x**y from parent assignment expression
      auto replaceAll = [](std::string &s, const std::string &search,
                           const std::string &replace) {
        for (std::size_t pos = 0;; pos += replace.length()) {
          // Locate the substring to replace
          pos = s.find(search, pos);
          if (pos == std::string::npos) break;
          // Replace by erasing and inserting
          s.erase(pos, search.length());
          s.insert(pos, replace);
        }
      };
      auto factor = context->factor();
      auto atom_expr = context->atom_expr();
      std::string s =
          "std::pow(" + atom_expr->getText() + ", " + factor->getText() + ")";
      replaceAll(result.first, context->getText(), s);
      return 0;
    }
    return visitChildren(context);
  }

499
500
  virtual antlrcpp::Any visitIf_stmt(
      pyxasmParser::If_stmtContext *ctx) override {
501
502
503
    // Only support single clause atm
    if (ctx->test().size() == 1) {
      std::stringstream ss;
504
505
506
507
      ss << "if ("
         << replacePythonConstants(
                replaceMeasureAssignment(ctx->test(0)->getText()))
         << ") {\n";
508
509
510
511
512
513
      result.first = ss.str();
      return 0;
    }
    return visitChildren(ctx);
  }

Nguyen, Thien Minh's avatar
Nguyen, Thien Minh committed
514
515
516
517
518
519
520
521
  virtual antlrcpp::Any
  visitWhile_stmt(pyxasmParser::While_stmtContext *ctx) override {
    std::stringstream ss;
    ss << "while (" << ctx->test()->getText() << ") {\n";
    result.first = ss.str();
    return 0;
  }

522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
  virtual antlrcpp::Any visitTestlist_star_expr(
      pyxasmParser::Testlist_star_exprContext *context) override {
    // std::cout << "Testlist_star_exprContext:" << context->getText() << "\n";
    const auto var_name = context->getText();
    if (xacc::container::contains(declared_var_names, var_name)) {
      sub_node_translation << var_name << " ";
      return 0;
    }
    return visitChildren(context);
  }

  virtual antlrcpp::Any
  visitAugassign(pyxasmParser::AugassignContext *context) override {
    // std::cout << "Augassign:" << context->getText() << "\n";
    sub_node_translation << context->getText() << " ";
    return 0;
  }

  virtual antlrcpp::Any
  visitTestlist(pyxasmParser::TestlistContext *context) override {
    // std::cout << "visitTestlist:" << context->getText() << "\n";
    sub_node_translation << context->getText() << " ";
    return 0;
  }

547
 private:
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
  // Replaces common Python constants, e.g. 'math.pi' or 'numpy.pi'.
  // Note: the library names have been resolved to their original names.
  std::string replacePythonConstants(const std::string &in_pyExpr) const {
    // List of all keywords to be replaced
    const std::map<std::string, std::string> REPLACE_MAP{{"math.pi", "M_PI"},
                                                         {"numpy.pi", "M_PI"}};
    std::string newSrc = in_pyExpr;
    for (const auto &[key, value] : REPLACE_MAP) {
      const auto pos = newSrc.find(key);
      if (pos != std::string::npos) {
        newSrc.replace(pos, key.length(), value);
      }
    }
    return newSrc;
  }
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591

  // Assignment of Measure results -> variable or in if conditional statements
  std::string replaceMeasureAssignment(const std::string &in_expr) const {
    if (in_expr.find("Measure") != std::string::npos) {
      // Found measure in an if statement instruction.
      const auto replaceMeasureInst = [](std::string &s,
                                         const std::string &search,
                                         const std::string &replace) {
        for (size_t pos = 0;; pos += replace.length()) {
          pos = s.find(search, pos);
          if (pos == std::string::npos) {
            break;
          }
          if (!isspace(s[pos + search.length()]) &&
              (s[pos + search.length()] != '(')) {
            continue;
          }
          s.erase(pos, search.length());
          s.insert(pos, replace);
        }
      };

      std::string result = in_expr;
      replaceMeasureInst(result, "Measure", "quantum::mz");
      return result;
    } else {
      return in_expr;
    }
  }
592
593
594
595
596
597
598
599
600
601
602
603

  // A helper to rewrite function argument by traversing the node to see
  // if there is a potential rewrite.
  // Use case: inline expressions
  // e.g. X(q[0:3])
  // slicing of the qreg 'q' then call the broadcast X op.
  // i.e., we need to rewrite the arg to q.extract_range({0, 3}).
  std::string
  rewriteFunctionArgument(pyxasmParser::ArgumentContext &in_argContext) {
    // Strategy: try to traverse the argument context to see if there is a
    // possible rewrite; i.e. it may be another atom_expression that we have a
    // handler for. Otherwise, use the text as is.
604
605
606
607
    // We need this flag to prevent parsing quantum instructions as sub-expressions.
    // e.g. QCOR operators (X, Y, Z) in an observable definition shouldn't be 
    // processed as instructions.
    is_processing_sub_expr = true;
608
609
610
611
612
613
614
615
616
617
618
619
620
621
    // clear the sub_node_translation
    sub_node_translation.str(std::string());

    // visit arg sub-node:
    visitChildren(&in_argContext);

    // Check if there is a rewrite:
    if (!sub_node_translation.str().empty()) {
      // Update RHS
      return sub_node_translation.str();
    }
    // Returns the string as is
    return in_argContext.getText();
  }
622
};