quantum_types_handler.cpp 6.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include "qasm3_visitor.hpp"

namespace qcor {

antlrcpp::Any qasm3_visitor::visitQuantumDeclaration(
    qasm3Parser::QuantumDeclarationContext* context) {
  // quantumDeclaration
  //     : quantumType indexIdentifierList
  //     ;
  //
  //   indexIdentifier
  //     : Identifier rangeDefinition
  //     | Identifier ( LBRACKET expressionList RBRACKET )?
  //     | indexIdentifier '||' indexIdentifier
  //     ;

  // indexIdentifierList
  //     : ( indexIdentifier COMMA )* indexIdentifier
  //     ;
  //
  // can be
  // qubit q;
  // qubit q[size];
  // qubit q[size], r[size2], ...;
  // qreg q[size];

  auto location = get_location(builder, file_name, context);

  std::size_t size = 1;
  auto index_ident_list = context->indexIdentifierList();

  for (auto idx_identifier : index_ident_list->indexIdentifier()) {
    auto var_name = idx_identifier->Identifier()->getText();
    auto exp_list = idx_identifier->expressionList();
    if (exp_list) {
36
37
38
39
40
      auto opt_size = symbol_table.try_evaluate_constant_integer_expression(
          exp_list->expression(0)->getText());
      if (opt_size.has_value()) {
        size = opt_size.value();
      } else {
41
42
43
44
45
46
47
48
49
50
51
52
53
        // check if this is a constant expression
        qasm3_expression_generator exp_generator(builder, symbol_table,
                                                 file_name);
        exp_generator.visit(exp_list->expression(0));
        auto arg = exp_generator.current_value;

        if (auto constantOp = arg.getDefiningOp<mlir::ConstantOp>()) {
          if (constantOp.getValue().isa<mlir::IntegerAttr>()) {
            size = constantOp.getValue().cast<mlir::IntegerAttr>().getInt();
          } else {
            printErrorMessage(
                "This variable qubit size must be a constant integer.");
          }
54
        } else {
55
56
          size = symbol_table.get_global_constant<int64_t>(
              exp_list->expression(0)->getText());
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        }
      }
    }

    auto integer_type = builder.getI64Type();
    auto integer_attr = mlir::IntegerAttr::get(integer_type, size);

    auto str_attr = builder.getStringAttr(var_name);
    mlir::Value allocation = builder.create<mlir::quantum::QallocOp>(
        location, array_type, integer_attr, str_attr);

    if (context->quantumType()->getText() == "qubit" && size == 1) {
      // we have a single qubit, dont set it as an array in teh
      // symbol table, extract it and set it
      mlir::Value pos = get_or_create_constant_integer_value(
          0, location, builder.getIntegerType(64), symbol_table, builder);

      // Need to also store the qubit array for this single qubit
      // so that we can deallocate later.
76
      symbol_table.add_symbol("__qcor__mlir__single_qubit_register_" + var_name,
77
                              allocation);
78
79
80
81
82

      allocation = builder.create<mlir::quantum::ExtractQubitOp>(
          location, qubit_type, allocation, pos);
    }

83
    symbol_table.add_symbol(var_name, allocation);
84
85
86
87
88
    size = 1;
  }
  return 0;
}

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
antlrcpp::Any qasm3_visitor::visitQuantumGateDefinition(
    qasm3Parser::QuantumGateDefinitionContext* context) {
  // quantumGateDefinition
  //     : 'gate' quantumGateSignature quantumBlock
  //     ;

  // quantumGateSignature
  //     : ( Identifier | 'CX' | 'U') ( LPAREN identifierList? RPAREN )?
  //     identifierList
  //     ;

  // quantumBlock
  //     : LBRACE ( quantumStatement | quantumLoop )* RBRACE
  //     ;

  auto gate_call_name =
      context->quantumGateSignature()->Identifier()->getText();
  bool has_classical_params =
      context->quantumGateSignature()->identifierList().size() == 2;

  auto qbit_ident_list_idx = 0;
  std::vector<std::string> arg_names;
  std::vector<mlir::Type> func_args;
  if (has_classical_params) {
    qbit_ident_list_idx = 1;

    auto params_ident_list = context->quantumGateSignature()->identifierList(0);

    for (auto ident_expr : params_ident_list->Identifier()) {
      func_args.push_back(builder.getF64Type());
      arg_names.push_back(ident_expr->getText());
    }
  }
  auto qubits_ident_list =
      context->quantumGateSignature()->identifierList(qbit_ident_list_idx);
  for (auto ident_expr : qubits_ident_list->Identifier()) {
    func_args.push_back(qubit_type);
    arg_names.push_back(ident_expr->getText());
  }

  auto main_block = builder.saveInsertionPoint();

131
  auto func_type = builder.getFunctionType(func_args, func_args);
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
  auto proto =
      mlir::FuncOp::create(builder.getUnknownLoc(), gate_call_name, func_type);
  mlir::FuncOp function(proto);

  auto& entryBlock = *function.addEntryBlock();
  builder.setInsertionPointToStart(&entryBlock);

  symbol_table.enter_new_scope();

  auto arguments = entryBlock.getArguments();
  for (int i = 0; i < arguments.size(); i++) {
    symbol_table.add_symbol(arg_names[i], arguments[i]);
  }

  auto quantum_block = context->quantumBlock();

  auto ret = visitChildren(quantum_block);

150
151
152
153
154
  // Can I walk the use chain of the block arguments
  // and get the resultant qubit values taht I can then return
  // from this custom gate definition
  std::vector<mlir::Value> result_qubit_vals;
  for (auto arg : entryBlock.getArguments()) {
Omar's avatar
Omar committed
155
    // check if argument is a gate param
156
157
    if (arg.getType().isF64()) {
        result_qubit_vals.push_back(arg);
Omar's avatar
Omar committed
158
    // skip use chain traversal
159
160
        continue;
    }
Omar's avatar
Omar committed
161
    mlir::Value last_user = arg;
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    auto users = last_user.getUsers();

    while (!users.empty()) {
      // Get the first and only user
      auto only_user = *users.begin();

      // figure out which operand idx last_user 
      // corresponds to
      int idx = -1;
      for (auto op : only_user->getOperands()) {
        idx++;
        if (op == last_user) {
          break;
        }
      }

      // set the new last user as the correct 
      // return value of the user
180
      last_user = only_user->getResult(idx);
181
182
183
184
185
      users = last_user.getUsers();
    }
    result_qubit_vals.push_back(last_user);
  }

186
187
  // std::cout << "GATE " << gate_call_name << " has " << result_qubit_vals.size()
  //           << " to return.\n";
188
189
190
  // for (auto v : result_qubit_vals) {
  //   v.dump();
  // }
191

192
193
  builder.create<mlir::ReturnOp>(builder.getUnknownLoc(),
                                 llvm::makeArrayRef(result_qubit_vals));
194
195
196

  m_module.push_back(function);

197
  builder.restoreInsertionPoint(main_block);
198
199
200
201
202
203
204

  symbol_table.exit_scope();

  symbol_table.add_seen_function(gate_call_name, function);

  return 0;
}
205
}  // namespace qcor