4 #include "IRProvider.hpp"
5 #include "pyxasmBaseVisitor.h"
6 #include "qcor_utils.hpp"
10 using namespace pyxasm;
12 std::map<std::string, std::string> common_name_map{
13 {
"CX",
"CNOT"}, {
"qcor::exp",
"exp_i_theta"}, {
"exp",
"exp_i_theta"}};
15 using pyxasm_result_type =
16 std::pair<std::string, std::shared_ptr<xacc::Instruction>>;
20 std::shared_ptr<xacc::IRProvider> provider;
22 std::vector<std::string> bufferNames;
24 std::vector<std::string> declared_var_names;
28 const std::vector<std::string> &local_var_names = {})
29 : provider(xacc::getIRProvider(
"quantum")),
31 declared_var_names(local_var_names) {}
32 pyxasm_result_type result;
35 bool in_for_loop =
false;
38 std::stringstream sub_node_translation;
39 bool is_processing_sub_expr =
false;
41 antlrcpp::Any visitAtom_expr(
42 pyxasmParser::Atom_exprContext *context)
override {
59 if (is_processing_sub_expr) {
60 if (context->atom() && context->atom()->OPEN_BRACK() &&
61 context->atom()->CLOSE_BRACK() && context->atom()->testlist_comp()) {
66 sub_node_translation <<
"{";
67 bool firstElProcessed =
false;
68 for (
auto &testNode : context->atom()->testlist_comp()->test()) {
71 if (firstElProcessed) {
72 sub_node_translation <<
", ";
74 sub_node_translation << testNode->getText();
75 firstElProcessed =
true;
77 sub_node_translation <<
"}";
82 if (context->atom() && context->atom()->OPEN_BRACE() &&
83 context->atom()->CLOSE_BRACE() && context->atom()->dictorsetmaker()) {
91 if (context->atom() && !context->atom()->STRING().empty()) {
93 for (
auto &strNode : context->atom()->STRING()) {
94 std::string cppStrLiteral = strNode->getText();
96 if (cppStrLiteral.front() ==
'\'' && cppStrLiteral.back() ==
'\'') {
97 cppStrLiteral.front() =
'"';
98 cppStrLiteral.back() =
'"';
100 sub_node_translation << cppStrLiteral;
107 const auto isSliceOp =
108 [](pyxasmParser::Atom_exprContext *atom_expr_context) ->
bool {
109 if (atom_expr_context->trailer().size() == 1) {
110 auto subscriptlist = atom_expr_context->trailer(0)->subscriptlist();
111 if (subscriptlist && subscriptlist->subscript().size() == 1) {
112 auto subscript = subscriptlist->subscript(0);
113 const auto nbTestTerms = subscript->test().size();
115 return (nbTestTerms > 1);
124 if (context->atom() &&
125 xacc::container::contains(bufferNames, context->atom()->getText()) &&
126 isSliceOp(context)) {
128 sub_node_translation << context->atom()->getText()
129 <<
".extract_range({";
131 context->trailer(0)->subscriptlist()->subscript(0)->test();
132 assert(subscripts.size() > 1);
133 std::vector<std::string> subscriptTerms;
134 for (
auto &test : subscripts) {
135 subscriptTerms.emplace_back(test->getText());
139 context->trailer(0)->subscriptlist()->subscript(0)->sliceop();
140 if (sliceOp && sliceOp->test()) {
141 subscriptTerms.emplace_back(sliceOp->test()->getText());
143 assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3);
145 for (
int i = 0; i < subscriptTerms.size(); ++i) {
148 sub_node_translation <<
"static_cast<size_t>(" << subscriptTerms[i]
150 if (i != subscriptTerms.size() - 1) {
151 sub_node_translation <<
", ";
155 sub_node_translation <<
"})";
167 if (!context->trailer().empty() &&
168 (context->trailer()[0]->getText() ==
".ctrl" ||
169 context->trailer()[0]->getText() ==
".adjoint")) {
176 auto arg_list = context->trailer()[1]->arglist();
178 std::stringstream ss;
180 const std::string methodName = context->trailer()[0]->getText().substr(1);
183 const std::string separator =
184 (xacc::container::contains(declared_var_names,
185 context->atom()->getText()))
189 ss << context->atom()->getText() << separator << methodName
191 for (
int i = 0; i < arg_list->argument().size(); i++) {
192 ss <<
", " << rewriteFunctionArgument(*(arg_list->argument(i)));
197 result.first = ss.str();
200 if (context->atom()->NAME() !=
nullptr) {
201 auto inst_name = context->atom()->NAME()->getText();
203 if (common_name_map.count(inst_name)) {
204 inst_name = common_name_map[inst_name];
207 if (xacc::container::contains(provider->getInstructions(), inst_name)) {
209 auto inst = provider->createInstruction(inst_name, 0);
213 if (!inst->isComposite()) {
215 auto required_bits = inst->nRequiredBits();
216 auto required_params = inst->getParameters().size();
218 if (!context->trailer().empty()) {
220 context->trailer()[0]->arglist()->argument().size();
222 if (required_bits + required_params != atom_n_args &&
223 inst_name !=
"Measure") {
224 std::stringstream xx;
225 xx <<
"Invalid quantum instruction expression. " << inst_name
226 <<
" requires " << required_bits <<
" qubit args and "
227 << required_params <<
" parameter args.";
228 xacc::error(xx.str());
232 std::vector<std::string> buffer_names;
233 for (
int i = 0; i < required_bits; i++) {
234 auto bit_expr = context->trailer()[0]->arglist()->argument()[i];
235 auto bit_expr_str = rewriteFunctionArgument(*bit_expr);
237 auto found_bracket = bit_expr_str.find_first_of(
"[");
238 if (found_bracket != std::string::npos) {
239 auto buffer_name = bit_expr_str.substr(0, found_bracket);
240 auto bit_idx_expr = bit_expr_str.substr(
242 bit_expr_str.length() - found_bracket - 2);
243 buffer_names.push_back(buffer_name);
244 inst->setBitExpression(i, bit_idx_expr);
247 inst->setBitExpression(-1, bit_expr_str);
248 buffer_names.push_back(bit_expr_str);
251 inst->setBufferNames(buffer_names);
255 for (
int i = required_bits; i < atom_n_args; i++) {
256 inst->setParameter(counter,
257 replacePythonConstants(context->trailer()[0]
264 result.second = inst;
267 if (inst_name ==
"exp_i_theta") {
269 if (context->trailer()[0]->arglist()->argument().size() != 3) {
271 "Invalid number of arguments for the 'exp_i_theta' "
272 "instruction. Expected 3, got " +
274 context->trailer()[0]->arglist()->argument().size()) +
275 ". Please check your input.");
278 std::stringstream ss;
280 ss <<
"quantum::exp("
281 << context->trailer()[0]->arglist()->argument(0)->getText()
283 << context->trailer()[0]->arglist()->argument(1)->getText()
285 << context->trailer()[0]->arglist()->argument(2)->getText()
287 result.first = ss.str();
294 else if (xacc::container::contains(
295 ::quantum::kernels_in_translation_unit, inst_name) ||
296 !context->trailer()[0]->arglist()->argument().empty() &&
297 xacc::container::contains(bufferNames,
298 context->trailer()[0]
302 std::stringstream ss;
304 ss << inst_name <<
"(parent_kernel, ";
305 const auto &argList = context->trailer()[0]->arglist()->argument();
306 for (
size_t i = 0; i < argList.size(); ++i) {
307 ss << argList[i]->getText();
308 if (i != argList.size() - 1) {
313 result.first = ss.str();
315 xacc::error(
"Composite instruction '" + inst_name +
316 "' is not currently supported.");
324 if (xacc::container::contains(::quantum::kernels_in_translation_unit,
326 (!context->trailer().empty() && context->trailer()[0]->arglist() &&
327 !context->trailer()[0]->arglist()->argument().empty() &&
328 xacc::container::contains(
330 context->trailer()[0]->arglist()->argument(0)->getText()))) {
331 std::stringstream ss;
333 ss << inst_name <<
"(parent_kernel, ";
336 const auto &argList = context->trailer()[0]->arglist()->argument();
337 for (
size_t i = 0; i < argList.size(); ++i) {
338 ss << argList[i]->getText();
339 if (i != argList.size() - 1) {
344 result.first = ss.str();
346 if (!context->trailer().empty()) {
351 std::stringstream ss;
353 if (context->trailer()[0]->arglist() &&
354 !context->trailer()[0]->arglist()->argument().empty()) {
355 const auto &argList =
356 context->trailer()[0]->arglist()->argument();
357 ss << inst_name <<
"(";
358 for (
size_t i = 0; i < argList.size(); ++i) {
359 ss << rewriteFunctionArgument(*(argList[i]));
360 if (i != argList.size() - 1) {
366 ss << context->getText() <<
";\n";
368 result.first = ss.str();
376 antlrcpp::Any visitFor_stmt(pyxasmParser::For_stmtContext *context)
override {
384 auto iter_container = context->testlist()->test()[0]->getText();
385 std::string counter_expr = context->exprlist()->expr()[0]->getText();
387 new_var = counter_expr;
388 if (context->exprlist()->expr().size() > 1) {
389 counter_expr =
"[" + counter_expr;
390 for (
int i = 1; i < context->exprlist()->expr().size(); i++) {
391 counter_expr +=
", " + context->exprlist()->expr()[i]->getText();
396 std::stringstream ss;
397 ss <<
"for (auto " << counter_expr <<
" : " << iter_container <<
") {\n";
398 result.first = ss.str();
403 antlrcpp::Any visitExpr_stmt(pyxasmParser::Expr_stmtContext *ctx)
override {
404 if (ctx->ASSIGN().size() == 1 && ctx->testlist_star_expr().size() == 2) {
406 std::stringstream ss;
407 const std::string lhs = ctx->testlist_star_expr(0)->getText();
408 std::string rhs = replacePythonConstants(
409 replaceMeasureAssignment(ctx->testlist_star_expr(1)->getText()));
411 if (lhs.find(
",") != std::string::npos) {
416 std::vector<std::string> suffix{
".first",
".second"};
417 auto vars = xacc::split(lhs,
',');
418 for (
auto [i, var] : qcor::enumerate(vars)) {
419 if (xacc::container::contains(declared_var_names, var)) {
420 ss << var <<
" = " << rhs << suffix[i] <<
";\n";
422 ss <<
"auto " << var <<
" = " << rhs << suffix[i] <<
";\n";
429 is_processing_sub_expr =
true;
431 sub_node_translation.str(std::string());
434 visitChildren(ctx->testlist_star_expr(1));
437 if (!sub_node_translation.str().empty()) {
439 rhs = replacePythonConstants(
440 replaceMeasureAssignment(sub_node_translation.str()));
443 if (xacc::container::contains(declared_var_names, lhs)) {
444 ss << lhs <<
" = " << rhs <<
"; \n";
447 ss <<
"auto " << lhs <<
" = " << rhs <<
"; \n";
452 result.first = ss.str();
453 if (rhs.find(
"**") != std::string::npos) {
455 return visitChildren(ctx);
461 auto child_result = visitChildren(ctx);
462 const auto translated_src = sub_node_translation.str();
463 sub_node_translation.str(std::string());
467 if (result.first.empty() && !translated_src.empty()) {
468 result.first = translated_src +
";\n";
474 antlrcpp::Any visitPower(pyxasmParser::PowerContext *context)
override {
475 if (context->getText().find(
"**") != std::string::npos &&
476 context->factor() !=
nullptr) {
478 auto replaceAll = [](std::string &s,
const std::string &search,
479 const std::string &replace) {
480 for (std::size_t pos = 0;; pos += replace.length()) {
482 pos = s.find(search, pos);
483 if (pos == std::string::npos)
break;
485 s.erase(pos, search.length());
486 s.insert(pos, replace);
489 auto factor = context->factor();
490 auto atom_expr = context->atom_expr();
492 "std::pow(" + atom_expr->getText() +
", " + factor->getText() +
")";
493 replaceAll(result.first, context->getText(), s);
496 return visitChildren(context);
499 virtual antlrcpp::Any visitIf_stmt(
500 pyxasmParser::If_stmtContext *ctx)
override {
502 if (ctx->test().size() == 1) {
503 std::stringstream ss;
505 << replacePythonConstants(
506 replaceMeasureAssignment(ctx->test(0)->getText()))
508 result.first = ss.str();
511 return visitChildren(ctx);
514 virtual antlrcpp::Any
515 visitWhile_stmt(pyxasmParser::While_stmtContext *ctx)
override {
516 std::stringstream ss;
517 ss <<
"while (" << ctx->test()->getText() <<
") {\n";
518 result.first = ss.str();
522 virtual antlrcpp::Any visitTestlist_star_expr(
523 pyxasmParser::Testlist_star_exprContext *context)
override {
525 const auto var_name = context->getText();
526 if (xacc::container::contains(declared_var_names, var_name)) {
527 sub_node_translation << var_name <<
" ";
530 return visitChildren(context);
533 virtual antlrcpp::Any
534 visitAugassign(pyxasmParser::AugassignContext *context)
override {
536 sub_node_translation << context->getText() <<
" ";
540 virtual antlrcpp::Any
541 visitTestlist(pyxasmParser::TestlistContext *context)
override {
543 sub_node_translation << context->getText() <<
" ";
550 std::string replacePythonConstants(
const std::string &in_pyExpr)
const {
552 const std::map<std::string, std::string> REPLACE_MAP{{
"math.pi",
"M_PI"},
553 {
"numpy.pi",
"M_PI"}};
554 std::string newSrc = in_pyExpr;
555 for (
const auto &[key, value] : REPLACE_MAP) {
556 const auto pos = newSrc.find(key);
557 if (pos != std::string::npos) {
558 newSrc.replace(pos, key.length(), value);
565 std::string replaceMeasureAssignment(
const std::string &in_expr)
const {
566 if (in_expr.find(
"Measure") != std::string::npos) {
568 const auto replaceMeasureInst = [](std::string &s,
569 const std::string &search,
570 const std::string &replace) {
571 for (
size_t pos = 0;; pos += replace.length()) {
572 pos = s.find(search, pos);
573 if (pos == std::string::npos) {
576 if (!isspace(s[pos + search.length()]) &&
577 (s[pos + search.length()] !=
'(')) {
580 s.erase(pos, search.length());
581 s.insert(pos, replace);
585 std::string result = in_expr;
586 replaceMeasureInst(result,
"Measure",
"quantum::mz");
600 rewriteFunctionArgument(pyxasmParser::ArgumentContext &in_argContext) {
607 is_processing_sub_expr =
true;
609 sub_node_translation.str(std::string());
612 visitChildren(&in_argContext);
615 if (!sub_node_translation.str().empty()) {
617 return sub_node_translation.str();
620 return in_argContext.getText();