2 #include "llvm/ADT/StringRef.h"
3 #include "llvm/IR/Constants.h"
4 #include "llvm/IR/DerivedTypes.h"
5 #include "llvm/IR/IRBuilder.h"
6 #include "llvm/IR/InstVisitor.h"
7 #include "llvm/IR/InstrTypes.h"
8 #include "llvm/IR/Instructions.h"
11 #include <Instruction.hpp>
19 std::map<std::string, std::string> qrt_to_xacc{
20 {
"h",
"H"}, {
"rz",
"Rz"}, {
"ry",
"Ry"}, {
"rx",
"Rx"},
21 {
"x",
"X"}, {
"y",
"Y"}, {
"z",
"Z"}, {
"s",
"S"},
22 {
"t",
"T"}, {
"sdg",
"Sdg"}, {
"tdg",
"Tdg"}, {
"cy",
"CY"},
23 {
"cz",
"CZ"}, {
"swap",
"Swap"}, {
"crz",
"CRZ"}, {
"ch",
"CH"},
24 {
"cphase",
"CPhase"}, {
"i",
"I"}, {
"u",
"U"}, {
"u1",
"U1"},
25 {
"cnot",
"CNOT"}, {
"mz",
"Measure"}};
29 :
public InstVisitor<FindFunctionVariableStoreInsts> {
31 std::vector<std::string> function_variable_names;
34 std::map<std::string, StoreInst *> stores;
37 : function_variable_names(func_var_names) {}
39 void visitStoreInst(StoreInst &store) {
40 auto name = store.getOperand(1)->getName().str();
41 auto tmp_name = xacc::split(name,
'.')[0];
42 if (xacc::container::contains(function_variable_names, tmp_name)) {
43 stores.insert({name, &store});
50 std::shared_ptr<xacc::IRProvider> provider;
51 std::vector<std::size_t> seen_qbit_idxs;
52 std::vector<std::string> buffer_names;
53 std::set<std::string> unique_buffer_names;
54 std::map<std::string, std::string> new_to_old_qreg_names;
57 std::shared_ptr<xacc::CompositeInstruction> composite;
60 : new_to_old_qreg_names(m) {
61 provider = xacc::getIRProvider(
"quantum");
62 composite = provider->createComposite(
"tmp");
65 auto demangle(
const char *name) {
67 std::unique_ptr<char, void (*)(
void *)> res{
68 abi::__cxa_demangle(name, NULL, NULL, &status), std::free};
69 return (status == 0) ? res.get() : std::string(name);
72 void visitCallInst(CallInst &call) {
73 auto f = call.getCalledFunction();
74 if (f && demangle(f->getName().str().c_str()).find(
"qreg::operator[]") !=
76 if (
auto *const_int = dyn_cast<ConstantInt>(call.getOperand(2))) {
77 auto bit_idx = const_int->getValue().getLimitedValue();
78 seen_qbit_idxs.push_back(bit_idx);
80 auto seen_buf_name = call.getOperand(1)->getName().str();
81 if (new_to_old_qreg_names.count(seen_buf_name)) {
82 seen_buf_name = new_to_old_qreg_names[seen_buf_name];
84 buffer_names.push_back(seen_buf_name);
85 unique_buffer_names.insert(seen_buf_name);
89 void visitInvokeInst(InvokeInst &invoke) {
90 auto f = invoke.getCalledFunction();
91 if (f && demangle(f->getName().str().c_str()).find(
"qreg::operator[]") !=
93 if (
auto *const_int = dyn_cast<ConstantInt>(invoke.getOperand(2))) {
94 auto bit_idx = const_int->getValue().getLimitedValue();
95 seen_qbit_idxs.push_back(bit_idx);
98 auto seen_buf_name = invoke.getOperand(1)->getName().str();
99 if (new_to_old_qreg_names.count(seen_buf_name)) {
100 seen_buf_name = new_to_old_qreg_names[seen_buf_name];
102 buffer_names.push_back(seen_buf_name);
103 unique_buffer_names.insert(seen_buf_name);
104 }
else if (f && demangle(f->getName().str().c_str()).find(
"quantum::") !=
107 auto qrt_call_str = demangle(f->getName().str().c_str());
108 auto split = xacc::split(qrt_call_str,
':');
109 auto qrt_name = split[2].substr(0, split[2].find_first_of(
"("));
111 if (qcor::qrt_to_xacc.count(qrt_name)) {
112 auto xacc_name = qcor::qrt_to_xacc[qrt_name];
113 auto inst = provider->createInstruction(xacc_name, seen_qbit_idxs);
114 inst->setBufferNames(buffer_names);
116 if (inst->nParameters() > 0) {
117 xacc::InstructionParameter p;
118 if (
auto constant_double =
119 dyn_cast<ConstantFP>(invoke.getOperand(1))) {
120 errs() <<
"Can get the double too "
121 << constant_double->getValueAPF().convertToDouble() <<
"\n";
122 p = constant_double->getValueAPF().convertToDouble();
124 auto prev_node = invoke.getPrevNode();
125 if (
auto load = dyn_cast<LoadInst>(prev_node)) {
129 auto param_str = load->getOperand(0)->getName().str();
130 param_str = xacc::split(param_str,
'.')[0];
131 errs() <<
"HELLO WORLD: " << param_str <<
"\n";
133 composite->addVariable(param_str);
137 inst->setParameter(0, p);
140 seen_qbit_idxs.clear();
141 buffer_names.clear();
142 composite->addInstruction(inst);
150 bool has_run_once =
false;
151 xacc::CompositeInstruction *program;
153 Function *simple_one_qbit;
154 LLVMContext &context;
156 std::map<std::string, StoreInst *> &variable_store_insts;
159 BasicBlock *basic_block;
160 BasicBlock *execution_block;
163 xacc::CompositeInstruction *c)
164 : program(c), module(mod), context(mod->getContext()),
165 variable_store_insts(vsi) {}
166 auto demangle(
const char *name) {
168 std::unique_ptr<char, void (*)(
void *)> res{
169 abi::__cxa_demangle(name, NULL, NULL, &status), std::free};
170 return (status == 0) ? res.get() : std::string(name);
172 void visitBasicBlock(BasicBlock &bb) {
175 if (bb.getInstList().size() > 1) {
177 auto inst_iter = bb.getInstList().begin();
179 if (isa<LoadInst>(*inst_iter) &&
180 demangle(dyn_cast<LoadInst>(&*inst_iter)
184 .c_str()) ==
"xacc::internal_compiler::__execute") {
185 execution_block = &bb;
186 Instruction *first_inst = &*bb.getInstList().begin();
187 if (
auto call = dyn_cast<CallInst>(first_inst)) {
188 auto f = call->getCalledFunction();
190 demangle(f->getName().str().c_str())
191 .find(
"std::pair<std::__cxx11::basic_string<char, "
192 "std::char_traits<char>, std::allocator<char> >, "
193 "unsigned long>::~pair()") != std::string::npos) {
194 call->eraseFromParent();
202 void visitInvokeInst(InvokeInst &invoke) {
203 auto f = invoke.getCalledFunction();
204 if (f && demangle(f->getName().str().c_str()).find(
"quantum::") !=
206 auto qrt_call_str = demangle(f->getName().str().c_str());
207 auto split = xacc::split(qrt_call_str,
':');
208 auto qrt_name = split[2].substr(0, split[2].find_first_of(
"("));
209 if (qrt_to_xacc.count(qrt_name)) {
211 auto normal_next = invoke.getNormalDest();
212 auto except_next = invoke.getUnwindDest();
213 Instruction *last_node = &invoke;
215 Function *one_qubit =
nullptr;
216 Function *one_qubit_param =
nullptr;
217 Function *two_qubit =
nullptr;
219 for (
int i = program->nInstructions() - 1; i >= 0; i--) {
220 auto inst = program->getInstruction(i);
222 int n_bits = inst->nRequiredBits(), n_params = inst->nParameters();
224 if (inst->name() ==
"Measure") {
228 IRBuilder<> builder(last_node->getParent());
231 Constant *gate_name = builder.CreateGlobalStringPtr(inst->name());
234 builder.CreateGlobalStringPtr(inst->getBufferNames()[0]);
240 if (i == program->nInstructions() - 1) {
241 last_node->getPrevNode()->eraseFromParent();
244 std::vector<Type *> arg_types_vec{
245 gate_name->getType(), gate_name->getType(),
246 FunctionType::getInt64Ty(context)};
247 std::vector<Value *> args_vec{
249 ConstantInt::get(IntegerType::getInt64Ty(context),
250 APInt(64, inst->bits()[0]))};
253 for (
auto &p : inst->getParameters()) {
256 arg_types_vec.push_back(FunctionType::getDoubleTy(context));
258 if (p.isVariable()) {
268 auto store_inst_key = p.toString() +
".addr";
269 auto load =
new LoadInst(
270 FunctionType::getDoubleTy(context),
271 variable_store_insts[store_inst_key]->getOperand(1),
272 "tmp_" + p.toString(),
false, last_node);
273 auto load_value = load->getOperand(0);
274 args_vec.push_back(load);
276 args_vec.push_back(ConstantFP::get(
277 FunctionType::getDoubleTy(context),
278 APFloat(xacc::InstructionParameterToDouble(p))));
283 ArrayRef<Type *> arg_types(arg_types_vec);
284 ArrayRef<Value *> args(args_vec);
286 FunctionType *ftype = FunctionType::get(
287 FunctionType::getVoidTy(context), arg_types,
false);
289 Instruction *new_inst =
nullptr;
292 if (!one_qubit_param) {
293 one_qubit_param = Function::Create(
294 ftype, Function::ExternalLinkage,
295 "_ZN4xacc17internal_compiler38simplified_qrt_call_one_"
296 "qbit_one_paramEPKcS2_md",
301 if (i == program->nInstructions() - 1) {
303 InvokeInst::Create(ftype, one_qubit_param, normal_next,
304 except_next, args,
"", last_node);
306 new_inst = CallInst::Create(ftype, one_qubit_param, args,
"",
311 one_qubit = Function::Create(
312 ftype, Function::ExternalLinkage,
313 "_ZN4xacc17internal_compiler28simplified_qrt_call_"
319 if (i == program->nInstructions() - 1) {
321 InvokeInst::Create(ftype, one_qubit, normal_next,
322 except_next, args,
"", last_node);
325 CallInst::Create(ftype, one_qubit, args,
"", last_node);
329 last_node = new_inst;
331 if (i == program->nInstructions() - 1) {
333 basic_block = invoke.getParent();
334 invoke.eraseFromParent();
346 builder.CreateGlobalStringPtr(inst->getBufferNames()[1]);
350 ArrayRef<Type *> arg_types{
351 gate_name->getType(), gate_name->getType(),
352 gate_name->getType(), FunctionType::getInt64Ty(context),
353 FunctionType::getInt64Ty(context)};
356 FunctionType *ftype = FunctionType::get(
357 FunctionType::getVoidTy(context), arg_types,
false);
360 two_qubit = Function::Create(
361 ftype, Function::ExternalLinkage,
362 "_ZN4xacc17internal_compiler29simplified_qrt_call_"
363 "two_qbitsEPKcS2_S2_mm",
367 ArrayRef<Value *> args{
368 gate_name, buf, buf_2,
369 ConstantInt::get(IntegerType::getInt64Ty(context),
370 APInt(64, inst->bits()[0])),
371 ConstantInt::get(IntegerType::getInt64Ty(context),
372 APInt(64, inst->bits()[1]))};
375 Instruction *new_inst;
376 if (i == program->nInstructions() - 1) {
378 InvokeInst::Create(ftype, two_qubit, normal_next,
379 except_next, args,
"", last_node);
382 CallInst::Create(ftype, two_qubit, args,
"", last_node);
385 last_node = new_inst;
387 if (i == program->nInstructions() - 1) {
389 basic_block = invoke.getParent();
390 invoke.eraseFromParent();
398 auto containing_bb = invoke.getParent();
400 int n_bits = xacc::getIRProvider(
"quantum")->getNRequiredBits(
401 qrt_to_xacc[qrt_name]);
403 auto node = invoke.getPrevNode();
405 if (isa<LoadInst>(node)) {
408 node = node->getPrevNode();
410 node->eraseFromParent();
412 }
else if (n_bits == 2) {
415 containing_bb->getPrevNode();
417 invoke.eraseFromParent();
419 if (!containing_bb->hasNPredecessorsOrMore(1)) {
420 containing_bb->eraseFromParent();
429 :
public InstVisitor<SearchForCallsToThisFunction> {
435 CallInst *found_call =
nullptr;
438 auto demangle(
const char *name) {
440 std::unique_ptr<char, void (*)(
void *)> res{
441 abi::__cxa_demangle(name, NULL, NULL, &status), std::free};
442 return (status == 0) ? res.get() : std::string(name);
444 void visitCallInst(CallInst &call) {
445 if (call.getCalledFunction() !=
nullptr) {
449 if (call.getCalledFunction() ==
function) {
450 errs() <<
"Call Found our function\n";
452 call.getParent()->getParent();
457 void visitInvokeInst(InvokeInst &call) {
458 if (call.getCalledFunction() !=
nullptr) {
462 if (call.getCalledFunction() ==
function) {
463 errs() <<
"Invoke Found our function\n";