heterogeneous.hpp 9.31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*******************************************************************************
 * Copyright (c) 2019 UT-Battelle, LLC.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * and Eclipse Distribution License v1.0 which accompanies this
 * distribution. The Eclipse Public License is available at
 * http://www.eclipse.org/legal/epl-v10.html and the Eclipse Distribution
 *License is available at https://eclipse.org/org/documents/edl-v10.php
 *
 * Contributors:
 *   Alexander J. McCaskey - initial API and implementation
 *******************************************************************************/
13
14
#ifndef XACC_HETEROGENEOUS_HPP_
#define XACC_HETEROGENEOUS_HPP_
Mccaskey, Alex's avatar
Mccaskey, Alex committed
15
#include <Utils.hpp>
16
17
18
19
20
21
22
23
24
25
26
#include <map>
#include <stdexcept>
#include <vector>
#include <unordered_map>
#include <functional>
#include <iostream>
#include <experimental/type_traits>

#include "variant.hpp"

#include "Utils.hpp"
27

28
29
namespace xacc {

30
31
32
33
34
35
class HeterogeneousMap;

template <class...> struct is_heterogeneous_map : std::false_type {};

template <> struct is_heterogeneous_map<HeterogeneousMap> : std::true_type {};

36
37
38
39
40
41
42
43
44
45
template <class...> struct type_list {};

template <class... TYPES> struct visitor_base {
  using types = xacc::type_list<TYPES...>;
};

class HeterogeneousMap {
public:
  HeterogeneousMap() = default;
  HeterogeneousMap(const HeterogeneousMap &_other) { *this = _other; }
46
  HeterogeneousMap(HeterogeneousMap &_other) { *this = _other; }
47

48
49
50
51
52
53
54
55
56
57
58
  HeterogeneousMap &operator=(const HeterogeneousMap &_other) {
    clear();
    clear_functions = _other.clear_functions;
    copy_functions = _other.copy_functions;
    size_functions = _other.size_functions;
    for (auto &&copy_function : copy_functions) {
      copy_function(_other, *this);
    }
    return *this;
  }

59
60
61
62
63
64
65
66
67
68
  template <typename T> void loop_pairs(T value) {
    insert(value.first, value.second);
  }

  template <typename First, typename... Rest>
  void loop_pairs(First firstValue, Rest... rest) {
    loop_pairs(firstValue);
    loop_pairs(rest...);
  }

69
70
71
72
  template <typename... TYPES,
            typename = std::enable_if_t<!is_heterogeneous_map<
                std::remove_cv_t<std::remove_reference_t<TYPES>>...>::value>>
  HeterogeneousMap(TYPES &&... list) {
73
74
75
    loop_pairs(list...);
  }

76
  template <typename... Ts> void print(std::ostream &os) const {
77
78
79
80
    _internal_print_visitor<Ts...> v(os);
    visit(v);
  }

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
  template <class T> void insert(const std::string key, const T &_t) {
    // don't have it yet, so create functions for printing, copying, moving, and
    // destroying
    if (items<T>.find(this) == std::end(items<T>)) {
      clear_functions.emplace_back(
          [](HeterogeneousMap &_c) { items<T>.erase(&_c); });

      // if someone copies me, they need to call each copy_function and pass
      // themself
      copy_functions.emplace_back(
          [](const HeterogeneousMap &_from, HeterogeneousMap &_to) {
            items<T>[&_to] = items<T>[&_from];
          });
      size_functions.emplace_back(
          [](const HeterogeneousMap &_c) { return items<T>[&_c].size(); });
    }
    items<T>[this].insert({key, _t});
  }
99
100
101
102
103
104
105
106
107
108
109

  template <class T> T &get_mutable(const std::string key) const {
    if (!items<T>.count(this) && !items<T>[this].count(key)) {
      XACCLogger::instance()->error("Invalid type (" +
                                    std::string(typeid(T).name()) +
                                    ") or key (" + key + ").");
      print_backtrace();
    }
    return items<T>[this][key];
  }

110
  template <class T> const T &get(const std::string key) const {
Mccaskey, Alex's avatar
Mccaskey, Alex committed
111
    if (!keyExists<T>(key)) {
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
      XACCLogger::instance()->error("Invalid type (" +
                                    std::string(typeid(T).name()) +
                                    ") or key (" + key + ").");
      print_backtrace();
    }
    return items<T>[this][key];
  }

  template <class T> const T &get_with_throw(const std::string key) const {
    if (!items<T>.count(this) && !items<T>[this].count(key)) {
      throw new std::runtime_error("Invalid type (" +
                                   std::string(typeid(T).name()) + ").");
    }
    return items<T>[this][key];
  }

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  bool stringExists(const std::string key) const {
    if (keyExists<const char *>(key)) {
      return true;
    }
    if (keyExists<std::string>(key)) {
      return true;
    }
    return false;
  }

  const std::string getString(const std::string key) const {
    if (keyExists<const char *>(key)) {
      return get<const char *>(key);
    } else if (keyExists<std::string>(key)) {
      return get<std::string>(key);
    } else {
      XACCLogger::instance()->error("No string-like value at provided key (" +
                                    key + ").");
      print_backtrace();
    }
    return "";
  }

151
152
153
154
155
156
157
158
159
160
161
162
163
  void clear() {
    for (auto &&clear_func : clear_functions) {
      clear_func(*this);
    }
  }

