XACC
rbm_classification.hpp
1 /*******************************************************************************
2  * Copyright (c) 2019 UT-Battelle, LLC.
3  * All rights reserved. This program and the accompanying materials
4  * are made available under the terms of the Eclipse Public License v1.0
5  * and Eclipse Distribution License v1.0 which accompanies this
6  * distribution. The Eclipse Public License is available at
7  * http://www.eclipse.org/legal/epl-v10.html and the Eclipse Distribution
8  *License is available at https://eclipse.org/org/documents/edl-v10.php
9  *
10  * Contributors:
11  * Alexander J. McCaskey - initial API and implementation
12  *******************************************************************************/
13 #ifndef XACC_ALGORITHM_RBM_CLASSIFICATION_HPP_
14 #define XACC_ALGORITHM_RBM_CLASSIFICATION_HPP_
15 
16 #include "Algorithm.hpp"
17 #include <vector>
18 #include <Eigen/Dense>
19 #include <fstream>
20 
21 namespace xacc {
22 namespace algorithm {
23 
25 public:
26  virtual std::tuple<Eigen::MatrixXd, Eigen::VectorXd, Eigen::VectorXd>
27  compute(Eigen::MatrixXd &features, Eigen::MatrixXd &w, Eigen::VectorXd &v,
28  Eigen::VectorXd &h, HeterogeneousMap options = {}) = 0;
29 };
30 
31 class RBMClassification : public Algorithm {
32 protected:
34  Eigen::MatrixXd training_data;
35  std::string modelExp;
36 
37  HeterogeneousMap _parameters;
38 
39  template <typename M> M load_csv(const std::string &path) {
40  std::ifstream indata;
41  indata.open(path);
42  std::string line;
43  std::vector<double> values;
44  int rows = 0;
45  while (std::getline(indata, line)) {
46  if (line[0] != '#') {
47  std::stringstream lineStream(line);
48  std::string cell;
49  while (std::getline(lineStream, cell, ',')) {
50  values.push_back(std::stod(cell));
51  }
52  ++rows;
53  }
54  }
55  return Eigen::Map<
56  const Eigen::Matrix<typename M::Scalar, M::RowsAtCompileTime,
57  M::ColsAtCompileTime, Eigen::RowMajor>>(
58  values.data(), rows, values.size() / rows);
59  }
60 
61 public:
62  bool initialize(const HeterogeneousMap &parameters) override;
63  const std::vector<std::string> requiredParameters() const override;
64 
65  void execute(const std::shared_ptr<AcceleratorBuffer> buffer) const override;
66  const std::string name() const override { return "rbm-classification"; }
67  const std::string description() const override { return ""; }
68  DEFINE_ALGORITHM_CLONE(RBMClassification)
69 };
70 } // namespace algorithm
71 } // namespace xacc
72 #endif
Definition: Algorithm.hpp:34
const std::string name() const override
Definition: rbm_classification.hpp:66
Definition: Accelerator.hpp:25
Definition: rbm_classification.hpp:24
const std::string description() const override
Definition: rbm_classification.hpp:67
Definition: heterogeneous.hpp:45
Definition: Identifiable.hpp:25
Definition: CompositeInstruction.hpp:72
Definition: rbm_classification.hpp:31