ScaffoldCompiler.cpp 10.7 KB
Newer Older
Mccaskey, Alex's avatar
Mccaskey, Alex committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
/***********************************************************************************
 * Copyright (c) 2016, UT-Battelle
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *   * Redistributions of source code must retain the above copyright
 *     notice, this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright
 *     notice, this list of conditions and the following disclaimer in the
 *     documentation and/or other materials provided with the distribution.
 *   * Neither the name of the xacc nor the
 *     names of its contributors may be used to endorse or promote products
 *     derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * Contributors:
 *   Initial API and implementation - Alex McCaskey
 *
 **********************************************************************************/
31
32
#include "ScaffoldCompiler.hpp"
#include <regex>
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

namespace xacc {

namespace quantum {

void ScaffoldCompiler::modifySource() {

	// Here we assume we've been given just
	// the body of the quantum code, as part
	// of an xacc __qpu__ kernel function.

	// First off, replace __qpu__ with 'module '
	kernelSource.erase(kernelSource.find("__qpu__"), 7);
	kernelSource = std::string("module ") + kernelSource;

49
	std::string qubitAllocationLine, cbitAllocationLine, cbitVarName, qbitVarName;
Mccaskey, Alex's avatar
Mccaskey, Alex committed
50
	std::map<int, int> cbitToQubit;
51

52
53
54
55
56
57
58
59
60
61
62
	if (typeToVarKernelArgs.find("qbit") != typeToVarKernelArgs.end()) {
		qubitAllocationLine = "qbit " + typeToVarKernelArgs["qbit"] + ";\n";
		qbitVarName = typeToVarKernelArgs["qbit"].substr(0, typeToVarKernelArgs["qbit"].find_first_of("["));
	} else {
		std::regex qbitName("qbit\\s.*");
		qubitAllocationLine = (*std::sregex_iterator(kernelSource.begin(),
				kernelSource.end(), qbitName)).str() + "\n";
		std::vector<std::string> splitQbit;
		boost::split(splitQbit, qubitAllocationLine, boost::is_any_of(" "));
		qbitVarName = splitQbit[1].substr(0, splitQbit[1].find_first_of("["));
	}
63

64
    // Create Cbit to Qbit mapping
Mccaskey, Alex's avatar
Mccaskey, Alex committed
65
66
67
68
69
70
71
72
73
74
	std::regex cbitName("cbit\\s.*");
	auto it = std::sregex_iterator(kernelSource.begin(), kernelSource.end(),
			cbitName);
	if (it != std::sregex_iterator()) {
		cbitAllocationLine = (*std::sregex_iterator(kernelSource.begin(),
				kernelSource.end(), cbitName)).str() + "\n";
		std::vector<std::string> splitCbit;
		boost::split(splitCbit, cbitAllocationLine, boost::is_any_of(" "));
		cbitVarName = splitCbit[1].substr(0,
				splitCbit[1].find_first_of("["));
75

Mccaskey, Alex's avatar
Mccaskey, Alex committed
76
77
78
79
80
81
		std::regex measurements(".*Meas.*");
		for (auto i = std::sregex_iterator(kernelSource.begin(),
				kernelSource.end(), measurements); i != std::sregex_iterator();
				++i) {
			auto measurement = (*i).str();
			boost::trim(measurement);
82

Mccaskey, Alex's avatar
Mccaskey, Alex committed
83
84
85
86
87
88
89
90
			boost::erase_all(measurement, "MeasZ");
			boost::erase_all(measurement, "(");
			boost::erase_all(measurement, ")");
			boost::erase_all(measurement, cbitVarName);
			boost::erase_all(measurement, qbitVarName);
			// Should now have [#] = [#]
			boost::erase_all(measurement, "[");
			boost::erase_all(measurement, "]");
91

Mccaskey, Alex's avatar
Mccaskey, Alex committed
92
93
94
95
96
97
			std::vector<std::string> splitVec;
			boost::split(splitVec, measurement, boost::is_any_of("="));
			auto cbit = splitVec[0];
			auto qbit = splitVec[1];
			boost::trim(cbit);
			boost::trim(qbit);
98

Mccaskey, Alex's avatar
Mccaskey, Alex committed
99
100
101
102
			cbitToQubit.insert(
					std::make_pair(std::stoi(cbit), std::stoi(qbit)));
		}
	}
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

	// conditional on measurements
	// FIXME FOR NOW WE ONLY ACCEPT format
	// 'if (creg[0] == 1) GATEOP'
	int counter = 1;
	std::vector<std::string> ifLines;
	std::regex ifstmts("if\\s?\\(\\w+\\[\\w+\\]\\s?=.*\\s?\\)\\s?");
	for (auto i = std::sregex_iterator(kernelSource.begin(), kernelSource.end(),
			ifstmts); i != std::sregex_iterator(); ++i) {
		std::vector<std::string> splitVec;
		std::string ifLine = (*i).str();
		boost::trim(ifLine);
		boost::split(splitVec, ifLine, boost::is_any_of(" "));
		conditionalCodeSegments.push_back("module foo" + std::to_string(counter) + "() {\n"
				+ qubitAllocationLine + "   " + splitVec[splitVec.size()-1] + ";\n}\nint main() {\n"
						"   foo" + std::to_string(counter) + "();\n}");
		counter++;

		ifLines.push_back(ifLine + ";\n");

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        // Also get which cbit this conditional code belongs to
		int measurementGateId = -1;
		for (auto s : splitVec) {
			if (boost::contains(s, cbitVarName)) {
				boost::erase_all(s, "(");

				boost::erase_all(s, cbitVarName);
				boost::erase_all(s, "[");
				boost::erase_all(s, "]");

				conditionalCodeSegmentActingQubits.push_back(cbitToQubit[std::stoi(s)]);

				break;
			}
		}

139
140
	}

141
142
143
	// Erase the if lines from the main source
	// they are going to be represented with
	// conditional graphs.
144
145
146
147
148
	for (auto s : ifLines) {
		auto idx = kernelSource.find(s);
		kernelSource.erase(idx, s.size());
	}

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
	std::regex gates("\\s+*\\w+\\(.*");
	it = std::sregex_iterator(kernelSource.begin(), kernelSource.end(),
			gates);
	counter = 0;
	for (auto it = std::sregex_iterator(kernelSource.begin(), kernelSource.end(),
			gates); it != std::sregex_iterator(); ++it) {
		auto gateLine = (*it).str();
		auto firstParen = gateLine.find_first_of('(');
		auto secondParen = gateLine.find_first_of(')', firstParen);
		auto functionArguments = gateLine.substr(firstParen+1, (secondParen-firstParen)-1);
		auto args = gateLine.substr(firstParen+1, (secondParen-firstParen)-1);
		std::vector<std::string> splitArgs;
		boost::split(splitArgs, args, boost::is_any_of(","));
		std::vector<std::string> params;
		for (auto a : splitArgs) {
			boost::trim(a);
			if (!boost::contains(a, qbitVarName)) {
				// This is a gate parameter... What do we do with it?
				params.push_back(a);
			}
		}
		gateIdToParameterMap.insert(std::make_pair(counter, params));
171
		counter++;
172
173
	}

174
175
176
177
178
179
	// Get the kernel name
	std::regex functionName("((\\w+)\\s*\\()\\s*");
	auto begin = std::sregex_iterator(kernelSource.begin(), kernelSource.end(),
			functionName);
	std::string fName = (*begin).str();

180
181
182
183
184
185
186
	std::string varAllocation = "", fargs;
	for (auto i : orderOfArgs) {
		auto key = i;
		auto value = typeToVarKernelArgs[i];
		if ("qbit" == key) {
			varAllocation += key + " " + value + ";\n   ";
			fargs += value.substr(0, value.find_first_of("[")) + ",";
187
		} else {
188
189
190

			varAllocation += key + " " + value + " = 0;\n   ";
			fargs += value + ",";
191
192
193
194
195
196
197
198
		}
	}

	if (!fargs.empty()) {
		fargs = fargs.substr(0, fargs.size()-1);
		boost::replace_first(fName, "(", "(" + fargs);
	}

199
	// Now wrap in a main function for ScaffCC
200
	kernelSource = kernelSource + std::string("\nint main() {\n   ") + varAllocation + fName
201
202
			+ std::string(");\n}");

203
//	std::cout << "\n" << kernelSource << "\n";
204
205
}

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250


std::shared_ptr<IR> ScaffoldCompiler::compile(const std::string& src,
		std::shared_ptr<IAccelerator> acc) {

	kernelSource = src;

	kernelArgsToMap();

	// Set the accelerator
	accelerator = acc;

	// Get the bit variable type string
	auto bitTypeStr = getBitType();

	// Get the qubit variable name, if it exists
	std::string varName;
	for (auto it = typeToVarKernelArgs.begin(); it != typeToVarKernelArgs.end();
			it++) {
		if (boost::contains(it->first, bitTypeStr)) {
			varName = it->second;
			auto nBits = accelerator->getBufferSize(varName);
			boost::replace_first(kernelSource,
					std::string(bitTypeStr + " " + varName),
					std::string(
							bitTypeStr + " " + varName + "["
									+ std::to_string(nBits) + "]"));

			// Replace the varname in the map with varName[#]
			typeToVarKernelArgs[bitTypeStr] = varName + "["
					+ std::to_string(nBits) + "]";
		}
	}

	return compile(kernelSource);
}
std::shared_ptr<IR> ScaffoldCompiler::compile(const std::string& src) {

	kernelSource = src;

	if (!accelerator) {
		kernelArgsToMap();
	}

	modifySource();
251

252
253
254
	// Create an instance of our ScaffCC API
	// so that we can interact with a locally installed
	// scaffcc executable
255
	scaffold::ScaffCCAPI scaffcc;
256
	using ScaffoldGraphIR = GraphIR<QuantumCircuit>;
257

258
259
	// Compile the source code and return the QASM form
	// This will throw if it fails.
260
261
	auto qasm = scaffcc.getFlatQASMFromSource(kernelSource);

262
263
264
	// Get the Qasm as a Graph...
	auto circuitGraph = QasmToGraph::getCircuitGraph(qasm);

265
266
267
268
269
270
	// HERE we have main circuit graph, before conditional
	// if branches... So then for each conditional code statement,
	// get its circuit graph and add it to the main graph after
	// the addition of a COND conditional node that will enable the
	// conditional nodes if the measured cbit is a 1.

271
272
	// Get measurement acting qubits

273
274
275
	// So a COND node needs to know the gate id of the measurement gate
	// and the nodes to mark enabled if the measurement is a 1,
	if (!conditionalCodeSegments.empty()) {
276
		std::vector<QuantumCircuit> condGraphs;
277
278
279
280
281
282
		for (auto cond : conditionalCodeSegments) {
			auto condQasm = scaffcc.getFlatQASMFromSource(cond);
			auto g = QasmToGraph::getCircuitGraph(condQasm);
			condGraphs.push_back(g);
		}

283
284
		QasmToGraph::linkConditionalQasm(circuitGraph, condGraphs,
				conditionalCodeSegmentActingQubits);
285
286
	}

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
	// Modify the graph to reflect gate dependence on runtime variables.
	int counter = 0;
	for (int i = 0; i < circuitGraph.order(); i++) {
		auto props = circuitGraph.getVertexProperties(i);
		auto gateName = std::get<0>(props);
		if (gateName != "InitialState" && gateName != "FinalState"
				&& gateName != "conditional") {
			std::vector<std::string> possibleParams = std::get<5>(props);
			if (!possibleParams.empty()) {
				// This is a parameterized gate
				for (int j = 0; j < possibleParams.size(); j++) {
					std::get<5>(circuitGraph.getVertex(i).properties)[j] = gateIdToParameterMap[counter][j];
				}
			}
			counter++;
		}
	}
304

305
306
307
308
309
	// Create a GraphIR instance from that graph
	auto graphIR = std::make_shared<ScaffoldGraphIR>(circuitGraph);

	// Return the IR.
	return graphIR;
310
311
312
313
314
}

} // end namespace quantum

} // end namespace xacc
315
316
317
318
//
//// Required in CPP file to be discovered by factory pattern
//REGISTER_XACCOBJECT_WITH_XACCTYPE(xacc::quantum::ScaffoldCompiler, "compiler",
//		"scaffold");
319

320
321
// Register the ScaffoldCompiler with the CompilerRegistry.
static xacc::RegisterCompiler<xacc::quantum::ScaffoldCompiler> X("scaffold");