Commit 818f4d16 authored by Nguyen, Thien Minh's avatar Nguyen, Thien Minh
Browse files

Add int& type support



Add bit-flip ftqc test (syndome value returned as int&)

Signed-off-by: default avatarThien Nguyen <nguyentm@ornl.gov>
parent 48217869
Loading
Loading
Loading
Loading
+12 −6
Original line number Diff line number Diff line
@@ -152,14 +152,19 @@ class qjit(object):

        # Construct the C++ kernel arg string
        cpp_arg_str = ''
        self.float_ref_args = []
        self.ref_type_args = []
        self.qRegName = ''
        for arg, _type in self.type_annotations.items():
            if _type is FLOAT_REF:
                self.float_ref_args.append(arg)
                self.ref_type_args.append(arg)
                cpp_arg_str += ',' + \
                    'double& ' + arg
                continue
            if _type is INT_REF:
                self.ref_type_args.append(arg)
                cpp_arg_str += ',' + \
                    'int& ' + arg
                continue
            if str(_type) not in self.allowed_type_cpp_map:
                print('Error, this quantum kernel arg type is not allowed: ', str(_type))
                exit(1)
@@ -204,7 +209,7 @@ class qjit(object):

        # Persist *pass by ref* variables to the accelerator buffer:
        persist_by_ref_var_code = ''
        for ref_var in self.float_ref_args:
        for ref_var in self.ref_type_args:
            persist_by_ref_var_code += '\npersist_var_to_qreq(\"' + ref_var + '\", ' + ref_var + ', '+ self.qRegName + ')' 

        # Create the qcor quantum kernel function src for QJIT and the Clang syntax handler
@@ -348,7 +353,7 @@ class qjit(object):
        
        # Update any *by-ref* arguments: annotated with the custom type: FLOAT_REF, INT_REF, etc.
        # If there are *pass-by-ref* variables:
        if len(self.float_ref_args) > 0:
        if len(self.ref_type_args) > 0:
            # Access the register:
            qReg = args_dict[self.qRegName]
            # Retrieve *original* variable names of the argument pack
@@ -365,13 +370,14 @@ class qjit(object):
                    caller_var_names.append(i)
            
            # Get the updated value:
            for by_ref_var in self.float_ref_args:
            for by_ref_var in self.ref_type_args:
                updated_var = qReg.getInformation(by_ref_var)
                caller_var_name = caller_var_names[self.arg_names.index(by_ref_var)]
                if (caller_var_name in inspect.stack()[1][0].f_globals):
                    # Make sure it is the correct type:
                    by_ref_instane = inspect.stack()[1][0].f_globals[caller_var_name] 
                    if (isinstance(by_ref_instane, float)):
                    # We only support float and int atm
                    if (isinstance(by_ref_instane, float) or isinstance(by_ref_instane, int)):
                        inspect.stack()[1][0].f_globals[caller_var_name] = updated_var

        return
+62 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import unittest
from qcor import *

float_result = 0.0
int_result = 0

class TestKernelFTQC(unittest.TestCase):
    def test_pass_by_ref(self):
@@ -65,6 +66,67 @@ class TestKernelFTQC(unittest.TestCase):
        # No change because we pass by value
        self.assertAlmostEqual(float_result, 0.0)

    def test_bit_flip_code(self):
        @qjit
        def encodeLogicalQubit(q : qreg):
            CX(q[0], q[1])
            CX(q[0], q[2])

        @qjit
        def measureSyndrome(q : qreg, syndrome: INT_REF):
            # Make sure to clear syndrome
            syndrome = 0
            ancIdx = 3
            CX(q[0], q[ancIdx])
            CX(q[1], q[ancIdx])
            parity01 = Measure(q[ancIdx])
            if parity01: 
                # Reset anc qubit
                X(q[ancIdx])
                syndrome = syndrome + 1
            
            CX(q[1], q[ancIdx])
            CX(q[2], q[ancIdx])
            parity12 = Measure(q[ancIdx])
            if parity12:
                #Reset anc qubit
                X(q[ancIdx])
                syndrome = syndrome + 2

        @qjit
        def reset_all_qubits(q : qreg):
            for i in range(q.size()):
                if Measure(q[i]):
                    X(q[i])
        
        @qjit
        def testBitflipCode(q : qreg, qIdx: int, syndrome: INT_REF):
            H(q[0])
            encodeLogicalQubit(q)      
            # Apply error:
            if qIdx >= 0:
                X(q[qIdx])
            measureSyndrome(q, syndrome)
            reset_all_qubits(q)

        # Allocate 4 qubits: 3 qubits + 1 ancilla
        q = qalloc(4)
        global int_result
        # Init a minus value to make sure it got updated
        int_result = -1
        # No error: 
        testBitflipCode(q, -1, int_result)
        self.assertEqual(int_result, 0)
        testBitflipCode(q, 0, int_result)
        # X @ q0 -> Syndrome = 10
        self.assertEqual(int_result, 1)
        testBitflipCode(q, 1, int_result)
        # X @ q1 -> Syndrome = 11 == 3
        self.assertEqual(int_result, 3)
        testBitflipCode(q, 2, int_result)
        # X @ q2 -> Syndrome = 01 == 2
        self.assertEqual(int_result, 2)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()