QCOR
pyxasm_visitor.hpp
1 #pragma once
2 #include <regex>
3 
4 #include "IRProvider.hpp"
5 #include "pyxasmBaseVisitor.h"
6 #include "qcor_utils.hpp"
7 #include "qrt.hpp"
8 #include "xacc.hpp"
9 
10 using namespace pyxasm;
11 
12 std::map<std::string, std::string> common_name_map{
13  {"CX", "CNOT"}, {"qcor::exp", "exp_i_theta"}, {"exp", "exp_i_theta"}};
14 
15 using pyxasm_result_type =
16  std::pair<std::string, std::shared_ptr<xacc::Instruction>>;
17 
18 class pyxasm_visitor : public pyxasmBaseVisitor {
19  protected:
20  std::shared_ptr<xacc::IRProvider> provider;
21  // List of buffers in the *context* of this XASM visitor
22  std::vector<std::string> bufferNames;
23  // List of *declared* variables
24  std::vector<std::string> declared_var_names;
25 
26  public:
27  pyxasm_visitor(const std::vector<std::string> &buffers = {},
28  const std::vector<std::string> &local_var_names = {})
29  : provider(xacc::getIRProvider("quantum")),
30  bufferNames(buffers),
31  declared_var_names(local_var_names) {}
32  pyxasm_result_type result;
33  // New var declared (auto type) after visiting this node.
34  std::string new_var;
35  bool in_for_loop = false;
36  // Var to keep track of sub-node rewrite:
37  // e.g., traverse down the AST recursively.
38  std::stringstream sub_node_translation;
39  bool is_processing_sub_expr = false;
40 
41  antlrcpp::Any visitAtom_expr(
42  pyxasmParser::Atom_exprContext *context) override {
43  // std::cout << "Atom_exprContext: " << context->getText() << "\n";
44  // Strategy:
45  // At the top level, we analyze the trailer to determine the
46  // list of function call arguments.
47  // Then, traverse down the arg. node to see if there is a potential rewrite rules
48  // e.g. for arrays (as testlist_comp nodes)
49  // Otherwise, just get the argument text as is.
50  /*
51  atom_expr: (AWAIT)? atom trailer*;
52  atom: ('(' (yield_expr|testlist_comp)? ')' |
53  '[' (testlist_comp)? ']' |
54  '{' (dictorsetmaker)? '}' |
55  NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False');
56  */
57  // Only processes these for sub-expressesions,
58  // e.g. re-entries to this function
59  if (is_processing_sub_expr) {
60  if (context->atom() && context->atom()->OPEN_BRACK() &&
61  context->atom()->CLOSE_BRACK() && context->atom()->testlist_comp()) {
62  // Array type expression:
63  // std::cout << "Array atom expression: "
64  // << context->atom()->testlist_comp()->getText() << "\n";
65  // Use braces
66  sub_node_translation << "{";
67  bool firstElProcessed = false;
68  for (auto &testNode : context->atom()->testlist_comp()->test()) {
69  // std::cout << "Array elem: " << testNode->getText() << "\n";
70  // Add comma if needed (there is a previous element)
71  if (firstElProcessed) {
72  sub_node_translation << ", ";
73  }
74  sub_node_translation << testNode->getText();
75  firstElProcessed = true;
76  }
77  sub_node_translation << "}";
78  return 0;
79  }
80 
81  // We don't have a re-write rule for this one (py::dict)
82  if (context->atom() && context->atom()->OPEN_BRACE() &&
83  context->atom()->CLOSE_BRACE() && context->atom()->dictorsetmaker()) {
84  // Dict:
85  // std::cout << "Dict atom expression: "
86  // << context->atom()->dictorsetmaker()->getText() << "\n";
87  // TODO:
88  return 0;
89  }
90 
91  if (context->atom() && !context->atom()->STRING().empty()) {
92  // Strings:
93  for (auto &strNode : context->atom()->STRING()) {
94  std::string cppStrLiteral = strNode->getText();
95  // Handle Python single-quotes
96  if (cppStrLiteral.front() == '\'' && cppStrLiteral.back() == '\'') {
97  cppStrLiteral.front() = '"';
98  cppStrLiteral.back() = '"';
99  }
100  sub_node_translation << cppStrLiteral;
101  // std::cout << "String expression: " << strNode->getText() << " --> "
102  // << cppStrLiteral << "\n";
103  }
104  return 0;
105  }
106 
107  const auto isSliceOp =
108  [](pyxasmParser::Atom_exprContext *atom_expr_context) -> bool {
109  if (atom_expr_context->trailer().size() == 1) {
110  auto subscriptlist = atom_expr_context->trailer(0)->subscriptlist();
111  if (subscriptlist && subscriptlist->subscript().size() == 1) {
112  auto subscript = subscriptlist->subscript(0);
113  const auto nbTestTerms = subscript->test().size();
114  // Multiple test terms (separated by ':')
115  return (nbTestTerms > 1);
116  }
117  }
118 
119  return false;
120  };
121 
122  // Handle slicing operations (multiple array subscriptions separated by
123  // ':') on a qreg.
124  if (context->atom() &&
125  xacc::container::contains(bufferNames, context->atom()->getText()) &&
126  isSliceOp(context)) {
127  // std::cout << "Slice op: " << context->getText() << "\n";
128  sub_node_translation << context->atom()->getText()
129  << ".extract_range({";
130  auto subscripts =
131  context->trailer(0)->subscriptlist()->subscript(0)->test();
132  assert(subscripts.size() > 1);
133  std::vector<std::string> subscriptTerms;
134  for (auto &test : subscripts) {
135  subscriptTerms.emplace_back(test->getText());
136  }
137 
138  auto sliceOp =
139  context->trailer(0)->subscriptlist()->subscript(0)->sliceop();
140  if (sliceOp && sliceOp->test()) {
141  subscriptTerms.emplace_back(sliceOp->test()->getText());
142  }
143  assert(subscriptTerms.size() == 2 || subscriptTerms.size() == 3);
144 
145  for (int i = 0; i < subscriptTerms.size(); ++i) {
146  // Need to cast to prevent compiler errors,
147  // e.g. when using q.size() which returns an int.
148  sub_node_translation << "static_cast<size_t>(" << subscriptTerms[i]
149  << ")";
150  if (i != subscriptTerms.size() - 1) {
151  sub_node_translation << ", ";
152  }
153  }
154 
155  sub_node_translation << "})";
156 
157  // convert the slice op to initializer list:
158  // std::cout << "Slice Convert: " << context->getText() << " --> "
159  // << sub_node_translation.str() << "\n";
160  return 0;
161  }
162 
163  return 0;
164  }
165 
166  // Handle kernel::ctrl(...), kernel::adjoint(...)
167  if (!context->trailer().empty() &&
168  (context->trailer()[0]->getText() == ".ctrl" ||
169  context->trailer()[0]->getText() == ".adjoint")) {
170  // std::cout << "HELLO: " << context->getText() << "\n";
171  // std::cout << context->trailer()[0]->getText() << "\n";
172  // std::cout << context->atom()->getText() << "\n";
173 
174  // std::cout << context->trailer()[1]->getText() << "\n";
175  // std::cout << context->trailer()[1]->arglist() << "\n";
176  auto arg_list = context->trailer()[1]->arglist();
177 
178  std::stringstream ss;
179  // Remove the first '.' character
180  const std::string methodName = context->trailer()[0]->getText().substr(1);
181  // If this is a *variable*, then using '.' for control/adjoint.
182  // Otherwise, use '::' (global scope kernel names)
183  const std::string separator =
184  (xacc::container::contains(declared_var_names,
185  context->atom()->getText()))
186  ? "."
187  : "::";
188 
189  ss << context->atom()->getText() << separator << methodName
190  << "(parent_kernel";
191  for (int i = 0; i < arg_list->argument().size(); i++) {
192  ss << ", " << rewriteFunctionArgument(*(arg_list->argument(i)));
193  }
194  ss << ");\n";
195 
196  // std::cout << "HELLO SS: " << ss.str() << "\n";
197  result.first = ss.str();
198  return 0;
199  }
200  if (context->atom()->NAME() != nullptr) {
201  auto inst_name = context->atom()->NAME()->getText();
202 
203  if (common_name_map.count(inst_name)) {
204  inst_name = common_name_map[inst_name];
205  }
206 
207  if (xacc::container::contains(provider->getInstructions(), inst_name)) {
208  // Create an instance of the Instruction with the given name
209  auto inst = provider->createInstruction(inst_name, 0);
210 
211  // If it is not composite, look for its bit expressions
212  // and parameter expressions
213  if (!inst->isComposite()) {
214  // Get the number of required bits and parameters
215  auto required_bits = inst->nRequiredBits();
216  auto required_params = inst->getParameters().size();
217 
218  if (!context->trailer().empty()) {
219  auto atom_n_args =
220  context->trailer()[0]->arglist()->argument().size();
221 
222  if (required_bits + required_params != atom_n_args &&
223  inst_name != "Measure") {
224  std::stringstream xx;
225  xx << "Invalid quantum instruction expression. " << inst_name
226  << " requires " << required_bits << " qubit args and "
227  << required_params << " parameter args.";
228  xacc::error(xx.str());
229  }
230 
231  // Get the qubit expresssions
232  std::vector<std::string> buffer_names;
233  for (int i = 0; i < required_bits; i++) {
234  auto bit_expr = context->trailer()[0]->arglist()->argument()[i];
235  auto bit_expr_str = rewriteFunctionArgument(*bit_expr);
236 
237  auto found_bracket = bit_expr_str.find_first_of("[");
238  if (found_bracket != std::string::npos) {
239  auto buffer_name = bit_expr_str.substr(0, found_bracket);
240  auto bit_idx_expr = bit_expr_str.substr(
241  found_bracket + 1,
242  bit_expr_str.length() - found_bracket - 2);
243  buffer_names.push_back(buffer_name);
244  inst->setBitExpression(i, bit_idx_expr);
245  } else {
246  // Indicate this is a qubit(-1) or a qreg(-2)
247  inst->setBitExpression(-1, bit_expr_str);
248  buffer_names.push_back(bit_expr_str);
249  }
250  }
251  inst->setBufferNames(buffer_names);
252 
253  // Get the parameter expressions
254  int counter = 0;
255  for (int i = required_bits; i < atom_n_args; i++) {
256  inst->setParameter(counter,
257  replacePythonConstants(context->trailer()[0]
258  ->arglist()
259  ->argument()[i]
260  ->getText()));
261  counter++;
262  }
263  }
264  result.second = inst;
265  } else {
266  // Composite instructions, e.g. exp_i_theta
267  if (inst_name == "exp_i_theta") {
268  // Expected 3 params:
269  if (context->trailer()[0]->arglist()->argument().size() != 3) {
270  xacc::error(
271  "Invalid number of arguments for the 'exp_i_theta' "
272  "instruction. Expected 3, got " +
273  std::to_string(
274  context->trailer()[0]->arglist()->argument().size()) +
275  ". Please check your input.");
276  }
277 
278  std::stringstream ss;
279  // Delegate to the QRT call directly.
280  ss << "quantum::exp("
281  << context->trailer()[0]->arglist()->argument(0)->getText()
282  << ", "
283  << context->trailer()[0]->arglist()->argument(1)->getText()
284  << ", "
285  << context->trailer()[0]->arglist()->argument(2)->getText()
286  << ");\n";
287  result.first = ss.str();
288  }
289  // Handle potential name collision: user-defined kernel having the
290  // same name as an XACC circuit: e.g. common names such as qft, iqft
291  // Note: these circuits (except exp_i_theta) don't have QRT
292  // equivalents.
293  // Condition: first argument is a qubit register
294  else if (xacc::container::contains(
295  ::quantum::kernels_in_translation_unit, inst_name) ||
296  !context->trailer()[0]->arglist()->argument().empty() &&
297  xacc::container::contains(bufferNames,
298  context->trailer()[0]
299  ->arglist()
300  ->argument(0)
301  ->getText())) {
302  std::stringstream ss;
303  // Use the kernel call with a parent kernel arg.
304  ss << inst_name << "(parent_kernel, ";
305  const auto &argList = context->trailer()[0]->arglist()->argument();
306  for (size_t i = 0; i < argList.size(); ++i) {
307  ss << argList[i]->getText();
308  if (i != argList.size() - 1) {
309  ss << ", ";
310  }
311  }
312  ss << ");\n";
313  result.first = ss.str();
314  } else {
315  xacc::error("Composite instruction '" + inst_name +
316  "' is not currently supported.");
317  }
318  }
319  } else {
320  // This kernel *callable* is not an intrinsic instruction, just
321  // reassemble the call:
322  // Check that the *first* argument is a *qreg* in the current context of
323  // *this* kernel or the function name is a kernel in translation unit.
324  if (xacc::container::contains(::quantum::kernels_in_translation_unit,
325  inst_name) ||
326  (!context->trailer().empty() && context->trailer()[0]->arglist() &&
327  !context->trailer()[0]->arglist()->argument().empty() &&
328  xacc::container::contains(
329  bufferNames,
330  context->trailer()[0]->arglist()->argument(0)->getText()))) {
331  std::stringstream ss;
332  // Use the kernel call with a parent kernel arg.
333  ss << inst_name << "(parent_kernel, ";
334  // TODO: We potentially need to handle *inline* expressions in the
335  // function call.
336  const auto &argList = context->trailer()[0]->arglist()->argument();
337  for (size_t i = 0; i < argList.size(); ++i) {
338  ss << argList[i]->getText();
339  if (i != argList.size() - 1) {
340  ss << ", ";
341  }
342  }
343  ss << ");\n";
344  result.first = ss.str();
345  } else {
346  if (!context->trailer().empty()) {
347  // A classical call-like expression: i.e. not a kernel call:
348  // Just output it *as-is* to the C++ stream.
349  // We can hook more sophisticated code-gen here if required.
350  // std::cout << "Callable: " << context->getText() << "\n";
351  std::stringstream ss;
352 
353  if (context->trailer()[0]->arglist() &&
354  !context->trailer()[0]->arglist()->argument().empty()) {
355  const auto &argList =
356  context->trailer()[0]->arglist()->argument();
357  ss << inst_name << "(";
358  for (size_t i = 0; i < argList.size(); ++i) {
359  ss << rewriteFunctionArgument(*(argList[i]));
360  if (i != argList.size() - 1) {
361  ss << ", ";
362  }
363  }
364  ss << ");\n";
365  } else {
366  ss << context->getText() << ";\n";
367  }
368  result.first = ss.str();
369  }
370  }
371  }
372  }
373  return 0;
374  }
375 
376  antlrcpp::Any visitFor_stmt(pyxasmParser::For_stmtContext *context) override {
377  // Rewrite:
378  // Python: "for <var> in <expr>:"
379  // C++: for (auto var: <expr>) {}
380  // Note: we add range(int) as a C++ function to support this common pattern.
381  // or
382  // Python: "for <idx>,<var> in enumerate(<listvar>):"
383  // C++: for (auto [idx, var] : enumerate(listvar))
384  auto iter_container = context->testlist()->test()[0]->getText();
385  std::string counter_expr = context->exprlist()->expr()[0]->getText();
386  // Add the for loop variable to the tracking list as well.
387  new_var = counter_expr;
388  if (context->exprlist()->expr().size() > 1) {
389  counter_expr = "[" + counter_expr;
390  for (int i = 1; i < context->exprlist()->expr().size(); i++) {
391  counter_expr += ", " + context->exprlist()->expr()[i]->getText();
392  }
393  counter_expr += "]";
394  }
395 
396  std::stringstream ss;
397  ss << "for (auto " << counter_expr << " : " << iter_container << ") {\n";
398  result.first = ss.str();
399  in_for_loop = true;
400  return 0;
401  }
402 
403  antlrcpp::Any visitExpr_stmt(pyxasmParser::Expr_stmtContext *ctx) override {
404  if (ctx->ASSIGN().size() == 1 && ctx->testlist_star_expr().size() == 2) {
405  // Handle simple assignment: a = expr
406  std::stringstream ss;
407  const std::string lhs = ctx->testlist_star_expr(0)->getText();
408  std::string rhs = replacePythonConstants(
409  replaceMeasureAssignment(ctx->testlist_star_expr(1)->getText()));
410 
411  if (lhs.find(",") != std::string::npos) {
412  // this is
413  // var1, var2, ... = some_tuple_thing
414  // We only support var1, var2 = ... for now
415  // where ... is a pair-like object
416  std::vector<std::string> suffix{".first", ".second"};
417  auto vars = xacc::split(lhs, ',');
418  for (auto [i, var] : qcor::enumerate(vars)) {
419  if (xacc::container::contains(declared_var_names, var)) {
420  ss << var << " = " << rhs << suffix[i] << ";\n";
421  } else {
422  ss << "auto " << var << " = " << rhs << suffix[i] << ";\n";
423  new_var = lhs;
424  }
425  }
426  } else {
427  // Strategy: try to traverse the rhs to see if there is a possible rewrite;
428  // Otherwise, use the text as is.
429  is_processing_sub_expr = true;
430  // clear the sub_node_translation
431  sub_node_translation.str(std::string());
432 
433  // visit arg sub-node:
434  visitChildren(ctx->testlist_star_expr(1));
435 
436  // Check if there is a rewrite:
437  if (!sub_node_translation.str().empty()) {
438  // Update RHS
439  rhs = replacePythonConstants(
440  replaceMeasureAssignment(sub_node_translation.str()));
441  }
442 
443  if (xacc::container::contains(declared_var_names, lhs)) {
444  ss << lhs << " = " << rhs << "; \n";
445  } else {
446  // New variable: need to add *auto*
447  ss << "auto " << lhs << " = " << rhs << "; \n";
448  new_var = lhs;
449  }
450  }
451 
452  result.first = ss.str();
453  if (rhs.find("**") != std::string::npos) {
454  // keep processing
455  return visitChildren(ctx);
456  } else {
457  return 0;
458  }
459  } else {
460  // Visit child node:
461  auto child_result = visitChildren(ctx);
462  const auto translated_src = sub_node_translation.str();
463  sub_node_translation.str(std::string());
464  // If no child nodes, perform the codegen (result.first is not set)
465  // but just appending the incremental translation collector;
466  // return the collected C++ statement.
467  if (result.first.empty() && !translated_src.empty()) {
468  result.first = translated_src + ";\n";
469  }
470  return child_result;
471  }
472  }
473 
474  antlrcpp::Any visitPower(pyxasmParser::PowerContext *context) override {
475  if (context->getText().find("**") != std::string::npos &&
476  context->factor() != nullptr) {
477  // Here we handle x**y from parent assignment expression
478  auto replaceAll = [](std::string &s, const std::string &search,
479  const std::string &replace) {
480  for (std::size_t pos = 0;; pos += replace.length()) {
481  // Locate the substring to replace
482  pos = s.find(search, pos);
483  if (pos == std::string::npos) break;
484  // Replace by erasing and inserting
485  s.erase(pos, search.length());
486  s.insert(pos, replace);
487  }
488  };
489  auto factor = context->factor();
490  auto atom_expr = context->atom_expr();
491  std::string s =
492  "std::pow(" + atom_expr->getText() + ", " + factor->getText() + ")";
493  replaceAll(result.first, context->getText(), s);
494  return 0;
495  }
496  return visitChildren(context);
497  }
498 
499  virtual antlrcpp::Any visitIf_stmt(
500  pyxasmParser::If_stmtContext *ctx) override {
501  // Only support single clause atm
502  if (ctx->test().size() == 1) {
503  std::stringstream ss;
504  ss << "if ("
505  << replacePythonConstants(
506  replaceMeasureAssignment(ctx->test(0)->getText()))
507  << ") {\n";
508  result.first = ss.str();
509  return 0;
510  }
511  return visitChildren(ctx);
512  }
513 
514  virtual antlrcpp::Any
515  visitWhile_stmt(pyxasmParser::While_stmtContext *ctx) override {
516  std::stringstream ss;
517  ss << "while (" << ctx->test()->getText() << ") {\n";
518  result.first = ss.str();
519  return 0;
520  }
521 
522  virtual antlrcpp::Any visitTestlist_star_expr(
523  pyxasmParser::Testlist_star_exprContext *context) override {
524  // std::cout << "Testlist_star_exprContext:" << context->getText() << "\n";
525  const auto var_name = context->getText();
526  if (xacc::container::contains(declared_var_names, var_name)) {
527  sub_node_translation << var_name << " ";
528  return 0;
529  }
530  return visitChildren(context);
531  }
532 
533  virtual antlrcpp::Any
534  visitAugassign(pyxasmParser::AugassignContext *context) override {
535  // std::cout << "Augassign:" << context->getText() << "\n";
536  sub_node_translation << context->getText() << " ";
537  return 0;
538  }
539 
540  virtual antlrcpp::Any
541  visitTestlist(pyxasmParser::TestlistContext *context) override {
542  // std::cout << "visitTestlist:" << context->getText() << "\n";
543  sub_node_translation << context->getText() << " ";
544  return 0;
545  }
546 
547  private:
548  // Replaces common Python constants, e.g. 'math.pi' or 'numpy.pi'.
549  // Note: the library names have been resolved to their original names.
550  std::string replacePythonConstants(const std::string &in_pyExpr) const {
551  // List of all keywords to be replaced
552  const std::map<std::string, std::string> REPLACE_MAP{{"math.pi", "M_PI"},
553  {"numpy.pi", "M_PI"}};
554  std::string newSrc = in_pyExpr;
555  for (const auto &[key, value] : REPLACE_MAP) {
556  const auto pos = newSrc.find(key);
557  if (pos != std::string::npos) {
558  newSrc.replace(pos, key.length(), value);
559  }
560  }
561  return newSrc;
562  }
563 
564  // Assignment of Measure results -> variable or in if conditional statements
565  std::string replaceMeasureAssignment(const std::string &in_expr) const {
566  if (in_expr.find("Measure") != std::string::npos) {
567  // Found measure in an if statement instruction.
568  const auto replaceMeasureInst = [](std::string &s,
569  const std::string &search,
570  const std::string &replace) {
571  for (size_t pos = 0;; pos += replace.length()) {
572  pos = s.find(search, pos);
573  if (pos == std::string::npos) {
574  break;
575  }
576  if (!isspace(s[pos + search.length()]) &&
577  (s[pos + search.length()] != '(')) {
578  continue;
579  }
580  s.erase(pos, search.length());
581  s.insert(pos, replace);
582  }
583  };
584 
585  std::string result = in_expr;
586  replaceMeasureInst(result, "Measure", "quantum::mz");
587  return result;
588  } else {
589  return in_expr;
590  }
591  }
592 
593  // A helper to rewrite function argument by traversing the node to see
594  // if there is a potential rewrite.
595  // Use case: inline expressions
596  // e.g. X(q[0:3])
597  // slicing of the qreg 'q' then call the broadcast X op.
598  // i.e., we need to rewrite the arg to q.extract_range({0, 3}).
599  std::string
600  rewriteFunctionArgument(pyxasmParser::ArgumentContext &in_argContext) {
601  // Strategy: try to traverse the argument context to see if there is a
602  // possible rewrite; i.e. it may be another atom_expression that we have a
603  // handler for. Otherwise, use the text as is.
604  // We need this flag to prevent parsing quantum instructions as sub-expressions.
605  // e.g. QCOR operators (X, Y, Z) in an observable definition shouldn't be
606  // processed as instructions.
607  is_processing_sub_expr = true;
608  // clear the sub_node_translation
609  sub_node_translation.str(std::string());
610 
611  // visit arg sub-node:
612  visitChildren(&in_argContext);
613 
614  // Check if there is a rewrite:
615  if (!sub_node_translation.str().empty()) {
616  // Update RHS
617  return sub_node_translation.str();
618  }
619  // Returns the string as is
620  return in_argContext.getText();
621  }
622 };
pyxasm_visitor
Definition: pyxasm_visitor.hpp:18