QCOR
xasm_single_visitor.hpp
1 #pragma once
2 #include <regex>
3 
4 #include "IRProvider.hpp"
5 #include "qrt.hpp"
6 #include "xacc.hpp"
7 #include "xasm_singleVisitor.h"
8 
9 using namespace xasm;
10 
11 std::map<std::string, std::string> common_name_map{
12  {"CX", "CNOT"}, {"qcor::exp", "exp_i_theta"}, {"exp", "exp_i_theta"}};
13 using xasm_single_result_type =
14  std::pair<std::string, std::shared_ptr<xacc::Instruction>>;
15 
16 class xasm_single_visitor : public xasm::xasm_singleVisitor {
17  protected:
18  int n_cached_execs = 0;
19 
20  public:
21  xasm_single_result_type result;
22 
23  antlrcpp::Any visitStatement(
24  xasm_singleParser::StatementContext *context) override {
25  // should only have 1 child, if it is qinst
26  // we expect a xacc Instruction return type
27  // if cinst we expect a Cinst
28  return visitChildren(context);
29  }
30 
31  antlrcpp::Any visitQinst(xasm_singleParser::QinstContext *context) override {
32  if (!xacc::isInitialized()) {
33  xacc::Initialize();
34  }
35 
36  // if not in instruction registry, then forward to classical instructions
37  auto inst_name = context->inst_name->getText();
38  auto provider = xacc::getIRProvider("quantum");
39 
40  if (common_name_map.count(inst_name)) {
41  inst_name = common_name_map[inst_name];
42  }
43 
44  if (xacc::container::contains(provider->getInstructions(), inst_name)) {
45  // We don't really care about Instruction::bits(), qrt_mapper
46  // will look for bit expressions and use those, so just set
47  // everything as a string...
48 
49  // Create an instance of the Instruction with the given name
50  auto inst = provider->createInstruction(inst_name, 0);
51 
52  // If it is not composite, look for its bit expressions
53  // and parameter expressions
54  if (!inst->isComposite()) {
55  // Get the number of required bits and parameters
56  auto required_bits = inst->nRequiredBits();
57  auto required_params = inst->getParameters().size();
58 
59  if (required_bits + required_params !=
60  context->explist()->exp().size() &&
61  inst_name != "Measure") {
62  std::stringstream xx;
63  xx << "Invalid quantum instruction expression. " << inst_name
64  << " requires " << required_bits << " qubit args and "
65  << required_params << " parameter args.";
66  xacc::error(xx.str());
67  }
68 
69  // Get the qubit expresssions
70  std::vector<std::string> buffer_names;
71  int count = 1;
72  for (int i = 0; i < required_bits; i++) {
73  auto bit_expr = context->explist()->exp(i);
74  auto bit_expr_str = bit_expr->getText();
75 
76  auto found_bracket = bit_expr_str.find_first_of("[");
77  if (found_bracket != std::string::npos) {
78  auto buffer_name = bit_expr_str.substr(0, found_bracket);
79  auto bit_idx_expr = bit_expr_str.substr(
80  found_bracket + 1, bit_expr_str.length() - found_bracket - 2);
81 
82  buffer_names.push_back(buffer_name);
83  inst->setBitExpression(i, bit_idx_expr);
84  } else {
85  // Indicate this is a qubit(-1) or a qreg(-2)
86  inst->setBitExpression(-1*count, bit_expr_str);
87  buffer_names.push_back(bit_expr_str);
88  }
89  count++;
90  }
91 
92  inst->setBufferNames(buffer_names);
93 
94  // Get the parameter expressions
95  int counter = 0;
96  for (int i = required_bits; i < context->explist()->exp().size(); i++) {
97  inst->setParameter(counter, context->explist()->exp(i)->getText());
98  counter++;
99  }
100  } else {
101  // I don't want to use xasm circuit gen any more...
102  // So use it as a fallback, but first look for previous
103  if (xacc::container::contains(quantum::kernels_in_translation_unit,
104  context->inst_name->getText())) {
105  // If this is a previously seen quantum kernel
106  // then we want to update its signature to add the
107  // parent CompositeInstruction argument
108  std::stringstream ss;
109  for (auto c : context->children) {
110  if (c->getText() == "(") {
111  ss << c->getText() << "parent_kernel, ";
112 
113  } else {
114  ss << c->getText() << " ";
115  }
116  }
117 
118  result.first = ss.str() + "\n";
119  return 0;
120  } else {
121  // this is something like exp_i_theta(q,...);
122  auto comp_inst = xacc::ir::asComposite(inst);
123  inst->setBufferNames({context->explist()->exp(0)->getText()});
124  for (int i = 1; i < context->explist()->exp().size(); i++) {
125  comp_inst->addArgument(context->explist()->exp(i)->getText(), "");
126  }
127  }
128  }
129 
130  result.second = inst;
131  } else {
132  std::stringstream ss;
133 
134  if (xacc::container::contains(quantum::kernels_in_translation_unit,
135  context->inst_name->getText())) {
136  // If this is a previously seen quantum kernel
137  // then we want to update its signature to add the
138  // parent CompositeInstruction argument
139 
140  for (auto c : context->children) {
141  if (c->getText() == "(") {
142  ss << c->getText() << "parent_kernel, ";
143 
144  } else if (c->getText().find("qalloc") != std::string::npos) {
145  // Inline qalloc used in a kernel call:
146  // std::cout << "Qalloc: " << c->getText() << "\n";
147  std::string arg_str = c->getText();
148  const std::string qalloc_name = "qalloc";
149  auto qalloc_pos = arg_str.find(qalloc_name);
150  // Handle multiple temporary qalloc in a kernel call:
151  while (qalloc_pos != std::string::npos) {
152  // Matching '(' ')' to make sure we capture the content of the
153  // qalloc call.
154  std::stack<char> balance_matcher;
155  const auto open_pos =
156  arg_str.find_first_of("(", qalloc_pos);
157  if (open_pos == std::string::npos) {
158  xacc::error("Invalid Syntax: " + c->getText());
159  }
160  for (int i = open_pos; i < arg_str.size(); ++i) {
161  if (arg_str[i] == '(') {
162  balance_matcher.push('(');
163  }
164  if (arg_str[i] == ')') {
165  balance_matcher.pop();
166  }
167 
168  if (balance_matcher.empty()) {
169  arg_str.insert(i, ", quantum::getAncillaQubitAllocator()");
170  break;
171  }
172  }
173 
174  if (!balance_matcher.empty()) {
175  xacc::error("Invalid Syntax: " + c->getText());
176  }
177 
178  // Find the next one if any:
179  qalloc_pos = arg_str.find(qalloc_name, qalloc_pos + qalloc_name.size());
180  }
181  // Append the new arg string
182  ss << arg_str;
183  } else {
184  ss << c->getText() << " ";
185  }
186  }
187  } else {
188  for (auto c : context->children) {
189  ss << c->getText() << " ";
190  }
191  }
192  result.first = ss.str() + "\n";
193  n_cached_execs++;
194  }
195 
196  return 0;
197  }
198 
199  antlrcpp::Any visitCinst(xasm_singleParser::CinstContext *context) override {
200  // Strategy here is simple, we just want to
201  // preserve all classical code statements in
202  // the original quantum kernel
203  std::stringstream ss;
204 
205  if (context->getText().find("::adjoint") != std::string::npos) {
206  for (auto c : context->children) {
207  if (c->getText() == "(") {
208  ss << c->getText() << "parent_kernel, ";
209 
210  } else {
211  ss << c->getText() << " ";
212  }
213  }
214  } else if (context->getText().find("::ctrl") != std::string::npos ||
215  context->getText().find(".ctrl") != std::string::npos) {
216  for (auto c : context->children) {
217  if (c->getText() == "(") {
218  ss << c->getText() << "parent_kernel, ";
219 
220  } else {
221  ss << c->getText() << " ";
222  }
223  }
224  } else if (context->getText().find("Measure") != std::string::npos) {
225  // Found measure in a classical instruction.
226  // std::cout << "FOUND MEAS: " << context->getText() << "\n";
227  // To be extra careful, we use search and replace to handle edge case
228  // whereby `!Measure` is considered as 1 token.
229  const auto replaceAll = [](std::string &s, const std::string &search,
230  const std::string &replace) {
231  for (size_t pos = 0;; pos += replace.length()) {
232  pos = s.find(search, pos);
233  if (pos == std::string::npos) {
234  break;
235  }
236  if ((s.size() > pos + search.size()) &&
237  // If "Measure" is not followed by a space or '(',
238  // i.e. not having a function call signature,
239  // we don't replace.
240  // Not space **and** not '('
241  (!isspace(s[pos + search.length()]) &&
242  (s[pos + search.length()] != '('))) {
243  continue;
244  }
245  s.erase(pos, search.length());
246  s.insert(pos, replace);
247  }
248  };
249  for (auto c : context->children) {
250  auto origText = c->getText();
251  replaceAll(origText, "Measure", " quantum::mz");
252  ss << origText << " ";
253  }
254  } else {
255  if (context->var_value &&
256  context->var_value->getText().find("qalloc") != std::string::npos) {
257  // std::cout << "Qalloc encountered\n";
258  std::stringstream qalloc_ss;
259  for (auto c : context->children) {
260  qalloc_ss << c->getText() << " ";
261  }
262  std::string qalloc_call = qalloc_ss.str();
263  // std::cout << qalloc_call << "\n";
264  const auto close_pos = qalloc_call.find_last_of(")");
265  qalloc_call.insert(close_pos, ", quantum::getAncillaQubitAllocator()");
266  // std::cout << "After: " << qalloc_call << "\n";
267  ss << qalloc_call;
268  } else {
269  for (auto c : context->children) {
270  ss << c->getText() << " ";
271  }
272  }
273  }
274 
275  result.first = ss.str() + "\n";
276  return 0;
277  }
278 
279  antlrcpp::Any visitLine(xasm_singleParser::LineContext *context) override {
280  return 0;
281  }
282 
283  antlrcpp::Any visitComment(
284  xasm_singleParser::CommentContext *context) override {
285  return 0;
286  }
287  antlrcpp::Any visitCompare(
288  xasm_singleParser::CompareContext *context) override {
289  return 0;
290  }
291 
292  antlrcpp::Any visitCpp_type(
293  xasm_singleParser::Cpp_typeContext *context) override {
294  return 0;
295  }
296 
297  antlrcpp::Any visitExplist(
298  xasm_singleParser::ExplistContext *context) override {
299  return 0;
300  }
301 
302  antlrcpp::Any visitExp(xasm_singleParser::ExpContext *context) override {
303  return 0;
304  }
305 
306  antlrcpp::Any visitUnaryop(
307  xasm_singleParser::UnaryopContext *context) override {
308  return 0;
309  }
310 
311  antlrcpp::Any visitId(xasm_singleParser::IdContext *context) override {
312  return 0;
313  }
314 
315  antlrcpp::Any visitReal(xasm_singleParser::RealContext *context) override {
316  return 0;
317  }
318 
319  antlrcpp::Any visitString(
320  xasm_singleParser::StringContext *context) override {
321  return 0;
322  }
323 };
xasm_single_visitor
Definition: xasm_single_visitor.hpp:16