  template <class T> size_t number_of() const {
    auto iter = items<T>.find(this);
    if (iter != items<T>.cend())
      return items<T>[this].size();
    return 0;
  }

164
165
166
167
168
  template <typename T> bool keyExists(const std::string key) const {
    if (items<T>.count(this) && items<T>[this].count(key)) {
      return true;
    }
    return false;
169
170
171
172
173
174
175
176
177
178
179
180
181
  }

  size_t size() const {
    size_t sum = 0;
    for (auto &&size_func : size_functions) {
      sum += size_func(*this);
    }
    // gotta be careful about this overflowing
    return sum;
  }

  ~HeterogeneousMap() { clear(); }

182
  template <class T> void visit(T &&visitor) const {
183
184
185
186
    visit_impl(visitor, typename std::decay_t<T>::types{});
  }

private:
187
  template <typename... Ts>
188
  class _internal_print_visitor : public visitor_base<Ts...> {
189
190
191
192
193
194
195
196
197
  private:
    std::ostream &ss;

  public:
    _internal_print_visitor(std::ostream &s) : ss(s) {}

    template <typename T> void operator()(const std::string &key, const T &t) {
      ss << key << ": " << t << "\n";
    }
198
199
  };

200
201
202
203
204
  template <class T>
  static std::unordered_map<const HeterogeneousMap *, std::map<std::string, T>>
      items;

  template <class T, class U>
205
206
  using visit_function = decltype(std::declval<T>().operator()(
      std::declval<const std::string &>(), std::declval<U &>()));
207
208
209
210
211
  template <class T, class U>
  static constexpr bool has_visit_v =
      std::experimental::is_detected<visit_function, T, U>::value;

  template <class T, template <class...> class TLIST, class... TYPES>
212
  void visit_impl(T &&visitor, TLIST<TYPES...>) const {
213
214
215
216
    using expander = int[];
    (void)expander{0, (void(visit_impl_help<T, TYPES>(visitor)), 0)...};
  }

217
  template <class T, class U> void visit_impl_help(T &visitor) const {
218
219
220
    static_assert(has_visit_v<T, U>, "Visitors must provide a visit function "
                                     "accepting a reference for each type");
    for (auto &&element : items<U>[this]) {
221
      visitor(element.first, element.second);
222
223
224
225
226
227
228
229
230
231
232
233
234
    }
  }

  std::vector<std::function<void(HeterogeneousMap &)>> clear_functions;
  std::vector<std::function<void(const HeterogeneousMap &, HeterogeneousMap &)>>
      copy_functions;
  std::vector<std::function<size_t(const HeterogeneousMap &)>> size_functions;
};

template <class T>
std::unordered_map<const HeterogeneousMap *, std::map<std::string, T>>
    HeterogeneousMap::items;

235
template <typename... Types> class Variant : public mpark::variant<Types...> {
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

private:
  class ToStringVisitor {
  public:
    template <typename T> std::string operator()(const T &t) const {
      std::stringstream ss;
      ss << t;
      return ss.str();
    }
  };

  class IsArithmeticVisitor {
  public:
    template <typename T> bool operator()(const T &t) const {
      return std::is_arithmetic<T>::value;
    }
  };

254
255
256
  template <typename To, typename From> class CastVisitor {
  public:
    To operator()(const From &t) const { return (To)t; }
257
258
259
  };

public:
260
  Variant() : mpark::variant<Types...>() {}
261
  template <typename T>
262
  Variant(T &element) : mpark::variant<Types...>(element) {}
263
  template <typename T>
264
265
  Variant(T &&element) : mpark::variant<Types...>(element) {}
  Variant(const Variant &element) : mpark::variant<Types...>(element) {}
266
267
268
269
270
271

  template <typename T> T as() const {
    try {
      // First off just try to get it
      return mpark::get<T>(*this);
    } catch (std::exception &e) {
272
273
274
275
276
277
278
279
      std::stringstream s;
      s << "InstructionParameter::this->toString() = " << toString() << "\n";
      s << "This InstructionParameter type id is " << this->which() << "\n";
      XACCLogger::instance()->error("Cannot cast Variant to (" +
                                    std::string(typeid(T).name()) + "):\n" +
                                    s.str());
      print_backtrace();
      exit(0);
280
281
282
283
284
    }
    return T();
  }

  template <typename T> T as_no_error() const {
285
286
    // First off just try to get it
    return mpark::get<T>(*this);
287
288
  }

289
  int which() const { return this->index(); }
290

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
  bool isNumeric() const {
    IsArithmeticVisitor v;
    return mpark::visit(v, *this);
  }

  bool isVariable() const {
    try {
      mpark::get<std::string>(*this);
    } catch (std::exception &e) {
      return false;
    }
    return true;
  }

  const std::string toString() const {
    ToStringVisitor vis;
    return mpark::visit(vis, *this);
  }

310
  bool operator==(const Variant<Types...> &v) const {
311
312
313
    return v.toString() == toString();
  }

314
  bool operator!=(const Variant<Types...> &v) const { return !operator==(v); }
315
316
317
};
} // namespace xacc
#endif