QCOR
qcor_jit.hpp
1 #pragma once
2 #include <cxxabi.h>
3 
4 #include <map>
5 #include <memory>
6 #include <string>
7 #include <vector>
8 
9 namespace llvm {
10 class Module;
11 }
12 namespace qcor {
13  class CompositeInstruction;
14 }
15 namespace xacc {
16 class HeterogeneousMap;
17 } // namespace xacc
18 
19 namespace qcor {
20 class LLVMJIT;
21 
22 class QJIT {
23  template <typename... Args>
24  using kernel_functor_t = void (*)(Args...);
25 
26  private:
27  std::map<std::size_t, std::string> cached_kernel_codes;
28  std::string demangle(const char *name) {
29  int status = -1;
30  std::unique_ptr<char, void (*)(void *)> res{
31  abi::__cxa_demangle(name, NULL, NULL, &status), std::free};
32  return (status == 0) ? res.get() : std::string(name);
33  };
34 
35  std::string qjit_cache_path = "";
36 
37  protected:
38  std::map<std::string, std::uint64_t> kernel_name_to_f_ptr;
39  std::map<std::string, std::uint64_t> kernel_name_to_f_ptr_with_parent;
40  std::map<std::string, std::uint64_t> kernel_name_to_f_ptr_hetmap;
41  std::map<std::string, std::uint64_t> kernel_name_to_f_ptr_parent_hetmap;
42 
43  std::unique_ptr<LLVMJIT> jit;
44  std::unique_ptr<llvm::Module> module;
45 
46  public:
47  QJIT();
48  ~QJIT();
49  const std::pair<std::string, std::string> run_syntax_handler(
50  const std::string &quantum_kernel_src,
51  const bool add_het_map_kernel_ctor = false);
52  void jit_compile(const std::string &quantum_kernel_src,
53  const bool add_het_map_kernel_ctor = false,
54  const std::vector<std::string> &kernel_dependency = {},
55  const std::string &extra_functions_src = "",
56  std::vector<std::string> extra_headers = {});
57 
58  void jit_compile(std::unique_ptr<llvm::Module> m,
59  std::vector<std::string> extra_shared_lib_paths = {});
60 
61  void write_cache();
62 
63  template <typename... Args>
64  void invoke(const std::string &kernel_name, Args... args) {
65  // Debug: print the Args... type
66  // std::cout << "QJIT Invoke: " << __PRETTY_FUNCTION__ << "\n";
67  auto f_ptr = kernel_name_to_f_ptr[kernel_name];
68  void (*kernel_functor)(Args...) = (void (*)(Args...))f_ptr;
69  kernel_functor(std::forward<Args>(args)...);
70  }
71 
72  template <typename... Args>
73  void invoke_with_parent(const std::string &kernel_name,
74  std::shared_ptr<qcor::CompositeInstruction> parent,
75  Args... args) {
76  // Debug: print the Args... type
77  // std::cout << "QJIT Invoke with Parent: " << __PRETTY_FUNCTION__ << "\n";
78  auto f_ptr = kernel_name_to_f_ptr_with_parent[kernel_name];
79  void (*kernel_functor)(std::shared_ptr<qcor::CompositeInstruction>,
80  Args...) =
81  (void (*)(std::shared_ptr<qcor::CompositeInstruction>, Args...))f_ptr;
82  kernel_functor(parent, std::forward<Args>(args)...);
83  }
84 
85  // Invoke with type forwarding: Args &&
86  template <typename... Args>
87  void invoke_forwarding(const std::string &kernel_name, Args &&... args) {
88  // std::cout << "QJIT Invoke: " << __PRETTY_FUNCTION__ << "\n";
89  auto f_ptr = kernel_name_to_f_ptr[kernel_name];
90  void (*kernel_functor)(Args...) = (void (*)(Args...))f_ptr;
91  kernel_functor(std::forward<Args>(args)...);
92  }
93 
94  // Invoke with type forwarding: Args &&
95  template <typename... Args>
96  void invoke_with_parent_forwarding(
97  const std::string &kernel_name,
98  std::shared_ptr<qcor::CompositeInstruction> parent, Args &&... args) {
99  // std::cout << "QJIT Invoke with Parent: " << __PRETTY_FUNCTION__ << "\n";
100  auto f_ptr = kernel_name_to_f_ptr_with_parent[kernel_name];
101  void (*kernel_functor)(std::shared_ptr<qcor::CompositeInstruction>,
102  Args...) =
103  (void (*)(std::shared_ptr<qcor::CompositeInstruction>, Args...))f_ptr;
104  kernel_functor(parent, std::forward<Args>(args)...);
105  }
106 
107  int invoke_main(int argc, char **argv) {
108  auto f_ptr = kernel_name_to_f_ptr["main"];
109  int (*kernel_functor)(int, char **) = (int (*)(int, char **))f_ptr;
110  return kernel_functor(argc, argv);
111  }
112 
113  void invoke_with_hetmap(const std::string &kernel_name,
114  xacc::HeterogeneousMap &args);
115  std::shared_ptr<qcor::CompositeInstruction> extract_composite_with_hetmap(
116  const std::string name, xacc::HeterogeneousMap &m);
117 
118  template <typename... Args>
119  kernel_functor_t<Args...> get_kernel(const std::string &kernel_name) {
120  auto f_ptr = kernel_name_to_f_ptr[kernel_name];
121  void (*kernel_functor)(Args...) = (void (*)(Args...))f_ptr;
122  return kernel_functor;
123  }
124 
125  // The type of kernel functions:
126  enum class KernelType { Regular, HetMapArg, HetMapArgWithParent };
127  // Return kernel function pointer (as an integer)
128  // Returns 0 if the kernel doesn't exist.
129  std::uint64_t get_kernel_function_ptr(
130  const std::string &kernelName,
131  KernelType subType = KernelType::Regular) const;
132 };
133 
134 } // namespace qcor
qcor::QJIT
Definition: qcor_jit.hpp:22
qcor
Definition: qcor_syntax_handler.cpp:15