GateFunction.cpp 17.9 KB
Newer Older
1
#include "GateFunction.hpp"
2
#include <algorithm>
3
#include <ctype.h>
4
#include <memory>
5
#include <string>
6
#include "Function.hpp"
7
8
#include "InstructionIterator.hpp"
#include "IRToGraphVisitor.hpp"
9
10
#include "IRGenerator.hpp"
#include "xacc_service.hpp"
11

12
13
#include "Graph.hpp"

14
15
16
17
18
19
20
#include "JsonVisitor.hpp"

#define RAPIDJSON_HAS_STDSTRING 1
#include "rapidjson/document.h"
#include "rapidjson/prettywriter.h"
using namespace rapidjson;

21
22
23
24
namespace xacc {
namespace quantum {

void GateFunction::mapBits(std::vector<int> bitMap) {
25
26
27
  for (auto i : instructions) {
    i->mapBits(bitMap);
  }
28
29
}

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
void GateFunction::persist(std::ostream &outStream) {
  JsonVisitor<PrettyWriter<StringBuffer>, StringBuffer> visitor(
      shared_from_this());
  outStream << visitor.write();
}

// {
//     "kernels": [
//         {
//             "function": "foo",
//             "instructions": [
//                 {
//                     "gate": "H",
//                     "enabled": true,
//                     "qubits": [
//                         1
//                     ]
//                 },
//                 {
//                     "gate": "CNOT",
//                     "enabled": true,
//                     "qubits": [
//                         0,
//                         1
//                     ]
//                 }
//             ]
//         }
//     ]
// }
void GateFunction::load(std::istream &inStream) {

  std::vector<std::string> irGeneratorNames;
  auto irgens = xacc::getRegisteredIds<xacc::IRGenerator>();
  for (auto &irg : irgens) {
    irGeneratorNames.push_back(irg);
  }

  auto provider = xacc::getService<IRProvider>("gate");
  std::string json(std::istreambuf_iterator<char>(inStream), {});
//   std::cout << "JSON: " << json << "\n";

  Document doc;
  doc.Parse(json);

  auto &kernel = doc["kernels"].GetArray()[0];
  functionName = kernel["function"].GetString();
  auto instructionsArray = kernel["instructions"].GetArray();

  for (int i = 0; i < instructionsArray.Size(); i++) {
    auto &inst = instructionsArray[i];
    auto gname = inst["gate"].GetString();

    bool isAnIRG = false;
    if (std::find(irGeneratorNames.begin(), irGeneratorNames.end(), gname) != irGeneratorNames.end()) {
        // this is an IRG
        isAnIRG = true;
    }

    std::vector<int> qbits;
    auto bitsArray = inst["qubits"].GetArray();
    for (int k = 0; k < bitsArray.Size(); k++) {
      qbits.push_back(bitsArray[k].GetInt());
    }

    std::vector<InstructionParameter> local_parameters;
    auto &paramsArray = inst["parameters"];
    for (int k = 0; k < paramsArray.Size(); k++) {
      auto &value = paramsArray[k];
      if (value.IsInt()) {
        local_parameters.push_back(InstructionParameter(value.GetInt()));
      } else if (value.IsDouble()) {
        local_parameters.push_back(InstructionParameter(value.GetDouble()));
      } else {
        local_parameters.push_back(InstructionParameter(value.GetString()));
      }
    }

    std::shared_ptr<Instruction> instToAdd;
    if (!isAnIRG) {
     instToAdd =
        provider->createInstruction(gname, qbits, local_parameters);
    } else {
        instToAdd = xacc::getService<IRGenerator>(gname);
    }

    auto &optionsObj = inst["options"];
    for (auto itr = optionsObj.MemberBegin(); itr != optionsObj.MemberEnd();
         ++itr) {
      auto &value = optionsObj[itr->name.GetString()];

      if (value.IsInt()) {
        instToAdd->setOption(itr->name.GetString(),
                             InstructionParameter(value.GetInt()));
      } else if (value.IsDouble()) {
        instToAdd->setOption(itr->name.GetString(),
                             InstructionParameter(value.GetDouble()));
      } else {
        instToAdd->setOption(itr->name.GetString(),
                             InstructionParameter(value.GetString()));
      }
    }
    if (!inst["enabled"].GetBool()) {
      instToAdd->disable();
    }

    addInstruction(instToAdd);
  }
}

140
141
void GateFunction::expandIRGenerators(
    std::map<std::string, InstructionParameter> irGenMap) {
142
    std::list<InstPtr> newinsts;
143
144
145
146
147
  for (int idx = 0; idx < nInstructions(); idx++) {
    auto inst = getInstruction(idx);
    auto irg = std::dynamic_pointer_cast<IRGenerator>(inst);
    if (irg) {
      auto evaluated = irg->generate(irGenMap);
148
149
150
151
152
153
    //   replaceInstruction(idx, evaluated);
      for (auto i : evaluated->getInstructions()) {
          newinsts.push_back(i);
      }
    } else {
        newinsts.push_back(inst);
154
155
    }
  }
156
157
158
159
160
161

  instructions.clear();

  for (auto i : newinsts) {
      addInstruction(i);
  }
162
163
164
165
166
167
168
}

bool GateFunction::hasIRGenerators() {
  for (int idx = 0; idx < nInstructions(); idx++) {
    auto inst = getInstruction(idx);
    auto irg = std::dynamic_pointer_cast<IRGenerator>(inst);
    if (irg) {
169
      return true;
170
171
172
173
174
    }
  }
  return false;
}

175
const int GateFunction::nInstructions() { return instructions.size(); }
176

177
std::list<InstPtr> GateFunction::getInstructions() { return instructions; }
178

179
const std::string GateFunction::name() const { return functionName; }
180

181
const std::vector<int> GateFunction::bits() { return std::vector<int>{}; }
182

183
184
185
186
187
188
189
190
191
// lambda functions for determining if an InstructionParameter is a
// number/double or a variable
auto isInt = [](std::string s) {
  try {
    std::stoi(s);
    return true;
  } catch (...) {
    return false;
  }
192
};
193
194
195
196
197
198
199
auto isDouble = [](std::string s) {
  try {
    std::stod(s);
    return true;
  } catch (...) {
    return false;
  }
200
};
201
202
203
204
205
206
auto isNumber = [](std::string s) {
  if (isInt(s) || isDouble(s)) {
    return true;
  } else {
    return false;
  }
207
};
208
209
210
211
212
// Lambda function for extracting the stripped (no mathematical operators or
// numbers/doubles) parameter
auto splitParameter = [](InstructionParameter instParam) {
  std::vector<std::string> split;
  InstructionParameter rawParam;
213
  auto paramStr = mpark::get<std::string>(instParam);
214
215
216
217
218
219
220
221
  std::replace(mpark::get<std::string>(instParam).begin(),
               mpark::get<std::string>(instParam).end(), '-', ' ');
  std::replace(mpark::get<std::string>(instParam).begin(),
               mpark::get<std::string>(instParam).end(), '+', ' ');
  std::replace(mpark::get<std::string>(instParam).begin(),
               mpark::get<std::string>(instParam).end(), '*', ' ');
  std::replace(mpark::get<std::string>(instParam).begin(),
               mpark::get<std::string>(instParam).end(), '/', ' ');
222

223
  auto instParamStr = instParam.as<std::string>();
224
  split = xacc::split(instParamStr, ' ');
225
226
227
228
229
230
  for (auto s : split) {
    if (!isNumber(s) && !s.empty()) {
      rawParam = s;
    }
  }
  return rawParam;
231
232
};

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
const int GateFunction::nLogicalBits() {
  std::set<int> local_bits;
  xacc::InstructionIterator it(shared_from_this());
  while (it.hasNext()) {
    auto nextInst = it.next();
    if (nextInst->isEnabled()) {
      for (auto &i : nextInst->bits()) {
        local_bits.insert(i);
      }
    }
  }
  return local_bits.size();
}

const int GateFunction::nPhysicalBits() {
  int maxBitIdx = 0;
  xacc::InstructionIterator it(shared_from_this());
  while (it.hasNext()) {
    auto nextInst = it.next();
    if (nextInst->isEnabled()) {
      for (auto &i : nextInst->bits()) {
        if (maxBitIdx < i) {
          maxBitIdx = i;
        }
      }
    }
  }

  maxBitIdx++;
  return maxBitIdx;
}
264
void GateFunction::removeInstruction(const int idx) {
265
266
267
  auto instruction = getInstruction(idx);
  // Check to see if instruction being removed is parameterized
  if (instruction->isParameterized() &&
268
      instruction->getParameter(0).isVariable()) {
269
270
271
272
273
274
275
276
277
278
279
280
    // Get InstructionParameter of instruction being removed
    bool dupParam = false;
    // strip the parameter of mathematical operators and numbers/doubles
    InstructionParameter strippedParam =
        splitParameter(instruction->getParameter(0));
    // check if the InstructionParameter is a duplicate
    for (auto i : instructions) {
      if (i->isParameterized() &&
          strippedParam == splitParameter(i->getParameter(0)) &&
          instruction != i) {
        dupParam = true;
      }
281
    }
282
283
284
285
286
287
288
289
290
    // If there are no parameters shared, then remove the parameter
    if (!dupParam) {
      parameters.erase(
          std::remove(parameters.begin(), parameters.end(), strippedParam),
          parameters.end());
    }
  }
  // Remove instruction
  instructions.remove(getInstruction(idx));
291
}
292
293

void GateFunction::addInstruction(InstPtr instruction) {
294
295
296
297
298
  // Check to see if new GateInstruction is parameterized and there is only 1
  // parameter
  if (instruction->isParameterized() && instruction->nParameters() <= 1) {
    xacc::InstructionParameter param = instruction->getParameter(0);
    // Check to see if parameter is a string
299
    if (param.isVariable()) {
300
301
302
303
304
305
      // check to see if the new parameter is a duplicate parameter
      bool dupParam = false;
      // strip the parameter of mathematical operators and numbers/doubles
      InstructionParameter strippedParam = splitParameter(param);
      // check if the instruction parameter is a duplicate
      for (auto p : parameters) {
306
        if (p.as<std::string>() == strippedParam.as<std::string>()) {
307
308
          dupParam = true;
        }
309
      }
310
311
312
313
314
315
316
317
      // if new parameter is not a duplicate, add the stripped version
      if (!dupParam) {
        parameters.push_back(strippedParam);
      }
    }
  }
  // Add the GateInstruction
  instructions.push_back(instruction);
318
}
319

320
void GateFunction::replaceInstruction(const int idx, InstPtr replacingInst) {
321
322
323
324
  auto currentInst = getInstruction(idx);
  // Check if the GateInstruction being replaced is parameterized and if
  // parameter is a string
  if (currentInst->isParameterized() &&
325
      currentInst->getParameter(0).isVariable()) {
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    // strip the current InstructionParameter of mathematical operators and
    // numbers
    InstructionParameter strippedCurrent =
        splitParameter(currentInst->getParameter(0));
    bool dupParam = false;
    for (auto i : instructions) {
      // see if current instruction has a duplicate parameter
      if (strippedCurrent == splitParameter(i->getParameter(0)) &&
          currentInst != i) {
        dupParam = true;
      }
    }
    // Check if new instruction is parameterized and if parameter is a string
    if (replacingInst->isParameterized() &&
340
        replacingInst->getParameter(0).isVariable()) {
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
      InstructionParameter strippedNew =
          splitParameter(replacingInst->getParameter(0));
      // Check if current parameter is a duplicate parameter
      bool newDupParam = false;
      for (auto i : instructions) {
        // see if new param has a duplicate parameter
        if (strippedNew == splitParameter(i->getParameter(0))) {
          newDupParam = true;
        }
      }
      // Check if old GateInstruction parameter is different than new
      // GateInstruction parameter
      if (strippedCurrent != strippedNew) {
        if (!dupParam) {
          if (!newDupParam) {
            // if the current GateInstruction and the new GateInstruction both
            // do not have a parameter already in this GateFunction -> replace
            // parameter
            std::replace(parameters.begin(), parameters.end(), strippedCurrent,
                         strippedNew);
          } else {
            // If the current GateInstruction is not a duplicate but the new one
            // is  -> remove old parameter
            parameters.erase(std::remove(parameters.begin(), parameters.end(),
                                         strippedCurrent));
          }
        } else {
          // if the current GateInstruction is a duplicate but the new one is
          // not -> add new parameter
          if (!newDupParam) {
            parameters.push_back(strippedNew);
          }
        }
      }
      // if the current GateInstruction parameter is parameterized but the new
      // one is not, check if the parameter is a duplicate and erase (or not)
    } else if (!dupParam) {
      parameters.erase(
          std::remove(parameters.begin(), parameters.end(), strippedCurrent));
    }
    // if the current GateInstruction is not parameterized and the new
    // GateInstruction is, try to add parameter
  } else if (replacingInst->isParameterized() &&
384
             replacingInst->getParameter(0).isVariable()) {
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    InstructionParameter strippedNew =
        splitParameter(replacingInst->getParameter(0));
    // Check if new parameter is a duplicate parameter
    bool newDupParam = false;
    for (auto i : instructions) {
      if (i->isParameterized() &&
          strippedNew == splitParameter(i->getParameter(0))) {
        newDupParam = true;
      }
    }
    // if new parameter is not a duplicate, add it to the parameters list
    if (!newDupParam) {
      parameters.push_back(strippedNew);
    }
  }
  // Replace old GateInstruction with new GateInstruction
  std::replace(instructions.begin(), instructions.end(), getInstruction(idx),
               replacingInst);
403
404
405
}

void GateFunction::insertInstruction(const int idx, InstPtr newInst) {
406
  // Check if new GateInstruction is parameterized with 1 parameter
407
  if (newInst->isParameterized() && !newInst->isComposite()) {
408
409
    xacc::InstructionParameter param = newInst->getParameter(0);
    // Check if new parameter is a string
410
    if (param.isVariable()) {
411
412
413
414
415
416
      // If new parameter is not already in parameter vector -> add parameter to
      // GateFunction
      bool dupParam = false;
      // strip the parameter of mathematical operators and numbers/doubles
      InstructionParameter strippedParam = splitParameter(param);
      for (auto p : parameters) {
417
418
        if (mpark::get<std::string>(p) ==
            mpark::get<std::string>(strippedParam)) {
419
          dupParam = true;
420
        }
421
422
423
424
425
426
      }
      // If new parameter is not already in parameter vector -> add parameter to
      // GateFunction
      if (!dupParam) {
        parameters.push_back(param);
      }
427
    }
428
429
430
431
  }
  // Insert new GateInstruction to instructions vector
  auto iter = std::next(instructions.begin(), idx);
  instructions.insert(iter, newInst);
432
433
434
}

InstPtr GateFunction::getInstruction(const int idx) {
435
436
437
438
439
440
441
442
  InstPtr i;
  if (instructions.size() > idx) {
    i = *std::next(instructions.begin(), idx);
  } else {
    xacc::error("GateFunction getInstruction invalid instruction index - " +
                std::to_string(idx) + ".");
  }
  return i;
443
444
}

445
446
447
448
void GateFunction::setParameter(const int idx, InstructionParameter &p) {
  if (idx + 1 > parameters.size()) {
    XACCLogger::instance()->error("Invalid Parameter requested.");
  }
449

450
  parameters[idx] = p;
451
452
453
}

InstructionParameter GateFunction::getParameter(const int idx) const {
454
455
456
  if (idx + 1 > parameters.size()) {
    XACCLogger::instance()->error("Invalid Parameter requested.");
  }
457

458
  return parameters[idx];
459
460
461
}

void GateFunction::addParameter(InstructionParameter instParam) {
462
  parameters.push_back(instParam);
463
464
465
}

std::vector<InstructionParameter> GateFunction::getParameters() {
466
  return parameters;
467
468
}

469
bool GateFunction::isParameterized() { return nParameters() > 0; }
470

471
const int GateFunction::nParameters() { return parameters.size(); }
472

473
474
475
const std::string GateFunction::toString(const std::string &bufferVarName) {
  std::string retStr = "";
  for (auto i : instructions) {
476
    if (i->isEnabled()) {
477
      retStr += i->toString(bufferVarName) + "\n";
478
    }
479
480
  }
  return retStr;
481
482
}

483
484
std::shared_ptr<Function>
GateFunction::operator()(const std::vector<double> &params) {
485
486
487
488
489
490
  if (params.size() != nParameters()) {
    xacc::error("Invalid GateFunction evaluation: number "
                "of parameters don't match. " +
                std::to_string(params.size()) + ", " +
                std::to_string(nParameters()));
  }
491

492
  std::vector<double> p = params;
493
494
495
496
497
  symbol_table_t symbol_table;
  symbol_table.add_constants();
  std::vector<std::string> variableNames;
  std::vector<double> values;
  for (int i = 0; i < params.size(); i++) {
498
    auto var = getParameter(i).as<std::string>();
499
    variableNames.push_back(var);
500
    symbol_table.add_variable(var, p[i]);
501
  }
502

503
  auto compileExpression = [&](InstructionParameter &p) -> double {
504
    auto expression = mpark::get<std::string>(p);
505
506
507
508
509
510
    expression_t expr;
    expr.register_symbol_table(symbol_table);
    parser_t parser;
    parser.compile(expression, expr);
    return expr.value();
  };
511

512
513
  auto gateRegistry = xacc::getService<IRProvider>("gate");
  auto evaluatedFunction = std::make_shared<GateFunction>("evaled_" + name());
514

515
516
517
518
  // Walk the IR Tree, handle functions and instructions differently
  for (auto inst : getInstructions()) {
    if (inst->isComposite()) {
      // If a Function, call this method recursively
519
      auto evaled = std::dynamic_pointer_cast<Function>(inst)->operator()(p);
520
521
522
523
      evaluatedFunction->addInstruction(evaled);
    } else {
      // If a concrete GateInstruction, then check that it
      // is parameterized and that it has a string parameter
524
      if (inst->isParameterized() && inst->getParameter(0).isVariable()) {
525
526
        InstructionParameter p = inst->getParameter(0);
        std::stringstream s;
527
        s << p.toString();
528
529
530
531
532
533
534
535
536
537
538
539
        auto val = compileExpression(p);
        InstructionParameter pnew(val);
        auto updatedInst =
            gateRegistry->createInstruction(inst->name(), inst->bits());
        updatedInst->setParameter(0, pnew);
        evaluatedFunction->addInstruction(updatedInst);
      } else {
        evaluatedFunction->addInstruction(inst);
      }
    }
  }
  return evaluatedFunction;
540
}
541
542
const int GateFunction::depth() { return toGraph()->depth() - 2; }

543
544
545
546
547
const std::string GateFunction::persistGraph() {
  std::stringstream s;
  toGraph()->write(s);
  return s.str();
}
548

549
std::shared_ptr<Graph> GateFunction::toGraph() {
550

551
552
553
554
555
556
557
558
559
  // Compute number of qubits
  int maxBit = 0;
  InstructionIterator it(shared_from_this());
  while (it.hasNext()) {
    auto inst = it.next();
    for (auto &b : inst->bits()) {
      if (b > maxBit) {
        maxBit = b;
      }
560
    }
561
562
563
564
565
566
567
568
  }

  auto visitor = std::make_shared<IRToGraphVisitor>(maxBit + 1);
  InstructionIterator it2(shared_from_this());
  while (it2.hasNext()) {
    auto inst = it2.next();
    if (inst->isEnabled()) {
      inst->accept(visitor);
569
    }
570
  }
571

572
  return visitor->getGraph();
573
574
}

575
576
} // namespace quantum
} // namespace xacc