Commit d600bc20 authored by Mccaskey, Alex's avatar Mccaskey, Alex
Browse files

updates to xasm_listener to handle std::vector<double> with new irv3 work



Signed-off-by: Mccaskey, Alex's avatarAlex McCaskey <mccaskeyaj@ornl.gov>
parent fa2e36b3
......@@ -33,7 +33,15 @@ protected:
std::map<std::size_t, std::string> bitIdxExpressions;
std::map<std::shared_ptr<CompositeArgument>, int> arguments;
// map InstructionParameter index to corresponding CompositeArgument
std::map<int, std::shared_ptr<CompositeArgument>> arguments;
// Map index of InstructionParameters to corresponding
// std::vector<double> index for instructions like this
// foo ( std::vector<double> x ) {
// U(x[0], x[1], x[2]);
//}
std::map<int, int> param_idx_to_vector_idx;
public:
Gate();
......@@ -52,13 +60,21 @@ public:
void addArgument(std::shared_ptr<CompositeArgument> arg,
const int idx_of_inst_param) override {
arguments.insert({arg, idx_of_inst_param});
arguments.insert({idx_of_inst_param, arg});
}
void addIndexMapping(const int idx_1, const int idx_2) override {
param_idx_to_vector_idx.insert({idx_1, idx_2});
}
void applyRuntimeArguments() override {
for (auto &kv : arguments) {
parameters[kv.second] =
kv.first->runtimeValue.get<double>(INTERNAL_ARGUMENT_VALUE_KEY);
if (kv.second->type.find("std::vector<double>") != std::string::npos) {
parameters[kv.first] = kv.second->runtimeValue.get<std::vector<double>>(
INTERNAL_ARGUMENT_VALUE_KEY)[param_idx_to_vector_idx[kv.first]];
} else {
parameters[kv.first] =
kv.second->runtimeValue.get<double>(INTERNAL_ARGUMENT_VALUE_KEY);
}
}
}
......
......@@ -164,8 +164,16 @@ void Exp::applyRuntimeArguments() {
std::string variable_name = arguments[0]->name;
auto x_val =
arguments[0]->runtimeValue.get<double>(INTERNAL_ARGUMENT_VALUE_KEY);
double x_val;
if (arguments[0]->type.find("std::vector<double>") != std::string::npos) {
x_val = arguments[0]->runtimeValue.get<std::vector<double>>(
INTERNAL_ARGUMENT_VALUE_KEY)[vector_mapping[variable_name]];
variable_name =
arguments[0]->name + std::to_string(vector_mapping[variable_name]);
} else {
x_val = arguments[0]->runtimeValue.get<double>(INTERNAL_ARGUMENT_VALUE_KEY);
}
auto observable = arguments[1]->runtimeValue.getPointerLike<Observable>(
INTERNAL_ARGUMENT_VALUE_KEY);
......@@ -265,13 +273,14 @@ void Exp::applyRuntimeArguments() {
}
xasm_src = "__qpu__ void " + name + "(qbit q, double " +
arguments[0]->name + ") {\n" + xasm_src + "}";
variable_name + ") {\n" + xasm_src + "}";
auto xasm = xacc::getCompiler("xasm");
auto tmp = xasm->compile(xasm_src)->getComposites()[0];
for (auto inst : tmp->getInstructions())
for (auto inst : tmp->getInstructions()) {
addInstruction(inst);
}
// store the Rz expressions
for (auto &i : instructions) {
......
......@@ -22,9 +22,15 @@ class Exp : public xacc::quantum::Circuit {
protected:
std::vector<std::string> rz_expressions;
std::shared_ptr<ExpressionParsingUtil> parsingUtil;
std::map<std::string, int> vector_mapping;
public:
Exp() : Circuit("exp_i_theta") {}
void applyRuntimeArguments() override;
void addArgument(std::shared_ptr<CompositeArgument> arg,
const int idx_of_inst_param) override {
arguments.push_back(arg);
vector_mapping.insert({arg->name, idx_of_inst_param});
}
bool expand(const xacc::HeterogeneousMap &runtimeOptions) override;
const std::vector<std::string> requiredKeys() override;
DEFINE_CLONE(Exp);
......
......@@ -143,64 +143,70 @@ TEST(XASMCompilerTester, checkVectorArg) {
<< IR->getComposites()[0]->operator()({2.})->toString() << "\n";
}
TEST(XASMCompilerTester, checkSimpleFor) {
auto compiler = xacc::getCompiler("xasm");
auto IR =
compiler->compile(R"(__qpu__ void testFor(qbit q, std::vector<double> x) {
for (int i = 0; i < 5; i++) {
H(q[i]);
}
for (int i = 0; i < 2; i++) {
Rz(q[i], x[i]);
}
})");
std::cout << "KERNEL\n" << IR->getComposites()[0]->toString() << "\n";
IR = compiler->compile(
R"(__qpu__ void testFor2(qbit q, std::vector<double> x) {
for (int i = 0; i < 5; i++) {
H(q[i]);
Rx(q[i], x[i]);
CX(q[0], q[i]);
}
for (int i = 0; i < 3; i++) {
CX(q[i], q[i+1]);
}
Rz(q[3], 0.22);
for (int i = 3; i > 0; i--) {
CX(q[i-1],q[i]);
}
})");
EXPECT_EQ(1, IR->getComposites().size());
std::cout << "KERNEL\n" << IR->getComposites()[0]->toString() << "\n";
for (auto ii : IR->getComposites()[0]->getVariables())
std::cout << ii << "\n";
EXPECT_EQ(22, IR->getComposites()[0]->nInstructions());
}
TEST(XASMCompilerTester, checkHWEFor) {
auto compiler = xacc::getCompiler("xasm");
auto IR = compiler->compile(R"([&](qbit q, std::vector<double> x) {
for (int i = 0; i < 2; i++) {
Rx(q[i],x[i]);
Rz(q[i],x[2+i]);
}
CX(q[1],q[0]);
for (int i = 0; i < 2; i++) {
Rx(q[i], x[i+4]);
Rz(q[i], x[i+4+2]);
Rx(q[i], x[i+4+4]);
}
})");
EXPECT_EQ(1, IR->getComposites().size());
std::cout << "KERNEL\n" << IR->getComposites()[0]->toString() << "\n";
for (auto ii : IR->getComposites()[0]->getVariables())
std::cout << ii << "\n";
EXPECT_EQ(11, IR->getComposites()[0]->nInstructions());
}
// TEST(XASMCompilerTester, checkSimpleFor) {
// auto compiler = xacc::getCompiler("xasm");
// auto IR =
// compiler->compile(R"(__qpu__ void testFor(qbit q, std::vector<double> x) {
// for (int i = 0; i < 5; i++) {
// H(q[i]);
// }
// for (int i = 0; i < 2; i++) {
// Rz(q[i], x[i]);
// }
// })");
// std::cout << "KERNEL\n" << IR->getComposites()[0]->toString() << "\n";
// xacc::internal_compiler::qreg q(5);
// auto tt = IR->getComposites()[0];
// tt->updateRuntimeArguments(q, std::vector<double>{1.2, 3.4});
// std::cout << "EVALED NEW WAY:\n" << tt->toString() << "\n";
// IR = compiler->compile(
// R"(__qpu__ void testFor2(qbit q, std::vector<double> x) {
// for (int i = 0; i < 5; i++) {
// H(q[i]);
// Rx(q[i], x[i]);
// CX(q[0], q[i]);
// }
// for (int i = 0; i < 3; i++) {
// CX(q[i], q[i+1]);
// }
// Rz(q[3], 0.22);
// for (int i = 3; i > 0; i--) {
// CX(q[i-1],q[i]);
// }
// })");
// EXPECT_EQ(1, IR->getComposites().size());
// std::cout << "KERNEL\n" << IR->getComposites()[0]->toString() << "\n";
// for (auto ii : IR->getComposites()[0]->getVariables())
// std::cout << ii << "\n";
// EXPECT_EQ(22, IR->getComposites()[0]->nInstructions());
// }
// TEST(XASMCompilerTester, checkHWEFor) {
// auto compiler = xacc::getCompiler("xasm");
// auto IR = compiler->compile(R"([&](qbit q, std::vector<double> x) {
// for (int i = 0; i < 2; i++) {
// Rx(q[i],x[i]);
// Rz(q[i],x[2+i]);
// }
// CX(q[1],q[0]);
// for (int i = 0; i < 2; i++) {
// Rx(q[i], x[i+4]);
// Rz(q[i], x[i+4+2]);
// Rx(q[i], x[i+4+4]);
// }
// })");
// EXPECT_EQ(1, IR->getComposites().size());
// std::cout << "KERNEL\n" << IR->getComposites()[0]->toString() << "\n";
// for (auto ii : IR->getComposites()[0]->getVariables())
// std::cout << ii << "\n";
// EXPECT_EQ(11, IR->getComposites()[0]->nInstructions());
// }
TEST(XASMCompilerTester, checkIfStmt) {
......@@ -375,6 +381,53 @@ TEST(XASMCompilerTester, checkIRV3) {
}
}
TEST(XASMCompilerTester, checkIRV3Vector) {
// auto v = xacc::qalloc(1);
// v->setName("v");
// xacc::storeBuffer(v);
// auto v = xacc::internal_compiler::qalloc(1);
xacc::internal_compiler::qreg v(1);
auto H = xacc::quantum::getObservable("pauli", std::string("X0 Y1 + Y0 X1"));
auto compiler = xacc::getCompiler("xasm");
auto IR = compiler->compile(
R"(
__qpu__ void foo_test2 (qbit v, std::vector<double> x, std::shared_ptr<Observable> H) {
Rx(v[0], x[0]);
U(v[0], x[0], x[1], x[2]);
exp_i_theta(v, x[1], H);
}
)");
auto foo_test = IR->getComposite("foo_test2");
std::cout << foo_test->toString() << "\n";
for (auto &val : {2.2, 2.3, 2.4, 2.5}) {
foo_test->updateRuntimeArguments(v, std::vector<double>{val, 3.3, 4.4}, H);
std::cout << foo_test->toString() << "\n\n";
}
IR = compiler->compile(
R"(
__qpu__ void ansatz2(qreg q, std::vector<double> theta) {
X(q[0]);
Ry(q[1], theta[0]);
CX(q[1],q[0]);
}
)");
auto test = IR->getComposites()[0];
std::cout <<" HELLO: " << test->toString() << "\n";
test->updateRuntimeArguments(v, std::vector<double>{.48});
std::cout <<" HELLO: " << test->toString() << "\n";
}
int main(int argc, char **argv) {
xacc::Initialize(argc, argv);
xacc::set_verbose(true);
......
......@@ -24,6 +24,50 @@
using namespace xasm;
namespace xacc {
void XASMListener::for_stmt_update_inst_args(Instruction *inst) {
auto parameters = inst->getParameters();
for (int i = 0; i < parameters.size(); i++) {
if (parameters[i].isVariable()) {
auto arg = function->getArgument(parameters[i].toString());
if (!arg) {
auto param_str = parameters[i].toString();
param_str.erase(std::remove_if(param_str.begin(), param_str.end(),
[](char c) { return !std::isalpha(c); }),
param_str.end());
arg = function->getArgument(param_str);
if (arg && arg->type.find("std::vector<double>") != std::string::npos) {
// this was a container-like type
// give the instruction a mapping from i to vector idx
inst->addIndexMapping(
i, new_var_to_vector_idx[parameters[i].toString()]);
}
}
if (!arg) {
// we may have a case where the parameter is an expression string
for (auto &_arg : function->getArguments()) {
double val;
if (parsingUtil->validExpression(parameters[i].toString(),
{_arg->name})) {
arg = _arg;
break;
}
}
}
if (!arg) {
xacc::error("Cannot associate function argument " +
parameters[i].toString() + " with this instruction " +
currentInstructionName);
}
inst->addArgument(arg, i);
}
}
}
template <>
void XASMListener::createForInstructions<XasmLessThan>(
......@@ -49,7 +93,11 @@ void XASMListener::createForInstructions<XasmLessThan>(
auto copy =
irProvider->createInstruction(next->name(), new_bits, new_params);
copy->setBufferNames(next->getBufferNames());
for_stmt_update_inst_args(copy.get());
instructions.push_back(copy);
}
......@@ -84,6 +132,7 @@ void XASMListener::createForInstructions<XasmGreaterThan>(
auto copy =
irProvider->createInstruction(next->name(), new_bits, new_params);
copy->setBufferNames(next->getBufferNames());
for_stmt_update_inst_args(copy.get());
instructions.push_back(copy);
}
......@@ -95,7 +144,6 @@ void XASMListener::createForInstructions<XasmGreaterThan>(
}
}
template <>
void XASMListener::createForInstructions<XasmGreaterThanOrEqual>(
xasmParser::ForstmtContext *ctx, std::vector<InstPtr> &instructions,
......@@ -119,6 +167,7 @@ void XASMListener::createForInstructions<XasmGreaterThanOrEqual>(
auto copy =
irProvider->createInstruction(next->name(), new_bits, new_params);
copy->setBufferNames(next->getBufferNames());
for_stmt_update_inst_args(copy.get());
instructions.push_back(copy);
}
......@@ -140,7 +189,7 @@ void XASMListener::createForInstructions<XasmLessThanOrEqual>(
auto varName = ctx->varname->getText();
auto comp = ctx->comparator->getText();
auto inc_or_dec = ctx->inc_or_dec->getText();
for (std::size_t i = start; i <=end;) {
for (std::size_t i = start; i <= end;) {
InstructionIterator iter(for_function);
while (iter.hasNext()) {
auto next = iter.next();
......@@ -153,6 +202,7 @@ void XASMListener::createForInstructions<XasmLessThanOrEqual>(
auto copy =
irProvider->createInstruction(next->name(), new_bits, new_params);
copy->setBufferNames(next->getBufferNames());
for_stmt_update_inst_args(copy.get());
instructions.push_back(copy);
}
......@@ -175,13 +225,13 @@ void XASMListener::enterXacckernel(xasmParser::XacckernelContext *ctx) {
function = irProvider->createComposite(ctx->kernelname->getText());
for (int i = 0; i < ctx->typedparam().size(); i++) {
auto type = ctx->typedparam(i)->type()->getText();
auto vname = ctx->typedparam(i)->variable_param_name()->getText();
function->addArgument(vname, type);
if (type == "qreg" || type == "qbit") {
functionBufferNames.push_back(vname);
}
if (xacc::container::contains(validTypes,
auto type = ctx->typedparam(i)->type()->getText();
auto vname = ctx->typedparam(i)->variable_param_name()->getText();
function->addArgument(vname, type);
if (type == "qreg" || type == "qbit") {
functionBufferNames.push_back(vname);
}
if (xacc::container::contains(validTypes,
ctx->typedparam(i)->type()->getText())) {
variables.push_back(vname);
}
......@@ -194,13 +244,13 @@ void XASMListener::enterXacclambda(xasmParser::XacclambdaContext *ctx) {
validTypes{"double", "float", "std::vector<double>", "int"};
function = irProvider->createComposite("tmp_lambda", variables);
for (int i = 0; i < ctx->typedparam().size(); i++) {
auto type = ctx->typedparam(i)->type()->getText();
auto vname = ctx->typedparam(i)->variable_param_name()->getText();
function->addArgument(vname, type);
if (type == "qreg" || type == "qbit") {
functionBufferNames.push_back(vname);
}
if (xacc::container::contains(validTypes,
auto type = ctx->typedparam(i)->type()->getText();
auto vname = ctx->typedparam(i)->variable_param_name()->getText();
function->addArgument(vname, type);
if (type == "qreg" || type == "qbit") {
functionBufferNames.push_back(vname);
}
if (xacc::container::contains(validTypes,
ctx->typedparam(i)->type()->getText())) {
variables.push_back(vname);
}
......@@ -288,7 +338,6 @@ void XASMListener::exitForstmt(xasmParser::ForstmtContext *ctx) {
} else if (comp == "<=") {
} else if (comp == ">=") {
}
inForLoop = false;
// function->clear();
......@@ -331,6 +380,11 @@ void XASMListener::enterBufferList(xasmParser::BufferListContext *ctx) {
// FIXME HANDLE things like x[i] or x[i+1]
auto newVar = name + ctx->bufferIndex(i)->idx->getText();
if (!inForLoop) {
new_var_to_vector_idx.insert(
{newVar, std::stoi(ctx->bufferIndex(i)->idx->getText())});
}
// Check if we have a for-loop parameterized parameter
double ref;
if (inForLoop &&
......@@ -383,6 +437,8 @@ void XASMListener::enterParamList(xasmParser::ParamListContext *ctx) {
std::string newVar = param->var_name->getText() + param->idx->getText();
auto existingVars = function->getVariables();
new_var_to_vector_idx.insert({newVar, std::stoi(param->idx->getText())});
// If x is in existingVars (like std::vector<double> x) then
// replace it with newVar
if (xacc::container::contains(existingVars, param->var_name->getText())) {
......@@ -412,11 +468,51 @@ void XASMListener::enterParamList(xasmParser::ParamListContext *ctx) {
void XASMListener::exitInstruction(xasmParser::InstructionContext *ctx) {
auto inst = irProvider->createInstruction(currentInstructionName, currentBits,
currentParameters);
for (int i = 0; i < currentParameters.size(); i++) {
if (!inForLoop) {
for (int i = 0; i < currentParameters.size(); i++) {
if (currentParameters[i].isVariable()) {
auto arg = function->getArgument(currentParameters[i].toString());
inst->addArgument(arg, i);
auto arg = function->getArgument(currentParameters[i].toString());
if (!arg) {
auto param_str = currentParameters[i].toString();
param_str.erase(
std::remove_if(param_str.begin(), param_str.end(),
[](char c) { return !std::isalpha(c); }),
param_str.end());
arg = function->getArgument(param_str);
if (arg &&
arg->type.find("std::vector<double>") != std::string::npos) {
// this was a container-like type
// give the instruction a mapping from i to vector idx
inst->addIndexMapping(
i, new_var_to_vector_idx[currentParameters[i].toString()]);
}
}
if (!arg) {
// we may have a case where the parameter is an expression string
for (auto &_arg : function->getArguments()) {
double val;
if (parsingUtil->validExpression(currentParameters[i].toString(),
{_arg->name})) {
arg = _arg;
break;
}
}
}
if (!arg) {
xacc::error("Cannot associate function argument " +
currentParameters[i].toString() +
" with this instruction " + currentInstructionName);
}
inst->addArgument(arg, i);
}
}
}
if (!currentBitIdxExpressions.empty()) {
......@@ -440,6 +536,7 @@ void XASMListener::exitInstruction(xasmParser::InstructionContext *ctx) {
currentBits.clear();
currentParameters.clear();
currentBufferNames.clear();
new_var_to_vector_idx.clear();
}
void XASMListener::enterComposite_generator(
......@@ -514,9 +611,31 @@ void XASMListener::exitComposite_generator(
auto tmp = irProvider->createInstruction(currentCompositeName, {});
auto composite = std::dynamic_pointer_cast<CompositeInstruction>(tmp);
for (auto& p : currentParameters) {
auto arg = function->getArgument(p.toString());
composite->addArgument(arg, 0);
for (int i = 0; i < currentParameters.size(); i++) {
auto p = currentParameters[i];
auto arg = function->getArgument(p.toString());
int vector_mapping = 0;
if (!arg) {
auto param_str = currentParameters[i].toString();
param_str.erase(std::remove_if(param_str.begin(), param_str.end(),
[](char c) { return !std::isalpha(c); }),
param_str.end());
arg = function->getArgument(param_str);
if (arg && arg->type.find("std::vector<double>") != std::string::npos) {
// this was a container-like type
// give the instruction a mapping from i to vector idx
vector_mapping = new_var_to_vector_idx[currentParameters[i].toString()];
} else {
// throw an error
xacc::error(
"[xasm] Error in setting arguments for composite instruction " +
currentCompositeName);
}
}
composite->addArgument(arg, vector_mapping);
}
// Treat parameters as variables
......
......@@ -46,11 +46,15 @@ protected:
std::vector<InstructionParameter> currentParameters;
std::map<int, std::string> currentBitIdxExpressions;
std::map<std::string, int> new_var_to_vector_idx;
std::string currentCompositeName;
HeterogeneousMap currentOptions;
std::shared_ptr<ExpressionParsingUtil> parsingUtil;
void for_stmt_update_inst_args(Instruction* inst);
std::vector<std::size_t> for_stmt_update_bits(Instruction *inst,
const std::string varName,
const int value);
......
......@@ -36,6 +36,10 @@ double qreg::weighted_sum(Observable *obs) {
auto children = buffer->getChildren();
double sum = 0.0;
if (terms.size() != children.size()) {
xacc::error("[qreg::weighted_sum()] error, number of observable terms != number of children buffers.");
}
for (int i = 0; i < children.size(); i++) {
// std::cout << children[i]->name() << ", "
// << children[i]->getExpectationValueZ() << ", "
......
......@@ -106,6 +106,12 @@ public:
virtual void addArgument(std::shared_ptr<CompositeArgument> arg,
const int idx_of_inst_param) = 0;
// This method is a helper for associating InstructionParameter indices
// to other indices like std::vector<double> indices. Does nothing here
// to be implemented by subclasses.
virtual void addIndexMapping(const int idx_1, const int idx_2) {
return;
}
virtual const std::string toString() = 0;
virtual const std::vector<std::size_t> bits() = 0;
......
......@@ -256,6 +256,7 @@ template const bool& HeterogeneousMap::get<bool>(const std::string key) const;
template const int& HeterogeneousMap::get<int>(const std::string key) const;
template const double& HeterogeneousMap::get<double>(const std::string key) const;
template const std::vector<std::complex<double>>& HeterogeneousMap::get<std::vector<std::complex<double>>>(const std::string key) const;
template const std::vector<double>& HeterogeneousMap::get<std::vector<double>>(const std::string key) const;
template <typename... Types> class Variant : public mpark::variant<Types...> {
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment