Unverified Commit d110cb18 authored by Mccaskey, Alex's avatar Mccaskey, Alex Committed by GitHub
Browse files

Merge pull request #50 from tnguyen-ornl/pyxasm_ctrl_ajoint

More work on ctrl and adjoint
parents 6105e4ae dcdd6c93
Pipeline #123924 passed with stage
in 17 minutes and 44 seconds
...@@ -31,7 +31,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor { ...@@ -31,7 +31,9 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
pyxasmParser::Atom_exprContext *context) override { pyxasmParser::Atom_exprContext *context) override {
// Handle kernel::ctrl(...), kernel::adjoint(...) // Handle kernel::ctrl(...), kernel::adjoint(...)
if (!context->trailer().empty() && context->trailer()[0]->getText() == ".ctrl") { if (!context->trailer().empty() &&
(context->trailer()[0]->getText() == ".ctrl" ||
context->trailer()[0]->getText() == ".adjoint")) {
std::cout << "HELLO: " << context->getText() << "\n"; std::cout << "HELLO: " << context->getText() << "\n";
std::cout << context->trailer()[0]->getText() << "\n"; std::cout << context->trailer()[0]->getText() << "\n";
std::cout << context->atom()->getText() << "\n"; std::cout << context->atom()->getText() << "\n";
...@@ -41,7 +43,10 @@ class pyxasm_visitor : public pyxasmBaseVisitor { ...@@ -41,7 +43,10 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
auto arg_list = context->trailer()[1]->arglist(); auto arg_list = context->trailer()[1]->arglist();
std::stringstream ss; std::stringstream ss;
ss << context->atom()->getText() << "::ctrl(parent_kernel"; // Remove the first '.' character
const std::string methodName = context->trailer()[0]->getText().substr(1);
ss << context->atom()->getText() << "::" << methodName
<< "(parent_kernel";
for (int i = 0; i < arg_list->argument().size(); i++) { for (int i = 0; i < arg_list->argument().size(); i++) {
ss << ", " << arg_list->argument(i)->getText(); ss << ", " << arg_list->argument(i)->getText();
} }
...@@ -50,7 +55,6 @@ class pyxasm_visitor : public pyxasmBaseVisitor { ...@@ -50,7 +55,6 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
std::cout << "HELLO SS: " << ss.str() << "\n"; std::cout << "HELLO SS: " << ss.str() << "\n";
result.first = ss.str(); result.first = ss.str();
return 0; return 0;
} }
if (context->atom()->NAME() != nullptr) { if (context->atom()->NAME() != nullptr) {
auto inst_name = context->atom()->NAME()->getText(); auto inst_name = context->atom()->NAME()->getText();
...@@ -138,6 +142,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor { ...@@ -138,6 +142,29 @@ class pyxasm_visitor : public pyxasmBaseVisitor {
<< context->trailer()[0]->arglist()->argument(2)->getText() << context->trailer()[0]->arglist()->argument(2)->getText()
<< ");\n"; << ");\n";
result.first = ss.str(); result.first = ss.str();
}
// Handle potential name collision: user-defined kernel having the
// same name as an XACC circuit: e.g. common names such as qft, iqft
// Note: these circuits (except exp_i_theta) don't have QRT
// equivalents.
// Condition: first argument is a qubit register
else if (!context->trailer()[0]->arglist()->argument().empty() &&
xacc::container::contains(bufferNames, context->trailer()[0]
->arglist()
->argument(0)
->getText())) {
std::stringstream ss;
// Use the kernel call with a parent kernel arg.
ss << inst_name << "(parent_kernel, ";
const auto &argList = context->trailer()[0]->arglist()->argument();
for (size_t i = 0; i < argList.size(); ++i) {
ss << argList[i]->getText();
if (i != argList.size() - 1) {
ss << ", ";
}
}
ss << ");\n";
result.first = ss.str();
} else { } else {
xacc::error("Composite instruction '" + inst_name + xacc::error("Composite instruction '" + inst_name +
"' is not currently supported."); "' is not currently supported.");
......
...@@ -200,10 +200,12 @@ class qjit(object): ...@@ -200,10 +200,12 @@ class qjit(object):
# Handle nested kernels: # Handle nested kernels:
dependency = [] dependency = []
for kernelName in self.__compiled__kernels: for kernelName in self.__compiled__kernels:
kernelCall = kernelName + '('
# Check that this kernel *calls* a previously-compiled kernel: # Check that this kernel *calls* a previously-compiled kernel:
# pattern: "<white space> kernel(" # pattern: "<white space> kernel(" OR "kernel.adjoint(" OR "kernel.ctrl("
if re.search(r"\b" + re.escape(kernelCall), self.src): kernelCall = kernelName + '('
kernelAdjCall = kernelName + '.adjoint('
kernelCtrlCall = kernelName + '.ctrl('
if re.search(r"\b" + re.escape(kernelCall) + '|' + re.escape(kernelAdjCall) + '|' + re.escape(kernelCtrlCall), self.src):
dependency.append(kernelName) dependency.append(kernelName)
...@@ -312,6 +314,10 @@ class qjit(object): ...@@ -312,6 +314,10 @@ class qjit(object):
print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.') print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
exit(1) exit(1)
def adjoint(self, *args):
print('This is an internal API call and will be translated to C++ via the QJIT.\nIt can only be called from within another quantum kernel.')
exit(1)
def __call__(self, *args): def __call__(self, *args):
""" """
Execute the decorated quantum kernel. This will directly Execute the decorated quantum kernel. This will directly
......
...@@ -221,7 +221,7 @@ class TestSimpleKernelJIT(unittest.TestCase): ...@@ -221,7 +221,7 @@ class TestSimpleKernelJIT(unittest.TestCase):
def test_iqft_kernel(self): def test_iqft_kernel(self):
@qjit @qjit
def iqft(q : qreg, startIdx : int, nbQubits : int): def inverse_qft(q : qreg, startIdx : int, nbQubits : int):
for i in range(nbQubits/2): for i in range(nbQubits/2):
Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1]) Swap(q[startIdx + i], q[startIdx + nbQubits - i - 1])
...@@ -235,7 +235,7 @@ class TestSimpleKernelJIT(unittest.TestCase): ...@@ -235,7 +235,7 @@ class TestSimpleKernelJIT(unittest.TestCase):
H(q[startIdx+nbQubits-1]) H(q[startIdx+nbQubits-1])
q = qalloc(5) q = qalloc(5)
comp = iqft.extract_composite(q, 0, 5) comp = inverse_qft.extract_composite(q, 0, 5)
print(comp.toString()) print(comp.toString())
self.assertEqual(comp.nInstructions(), 17) self.assertEqual(comp.nInstructions(), 17)
self.assertEqual(comp.getInstruction(0).name(), "Swap") self.assertEqual(comp.getInstruction(0).name(), "Swap")
...@@ -279,24 +279,52 @@ class TestSimpleKernelJIT(unittest.TestCase): ...@@ -279,24 +279,52 @@ class TestSimpleKernelJIT(unittest.TestCase):
@qjit @qjit
def qpe(q : qreg): def qpe(q : qreg):
nq = q.size() nq = q.size()
X(q[nq - 1])
for i in range(q.size()-1): for i in range(q.size()-1):
H(q[i]) H(q[i])
bitPrecision = nq-1 bitPrecision = nq-1
for i in range(bitPrecision): for i in range(bitPrecision):
nbCalls = 1 << i nbCalls = 2**i
for j in range(nbCalls): for j in range(nbCalls):
ctrl_bit = i ctrl_bit = i
oracle.ctrl(ctrl_bit, q) oracle.ctrl(ctrl_bit, q)
# Inverse QFT on the counting qubits
iqft(q, 0, bitPrecision)
for i in range(bitPrecision): for i in range(bitPrecision):
Measure(q[i]) Measure(q[i])
q = qalloc(4) q = qalloc(4)
qpe(q) qpe(q)
print(q.counts()) print(q.counts())
self.assertEqual(q.counts()['100'], 1024)
def test_adjoint_kernel(self):
@qjit
def test_kernel(q : qreg):
CX(q[0], q[1])
Rx(q[0], 1.234)
T(q[0])
X(q[0])
@qjit
def check_adjoint(q : qreg):
test_kernel.adjoint(q)
q = qalloc(2)
comp = check_adjoint.extract_composite(q)
print(comp.toString())
self.assertEqual(comp.nInstructions(), 4)
# Reverse
self.assertEqual(comp.getInstruction(0).name(), "X")
# Check T -> Tdg
self.assertEqual(comp.getInstruction(1).name(), "Tdg")
self.assertEqual(comp.getInstruction(2).name(), "Rx")
# Check angle -> -angle
self.assertAlmostEqual((float)(comp.getInstruction(2).getParameter(0)), -1.234)
self.assertEqual(comp.getInstruction(3).name(), "CNOT")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -108,7 +108,6 @@ public: ...@@ -108,7 +108,6 @@ public:
} }
auto provider = qcor::__internal__::get_provider(); auto provider = qcor::__internal__::get_provider();
std::reverse(instructions.begin(), instructions.end());
for (int i = 0; i < instructions.size(); i++) { for (int i = 0; i < instructions.size(); i++) {
auto inst = derived.parent_kernel->getInstruction(i); auto inst = derived.parent_kernel->getInstruction(i);
// Parametric gates: // Parametric gates:
...@@ -127,8 +126,13 @@ public: ...@@ -127,8 +126,13 @@ public:
} }
} }
// We update/replace instructions in the derived.parent_kernel composite,
// hence collecting these new instructions and reversing the sequence.
auto new_instructions = derived.parent_kernel->getInstructions();
std::reverse(new_instructions.begin(), new_instructions.end());
// add the instructions to the current parent kernel // add the instructions to the current parent kernel
parent_kernel->addInstructions(instructions); parent_kernel->addInstructions(new_instructions);
// no measures, so no execute // no measures, so no execute
} }
......
...@@ -38,6 +38,11 @@ get_transformation(const std::string &transform_type) { ...@@ -38,6 +38,11 @@ get_transformation(const std::string &transform_type) {
xacc::internal_compiler::compiler_InitializeXACC(); xacc::internal_compiler::compiler_InitializeXACC();
return xacc::getService<xacc::IRTransformation>(transform_type); return xacc::getService<xacc::IRTransformation>(transform_type);
} }
std::shared_ptr<qcor::IRProvider> get_provider() {
return xacc::getIRProvider("quantum");
}
std::shared_ptr<qcor::CompositeInstruction> std::shared_ptr<qcor::CompositeInstruction>
decompose_unitary(const std::string algorithm, UnitaryMatrix &mat, decompose_unitary(const std::string algorithm, UnitaryMatrix &mat,
const std::string buffer_name) { const std::string buffer_name) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment