QCOR
llvm_ir_visitors.hpp
1 #pragma once
2 #include "llvm/ADT/StringRef.h"
3 #include "llvm/IR/Constants.h"
4 #include "llvm/IR/DerivedTypes.h"
5 #include "llvm/IR/IRBuilder.h"
6 #include "llvm/IR/InstVisitor.h"
7 #include "llvm/IR/InstrTypes.h"
8 #include "llvm/IR/Instructions.h"
9 
10 #include "xacc.hpp"
11 #include <Instruction.hpp>
12 #include <Utils.hpp>
13 #include <cxxabi.h>
14 
15 using namespace llvm;
16 
17 namespace qcor {
18 
19 std::map<std::string, std::string> qrt_to_xacc{
20  {"h", "H"}, {"rz", "Rz"}, {"ry", "Ry"}, {"rx", "Rx"},
21  {"x", "X"}, {"y", "Y"}, {"z", "Z"}, {"s", "S"},
22  {"t", "T"}, {"sdg", "Sdg"}, {"tdg", "Tdg"}, {"cy", "CY"},
23  {"cz", "CZ"}, {"swap", "Swap"}, {"crz", "CRZ"}, {"ch", "CH"},
24  {"cphase", "CPhase"}, {"i", "I"}, {"u", "U"}, {"u1", "U1"},
25  {"cnot", "CNOT"}, {"mz", "Measure"}};
26 
27 
29  : public InstVisitor<FindFunctionVariableStoreInsts> {
30 protected:
31  std::vector<std::string> function_variable_names;
32 
33 public:
34  std::map<std::string, StoreInst *> stores;
35 
36  FindFunctionVariableStoreInsts(std::vector<std::string> &func_var_names)
37  : function_variable_names(func_var_names) {}
38 
39  void visitStoreInst(StoreInst &store) {
40  auto name = store.getOperand(1)->getName().str();
41  auto tmp_name = xacc::split(name, '.')[0];
42  if (xacc::container::contains(function_variable_names, tmp_name)) {
43  stores.insert({name, &store});
44  }
45  }
46 };
47 
48 class LLVM_IR_To_XACC : public InstVisitor<LLVM_IR_To_XACC> {
49 protected:
50  std::shared_ptr<xacc::IRProvider> provider;
51  std::vector<std::size_t> seen_qbit_idxs;
52  std::vector<std::string> buffer_names;
53  std::set<std::string> unique_buffer_names;
54  std::map<std::string, std::string> new_to_old_qreg_names;
55 
56 public:
57  std::shared_ptr<xacc::CompositeInstruction> composite;
58 
59  LLVM_IR_To_XACC(std::map<std::string, std::string> &m)
60  : new_to_old_qreg_names(m) {
61  provider = xacc::getIRProvider("quantum");
62  composite = provider->createComposite("tmp");
63  }
64 
65  auto demangle(const char *name) {
66  int status = -1;
67  std::unique_ptr<char, void (*)(void *)> res{
68  abi::__cxa_demangle(name, NULL, NULL, &status), std::free};
69  return (status == 0) ? res.get() : std::string(name);
70  }
71 
72  void visitCallInst(CallInst &call) {
73  auto f = call.getCalledFunction();
74  if (f && demangle(f->getName().str().c_str()).find("qreg::operator[]") !=
75  std::string::npos) {
76  if (auto *const_int = dyn_cast<ConstantInt>(call.getOperand(2))) {
77  auto bit_idx = const_int->getValue().getLimitedValue();
78  seen_qbit_idxs.push_back(bit_idx);
79  }
80  auto seen_buf_name = call.getOperand(1)->getName().str();
81  if (new_to_old_qreg_names.count(seen_buf_name)) {
82  seen_buf_name = new_to_old_qreg_names[seen_buf_name];
83  }
84  buffer_names.push_back(seen_buf_name);
85  unique_buffer_names.insert(seen_buf_name);
86  }
87  }
88 
89  void visitInvokeInst(InvokeInst &invoke) {
90  auto f = invoke.getCalledFunction();
91  if (f && demangle(f->getName().str().c_str()).find("qreg::operator[]") !=
92  std::string::npos) {
93  if (auto *const_int = dyn_cast<ConstantInt>(invoke.getOperand(2))) {
94  auto bit_idx = const_int->getValue().getLimitedValue();
95  seen_qbit_idxs.push_back(bit_idx);
96  }
97 
98  auto seen_buf_name = invoke.getOperand(1)->getName().str();
99  if (new_to_old_qreg_names.count(seen_buf_name)) {
100  seen_buf_name = new_to_old_qreg_names[seen_buf_name];
101  }
102  buffer_names.push_back(seen_buf_name);
103  unique_buffer_names.insert(seen_buf_name);
104  } else if (f && demangle(f->getName().str().c_str()).find("quantum::") !=
105  std::string::npos) {
106 
107  auto qrt_call_str = demangle(f->getName().str().c_str());
108  auto split = xacc::split(qrt_call_str, ':');
109  auto qrt_name = split[2].substr(0, split[2].find_first_of("("));
110 
111  if (qcor::qrt_to_xacc.count(qrt_name)) {
112  auto xacc_name = qcor::qrt_to_xacc[qrt_name];
113  auto inst = provider->createInstruction(xacc_name, seen_qbit_idxs);
114  inst->setBufferNames(buffer_names);
115 
116  if (inst->nParameters() > 0) {
117  xacc::InstructionParameter p;
118  if (auto constant_double =
119  dyn_cast<ConstantFP>(invoke.getOperand(1))) {
120  errs() << "Can get the double too "
121  << constant_double->getValueAPF().convertToDouble() << "\n";
122  p = constant_double->getValueAPF().convertToDouble();
123  } else {
124  auto prev_node = invoke.getPrevNode();
125  if (auto load = dyn_cast<LoadInst>(prev_node)) {
126 
127  // this was loading the parameter,
128  // lets get the name as a string
129  auto param_str = load->getOperand(0)->getName().str();
130  param_str = xacc::split(param_str, '.')[0];
131  errs() << "HELLO WORLD: " << param_str << "\n";
132  p = param_str;
133  composite->addVariable(param_str);
134  }
135  }
136 
137  inst->setParameter(0, p);
138  // exit(0);
139  }
140  seen_qbit_idxs.clear();
141  buffer_names.clear();
142  composite->addInstruction(inst);
143  }
144  }
145  }
146 };
147 
148 class XACC_To_LLVM_IR : public InstVisitor<XACC_To_LLVM_IR> {
149 protected:
150  bool has_run_once = false;
151  xacc::CompositeInstruction *program;
152  Module *module;
153  Function *simple_one_qbit;
154  LLVMContext &context;
155 
156  std::map<std::string, StoreInst *> &variable_store_insts;
157 
158 public:
159  BasicBlock *basic_block;
160  BasicBlock *execution_block;
161 
162  XACC_To_LLVM_IR(Module *mod, std::map<std::string, StoreInst *> &vsi,
163  xacc::CompositeInstruction *c)
164  : program(c), module(mod), context(mod->getContext()),
165  variable_store_insts(vsi) {}
166  auto demangle(const char *name) {
167  int status = -1;
168  std::unique_ptr<char, void (*)(void *)> res{
169  abi::__cxa_demangle(name, NULL, NULL, &status), std::free};
170  return (status == 0) ? res.get() : std::string(name);
171  }
172  void visitBasicBlock(BasicBlock &bb) {
173  // we are looking for if (__execute) block, so
174  // looking for __execute load inst, should be second in the block
175  if (bb.getInstList().size() > 1) {
176 
177  auto inst_iter = bb.getInstList().begin();
178  inst_iter++;
179  if (isa<LoadInst>(*inst_iter) &&
180  demangle(dyn_cast<LoadInst>(&*inst_iter)
181  ->getOperand(0)
182  ->getName()
183  .str()
184  .c_str()) == "xacc::internal_compiler::__execute") {
185  execution_block = &bb;
186  Instruction *first_inst = &*bb.getInstList().begin();
187  if (auto call = dyn_cast<CallInst>(first_inst)) {
188  auto f = call->getCalledFunction();
189  if (f &&
190  demangle(f->getName().str().c_str())
191  .find("std::pair<std::__cxx11::basic_string<char, "
192  "std::char_traits<char>, std::allocator<char> >, "
193  "unsigned long>::~pair()") != std::string::npos) {
194  call->eraseFromParent();
195  }
196  }
197  return;
198  }
199  }
200  }
201 
202  void visitInvokeInst(InvokeInst &invoke) {
203  auto f = invoke.getCalledFunction();
204  if (f && demangle(f->getName().str().c_str()).find("quantum::") !=
205  std::string::npos) {
206  auto qrt_call_str = demangle(f->getName().str().c_str());
207  auto split = xacc::split(qrt_call_str, ':');
208  auto qrt_name = split[2].substr(0, split[2].find_first_of("("));
209  if (qrt_to_xacc.count(qrt_name)) {
210  if (!has_run_once) {
211  auto normal_next = invoke.getNormalDest();
212  auto except_next = invoke.getUnwindDest();
213  Instruction *last_node = &invoke;
214 
215  Function *one_qubit = nullptr;
216  Function *one_qubit_param = nullptr;
217  Function *two_qubit = nullptr;
218  // this is our first quantum call...
219  for (int i = program->nInstructions() - 1; i >= 0; i--) {
220  auto inst = program->getInstruction(i);
221  // create call to qrt internal simple* call...
222  int n_bits = inst->nRequiredBits(), n_params = inst->nParameters();
223 
224  if (inst->name() == "Measure") {
225  n_params = 0;
226  }
227 
228  IRBuilder<> builder(last_node->getParent());
229 
230  // create the gate name string
231  Constant *gate_name = builder.CreateGlobalStringPtr(inst->name());
232  // create the buffer register name string
233  Constant *buf =
234  builder.CreateGlobalStringPtr(inst->getBufferNames()[0]);
235 
236  if (n_bits == 1) {
237  // If this is the first iteration, grab
238  // the invoke inst's previous node and erase it
239  // it corresponds to the qreg[IDX] call, we don't need it
240  if (i == program->nInstructions() - 1) {
241  last_node->getPrevNode()->eraseFromParent();
242  }
243 
244  std::vector<Type *> arg_types_vec{
245  gate_name->getType(), gate_name->getType(),
246  FunctionType::getInt64Ty(context)};
247  std::vector<Value *> args_vec{
248  gate_name, buf,
249  ConstantInt::get(IntegerType::getInt64Ty(context),
250  APInt(64, inst->bits()[0]))};
251  if (n_params > 0) {
252  // rotation gate
253  for (auto &p : inst->getParameters()) {
254 
255  // add to the arg types
256  arg_types_vec.push_back(FunctionType::getDoubleTy(context));
257 
258  if (p.isVariable()) {
259 
260  // this parameter string should correspond to
261  // and argument on the function
262  // TODO Add a LoadInst to load VARNAME.addr
263  // then add that return value to args.
264  // %0 = load double, double* %angle.addr, align 8
265  // LoadInst(Type *Ty, Value *Ptr, const Twine &NameStr,
266  // bool isVolatile,
267  // Instruction *InsertBefore = nullptr);
268  auto store_inst_key = p.toString() + ".addr";
269  auto load = new LoadInst(
270  FunctionType::getDoubleTy(context),
271  variable_store_insts[store_inst_key]->getOperand(1),
272  "tmp_" + p.toString(), false, last_node);
273  auto load_value = load->getOperand(0);
274  args_vec.push_back(load);
275  } else {
276  args_vec.push_back(ConstantFP::get(
277  FunctionType::getDoubleTy(context),
278  APFloat(xacc::InstructionParameterToDouble(p))));
279  }
280  }
281  }
282 
283  ArrayRef<Type *> arg_types(arg_types_vec);
284  ArrayRef<Value *> args(args_vec);
285 
286  FunctionType *ftype = FunctionType::get(
287  FunctionType::getVoidTy(context), arg_types, false);
288 
289  Instruction *new_inst = nullptr;
290 
291  if (n_params > 0) {
292  if (!one_qubit_param) {
293  one_qubit_param = Function::Create(
294  ftype, Function::ExternalLinkage,
295  "_ZN4xacc17internal_compiler38simplified_qrt_call_one_"
296  "qbit_one_paramEPKcS2_md",
297  module);
298  }
299 
300  // Create the call inst and add it to the Function
301  if (i == program->nInstructions() - 1) {
302  new_inst =
303  InvokeInst::Create(ftype, one_qubit_param, normal_next,
304  except_next, args, "", last_node);
305  } else {
306  new_inst = CallInst::Create(ftype, one_qubit_param, args, "",
307  last_node);
308  }
309  } else {
310  if (!one_qubit) {
311  one_qubit = Function::Create(
312  ftype, Function::ExternalLinkage,
313  "_ZN4xacc17internal_compiler28simplified_qrt_call_"
314  "one_qbitEPKcS2_m",
315  module);
316  }
317 
318  // Create the call inst and add it to the Function
319  if (i == program->nInstructions() - 1) {
320  new_inst =
321  InvokeInst::Create(ftype, one_qubit, normal_next,
322  except_next, args, "", last_node);
323  } else {
324  new_inst =
325  CallInst::Create(ftype, one_qubit, args, "", last_node);
326  }
327  }
328 
329  last_node = new_inst;
330 
331  if (i == program->nInstructions() - 1) {
332  // save this basic block...
333  basic_block = invoke.getParent();
334  invoke.eraseFromParent();
335  }
336 
337  } else {
338  // 2 qubit gates
339  if (n_params > 0) {
340  // TODO FIXME add 2 qubit gates with param
341  } else {
342  // _ZN4xacc17internal_compiler29simplified_qrt_call_two_qbitsEPKcS2_S2_mm
343 
344  // create the buffer register name string
345  Constant *buf_2 =
346  builder.CreateGlobalStringPtr(inst->getBufferNames()[1]);
347 
348  // set the argument Types for this function call (char *,
349  // char *, size_t)
350  ArrayRef<Type *> arg_types{
351  gate_name->getType(), gate_name->getType(),
352  gate_name->getType(), FunctionType::getInt64Ty(context),
353  FunctionType::getInt64Ty(context)};
354 
355  // void return type, create the FunctionType instance
356  FunctionType *ftype = FunctionType::get(
357  FunctionType::getVoidTy(context), arg_types, false);
358 
359  if (!two_qubit) {
360  two_qubit = Function::Create(
361  ftype, Function::ExternalLinkage,
362  "_ZN4xacc17internal_compiler29simplified_qrt_call_"
363  "two_qbitsEPKcS2_S2_mm",
364  module);
365  }
366  // create actual argument values
367  ArrayRef<Value *> args{
368  gate_name, buf, buf_2,
369  ConstantInt::get(IntegerType::getInt64Ty(context),
370  APInt(64, inst->bits()[0])),
371  ConstantInt::get(IntegerType::getInt64Ty(context),
372  APInt(64, inst->bits()[1]))};
373 
374  // Create the call inst and add it to the Function
375  Instruction *new_inst;
376  if (i == program->nInstructions() - 1) {
377  new_inst =
378  InvokeInst::Create(ftype, two_qubit, normal_next,
379  except_next, args, "", last_node);
380  } else {
381  new_inst =
382  CallInst::Create(ftype, two_qubit, args, "", last_node);
383  }
384 
385  last_node = new_inst;
386 
387  if (i == program->nInstructions() - 1) {
388  // save this basic block...
389  basic_block = invoke.getParent();
390  invoke.eraseFromParent();
391  }
392  }
393  }
394  }
395 
396  has_run_once = true;
397  } else {
398  auto containing_bb = invoke.getParent();
399 
400  int n_bits = xacc::getIRProvider("quantum")->getNRequiredBits(
401  qrt_to_xacc[qrt_name]);
402  if (n_bits == 1) {
403  auto node = invoke.getPrevNode();
404 
405  if (isa<LoadInst>(node)) {
406  // this is only for single parameter gates
407  // this means that we have load on a gate parameter
408  node = node->getPrevNode();
409  }
410  node->eraseFromParent();
411 
412  } else if (n_bits == 2) {
413  // invoke.getPrevNode()->getPrevNode()->eraseFromParent();
414  // invoke.getPrevNode()->eraseFromParent();
415  containing_bb->getPrevNode();
416  }
417  invoke.eraseFromParent();
418 
419  if (!containing_bb->hasNPredecessorsOrMore(1)) {
420  containing_bb->eraseFromParent();
421  }
422  }
423  }
424  }
425  }
426 };
427 
429  : public InstVisitor<SearchForCallsToThisFunction> {
430 
431 protected:
432  Function *function;
433 
434 public:
435  CallInst *found_call = nullptr;
436 
437  SearchForCallsToThisFunction(Function *f) : function(f) {}
438  auto demangle(const char *name) {
439  int status = -1;
440  std::unique_ptr<char, void (*)(void *)> res{
441  abi::__cxa_demangle(name, NULL, NULL, &status), std::free};
442  return (status == 0) ? res.get() : std::string(name);
443  }
444  void visitCallInst(CallInst &call) {
445  if (call.getCalledFunction() != nullptr) {
446  // errs() <<
447  // demangle(call.getCalledFunction()->getName().str().c_str()) <<
448  // "\n";
449  if (call.getCalledFunction() == function) {
450  errs() << "Call Found our function\n";
451  call.dump();
452  call.getParent()->getParent();
453  found_call = &call;
454  }
455  }
456  }
457  void visitInvokeInst(InvokeInst &call) {
458  if (call.getCalledFunction() != nullptr) {
459  // errs() <<
460  // demangle(call.getCalledFunction()->getName().str().c_str()) <<
461  // "\n";
462  if (call.getCalledFunction() == function) {
463  errs() << "Invoke Found our function\n";
464  call.dump();
465  }
466  }
467  }
468 };
469 } // namespace qcor
qcor::SearchForCallsToThisFunction
Definition: llvm_ir_visitors.hpp:428
qcor::FindFunctionVariableStoreInsts
Definition: llvm_ir_visitors.hpp:28
qcor::LLVM_IR_To_XACC
Definition: llvm_ir_visitors.hpp:48
qcor
Definition: qcor_syntax_handler.cpp:15
qcor::XACC_To_LLVM_IR
Definition: llvm_ir_visitors.hpp:148