Commit 05f6f87b authored by Doak, Peter W.'s avatar Doak, Peter W.
Browse files

improved debug support for matrix

parent 98f9c286
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -238,6 +238,7 @@ public:
  // Returns the allocated device memory in bytes.
  std::size_t deviceFingerprint() const;

  std::string toStr() const;
private:
  static std::pair<int, int> capacityMultipleOfBlockSize(std::pair<int, int> size);
  inline static size_t nrElements(std::pair<int, int> size) {
@@ -527,6 +528,25 @@ void Matrix<ScalarType, device_name, ALLOC>::print() const {
  std::cout << ss.str() << std::endl;
}

template <typename ScalarType, DeviceType device_name, class ALLOC>
std::string Matrix<ScalarType, device_name,  ALLOC>::toStr() const {
  if (device_name == GPU)
    return Matrix<ScalarType, CPU>(*this).toStr();

  std::stringstream ss;
  ss.precision(16);
  ss << std::scientific;

  ss << "\n";
  for (int i = 0; i < nrRows(); ++i) {
    for (int j = 0; j < nrCols(); ++j)
      ss << "\t" << operator()(i, j);
    ss << "\n";
  }

  return ss.str();
}
  
template <typename ScalarType, DeviceType device_name, class ALLOC>
void Matrix<ScalarType, device_name,  ALLOC>::printFingerprint() const {
  std::stringstream ss;
+2 −2
Original line number Diff line number Diff line
@@ -197,9 +197,9 @@ inline void copyRows(const Matrix<Scalar, GPU>& mat_x, const Vector<int, GPU>& i
template <typename Scalar, class ALLOC>
auto difference(const Matrix<Scalar, CPU, ALLOC>& a, const Matrix<Scalar, CPU, ALLOC>& b,
                double diff_threshold = 1e-3) {
  auto max_diff = std::abs(Scalar(0));
  assert(a.size() == b.size());

  auto max_diff = std::abs(Scalar(0));

  for (int j = 0; j < a.nrCols(); ++j) {
    for (int i = 0; i < a.nrRows(); ++i) {
@@ -222,7 +222,7 @@ auto difference(const Matrix<Scalar, CPU, ALLOC>& a, const Matrix<Scalar, CPU, A
    s << std::endl;
    std::cout << s.str();
#endif  // NDEBUG

    std::cerr << "matrix difference in excess of threshold!\n";
    throw std::logic_error(__FUNCTION__);
  }