Commit 445667fc by Lefebvre, Jordan P

### Working on N nearest neighbors.

parent a4a86410
Pipeline #84947 failed with stages
in 13 minutes and 30 seconds
 #ifndef RADIX_RADIXALGORITHM_KDTREE_HH #define RADIX_RADIXALGORITHM_KDTREE_HH #include "radixalgorithm/ordering.hh" #include "radixbug/bug.hh" #include #include #include ... ... @@ -11,60 +14,15 @@ namespace radix { /** * Class for representing a Point. coordinate_type must be a numeric type. */ template class Point { public: Point(const std::array& c) : mCoords(c) { } Point(std::initializer_list list) { size_t n = std::min(dimensions, list.size()); std::copy_n(list.begin(), n, mCoords.begin()); } /** * Returns the coordinate in the given dimension. * * @param index dimension index (zero based) * @return coordinate in the given dimension */ coordinate_type get(size_t index) const { return mCoords[index]; } /** * Returns the distance squared from this point to another * point. * * @param pt another point * @return distance squared from this point to the other point */ double distance(const Point& pt) const { double dist = 0; for (size_t i = 0; i < dimensions; ++i) { double d = get(i) - pt.get(i); dist += d * d; } return dist; } private: std::array mCoords; }; template std::ostream& operator<<(std::ostream& out, const Point& pt) const std::array& pt) { out << '('; for (size_t i = 0; i < dimensions; ++i) { if (i > 0) out << ", "; out << pt.get(i); out << pt[i]; } out << ')'; return out; ... ... @@ -77,9 +35,20 @@ template class KDTree { public: typedef Point point_type; typedef std::array point_type; private: static double distance(const point_type& a, const point_type& b) { double dist = 0; size_t len = std::min(a.size(), b.size()); for (size_t i = 0; i < len; ++i) { double d = a[i] - b[i]; dist += d * d; } return dist; } struct Node { Node(const point_type& pt) ... ... @@ -88,17 +57,25 @@ class KDTree , mRight(nullptr) { } coordinate_type get(size_t index) const { return mPoint.get(index); } double distance(const point_type& pt) const { return mPoint.distance(pt); } double distance(const point_type& pt) const { return KDTree::distance(mPoint, pt); } point_type mPoint; Node* mLeft; Node* mRight; }; }; // struct Node Node* mRoot; Node* mBest; double mBestDist; size_t mVisited; std::vector mNodes; std::vector mNearest; struct AscendingCompare { bool operator()(double a, double b) { return a < b; } }; struct NodeCompare { ... ... @@ -108,7 +85,7 @@ class KDTree } bool operator()(const Node& n1, const Node& n2) const { return n1.mPoint.get(mIndex) < n2.mPoint.get(mIndex); return n1.mPoint[mIndex] < n2.mPoint[mIndex]; } size_t mIndex; }; ... ... @@ -135,13 +112,50 @@ class KDTree mBestDist = d; mBest = root; } if (mBestDist == 0) return; double dx = root->get(index) - point.get(index); if (std::fabs(mBestDist) <= 1e-12) return; double dx = root->mPoint[index] - point[index]; index = (index + 1) % dimensions; nearest(dx > 0 ? root->mLeft : root->mRight, point, index); if (dx * dx >= mBestDist) return; nearest(dx > 0 ? root->mRight : root->mLeft, point, index); } void nearest(Node* root, const point_type& point, size_t index, size_t N, std::vector& neighbors, std::vector& distances) { if (root == nullptr) return; mVisited++; radix_tagged_line("nearest(point='" << point << "',index=" << index); double root_distance = root->distance(point); radix_tagged_line("distance to root=" << root_distance); AscendingCompare compare; if (neighbors.size() == 0) { neighbors.push_back(root->mPoint); distances.push_back(root_distance); } else if (root_distance <= distances.back()) { neighbors.push_back(root->mPoint); distances.push_back(root_distance); std::vector permutation = sort_permutation(distances, compare); apply_permutation(neighbors, permutation); apply_permutation(distances, permutation); if (distances.size() > N) { neighbors.pop_back(); distances.pop_back(); } } double direction = root->mPoint[index] - point[index]; index = (index + 1) % dimensions; nearest(direction > 0 ? root->mLeft : root->mRight, point, index, N, neighbors, distances); if (direction * direction > distances.front()) return; nearest(direction > 0 ? root->mRight : root->mLeft, point, index, N, neighbors, distances); } public: KDTree(const KDTree&) = delete; ... ... @@ -216,6 +230,14 @@ class KDTree nearest(mRoot, pt, 0); return mBest->mPoint; } void nearest(const point_type& pt, size_t N /*number of neighbors*/, std::vector& neighbors, std::vector& distances) { neighbors.clear(); distances.clear(); nearest(mRoot, pt, 0, N, neighbors, distances); } }; } // namespace radix ... ...
 ... ... @@ -5,11 +5,10 @@ #include "radixbug/bug.hh" using namespace radix; typedef Point point3d; typedef KDTree tree3d; TEST(radixalgorithm, KDTree2D) TEST(KDTree, TwoD) { typedef Point point2d; typedef std::array point2d; typedef KDTree tree2d; point2d points[] = {{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}}; ... ... @@ -19,23 +18,67 @@ TEST(radixalgorithm, KDTree2D) EXPECT_NEAR(1.41421, tree.distance(), 1e-4); EXPECT_EQ(3, tree.visited()); EXPECT_EQ(8, n.get(0)); EXPECT_EQ(1, n.get(1)); EXPECT_EQ(8, n[0]); EXPECT_EQ(1, n[1]); } TEST(radixalgorithm, KDTree3D) TEST(KDTree, Nearest3D) { typedef Point point3d; typedef std::array point3d; typedef KDTree tree3d; point3d points[] = {{3, 2, 1}, {2, 7, 1}, {6, 9, 1}, {4, 5, 1}, {7, 4, 1}, {1, 8, 1}, {3, 5, 1}, {2, 4, 1}, {6, 8, 1}}; { point3d points[] = {{3, 2, 1}, {2, 7, 1}, {6, 9, 1}, {4, 5, 1}, {7, 4, 1}, {1, 8, 1}, {3, 5, 1}, {2, 4, 1}, {6, 8, 1}}; tree3d tree(std::begin(points), std::end(points)); point3d n = tree.nearest({2, 9, 1}); tree3d tree(std::begin(points), std::end(points)); point3d n = tree.nearest({2, 9, 1}); EXPECT_NEAR(1.41421, tree.distance(), 1e-4); EXPECT_EQ(7, tree.visited()); EXPECT_EQ(1, n.get(0)); EXPECT_EQ(8, n.get(1)); EXPECT_NEAR(1.41421, tree.distance(), 1e-4); EXPECT_EQ(7, tree.visited()); EXPECT_EQ(1, n[0]); EXPECT_EQ(8, n[1]); EXPECT_EQ(1, n[2]); } { point3d points[] = {{-10, 10, 10}, {10, 10, 10}, {-10, -10, 10}, {10, -10, 10}, {-10, 10, -10}, {10, 10, -10}, {-10, -10, -10}, {10, -10, -10}}; tree3d tree(std::begin(points), std::end(points)); point3d n = tree.nearest({0, 0, 0}); EXPECT_NEAR(17.3205080756888, tree.distance(), 1e-4); EXPECT_EQ(8, tree.visited()); EXPECT_EQ(10, n[0]); EXPECT_EQ(-10, n[1]); EXPECT_EQ(-10, n[2]); } } TEST(KDTree, NearestN3D) { typedef std::array point3d; typedef KDTree tree3d; { point3d points[] = {{-10, 10, 10}, {10, 10, 10}, {-10, -10, 10}, {10, -10, 10}, {-10, 10, -10}, {10, 10, -10}, {-10, -10, -10}, {10, -10, -10}}; tree3d tree(std::begin(points), std::end(points)); std::vector neighbors; std::vector distances; point3d point{10, 10, 9}; tree.nearest(point, 4, neighbors, distances); EXPECT_EQ(4, neighbors.size()); EXPECT_EQ(4, distances.size()); EXPECT_EQ(8, tree.visited()); for (size_t i = 0; i < neighbors.size(); ++i) { std::cout << i << ". " << neighbors[i] << std::endl; } } }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!