Commit 445667fc authored by Lefebvre, Jordan P's avatar Lefebvre, Jordan P

Working on N nearest neighbors.

parent a4a86410
Pipeline #84947 failed with stages
in 13 minutes and 30 seconds
#ifndef RADIX_RADIXALGORITHM_KDTREE_HH
#define RADIX_RADIXALGORITHM_KDTREE_HH
#include "radixalgorithm/ordering.hh"
#include "radixbug/bug.hh"
#include <algorithm>
#include <array>
#include <cassert>
......@@ -11,60 +14,15 @@
namespace radix
{
/**
* Class for representing a Point. coordinate_type must be a numeric type.
*/
template <typename coordinate_type, size_t dimensions>
class Point
{
public:
Point(const std::array<coordinate_type, dimensions>& c)
: mCoords(c)
{
}
Point(std::initializer_list<coordinate_type> list)
{
size_t n = std::min(dimensions, list.size());
std::copy_n(list.begin(), n, mCoords.begin());
}
/**
* Returns the coordinate in the given dimension.
*
* @param index dimension index (zero based)
* @return coordinate in the given dimension
*/
coordinate_type get(size_t index) const { return mCoords[index]; }
/**
* Returns the distance squared from this point to another
* point.
*
* @param pt another point
* @return distance squared from this point to the other point
*/
double distance(const Point& pt) const
{
double dist = 0;
for (size_t i = 0; i < dimensions; ++i)
{
double d = get(i) - pt.get(i);
dist += d * d;
}
return dist;
}
private:
std::array<coordinate_type, dimensions> mCoords;
};
template <typename coordinate_type, size_t dimensions>
std::ostream& operator<<(std::ostream& out,
const Point<coordinate_type, dimensions>& pt)
const std::array<coordinate_type, dimensions>& pt)
{
out << '(';
for (size_t i = 0; i < dimensions; ++i)
{
if (i > 0) out << ", ";
out << pt.get(i);
out << pt[i];
}
out << ')';
return out;
......@@ -77,9 +35,20 @@ template <typename coordinate_type, size_t dimensions>
class KDTree
{
public:
typedef Point<coordinate_type, dimensions> point_type;
typedef std::array<coordinate_type, dimensions> point_type;
private:
static double distance(const point_type& a, const point_type& b)
{
double dist = 0;
size_t len = std::min(a.size(), b.size());
for (size_t i = 0; i < len; ++i)
{
double d = a[i] - b[i];
dist += d * d;
}
return dist;
}
struct Node
{
Node(const point_type& pt)
......@@ -88,17 +57,25 @@ class KDTree
, mRight(nullptr)
{
}
coordinate_type get(size_t index) const { return mPoint.get(index); }
double distance(const point_type& pt) const { return mPoint.distance(pt); }
double distance(const point_type& pt) const
{
return KDTree::distance(mPoint, pt);
}
point_type mPoint;
Node* mLeft;
Node* mRight;
};
}; // struct Node
Node* mRoot;
Node* mBest;
double mBestDist;
size_t mVisited;
std::vector<Node> mNodes;
std::vector<Node*> mNearest;
struct AscendingCompare
{
bool operator()(double a, double b) { return a < b; }
};
struct NodeCompare
{
......@@ -108,7 +85,7 @@ class KDTree
}
bool operator()(const Node& n1, const Node& n2) const
{
return n1.mPoint.get(mIndex) < n2.mPoint.get(mIndex);
return n1.mPoint[mIndex] < n2.mPoint[mIndex];
}
size_t mIndex;
};
......@@ -135,13 +112,50 @@ class KDTree
mBestDist = d;
mBest = root;
}
if (mBestDist == 0) return;
double dx = root->get(index) - point.get(index);
if (std::fabs(mBestDist) <= 1e-12) return;
double dx = root->mPoint[index] - point[index];
index = (index + 1) % dimensions;
nearest(dx > 0 ? root->mLeft : root->mRight, point, index);
if (dx * dx >= mBestDist) return;
nearest(dx > 0 ? root->mRight : root->mLeft, point, index);
}
void nearest(Node* root, const point_type& point, size_t index, size_t N,
std::vector<point_type>& neighbors,
std::vector<double>& distances)
{
if (root == nullptr) return;
mVisited++;
radix_tagged_line("nearest(point='" << point << "',index=" << index);
double root_distance = root->distance(point);
radix_tagged_line("distance to root=" << root_distance);
AscendingCompare compare;
if (neighbors.size() == 0)
{
neighbors.push_back(root->mPoint);
distances.push_back(root_distance);
}
else if (root_distance <= distances.back())
{
neighbors.push_back(root->mPoint);
distances.push_back(root_distance);
std::vector<size_t> permutation = sort_permutation(distances, compare);
apply_permutation(neighbors, permutation);
apply_permutation(distances, permutation);
if (distances.size() > N)
{
neighbors.pop_back();
distances.pop_back();
}
}
double direction = root->mPoint[index] - point[index];
index = (index + 1) % dimensions;
nearest(direction > 0 ? root->mLeft : root->mRight, point, index, N,
neighbors, distances);
if (direction * direction > distances.front()) return;
nearest(direction > 0 ? root->mRight : root->mLeft, point, index, N,
neighbors, distances);
}
public:
KDTree(const KDTree&) = delete;
......@@ -216,6 +230,14 @@ class KDTree
nearest(mRoot, pt, 0);
return mBest->mPoint;
}
void nearest(const point_type& pt, size_t N /*number of neighbors*/,
std::vector<point_type>& neighbors,
std::vector<double>& distances)
{
neighbors.clear();
distances.clear();
nearest(mRoot, pt, 0, N, neighbors, distances);
}
};
} // namespace radix
......
......@@ -5,11 +5,10 @@
#include "radixbug/bug.hh"
using namespace radix;
typedef Point<double, 3> point3d;
typedef KDTree<double, 3> tree3d;
TEST(radixalgorithm, KDTree2D)
TEST(KDTree, TwoD)
{
typedef Point<int, 2> point2d;
typedef std::array<int, 2> point2d;
typedef KDTree<int, 2> tree2d;
point2d points[] = {{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}};
......@@ -19,23 +18,67 @@ TEST(radixalgorithm, KDTree2D)
EXPECT_NEAR(1.41421, tree.distance(), 1e-4);
EXPECT_EQ(3, tree.visited());
EXPECT_EQ(8, n.get(0));
EXPECT_EQ(1, n.get(1));
EXPECT_EQ(8, n[0]);
EXPECT_EQ(1, n[1]);
}
TEST(radixalgorithm, KDTree3D)
TEST(KDTree, Nearest3D)
{
typedef Point<int, 3> point3d;
typedef std::array<int, 3> point3d;
typedef KDTree<int, 3> tree3d;
point3d points[] = {{3, 2, 1}, {2, 7, 1}, {6, 9, 1}, {4, 5, 1}, {7, 4, 1},
{1, 8, 1}, {3, 5, 1}, {2, 4, 1}, {6, 8, 1}};
{
point3d points[] = {{3, 2, 1}, {2, 7, 1}, {6, 9, 1}, {4, 5, 1}, {7, 4, 1},
{1, 8, 1}, {3, 5, 1}, {2, 4, 1}, {6, 8, 1}};
tree3d tree(std::begin(points), std::end(points));
point3d n = tree.nearest({2, 9, 1});
tree3d tree(std::begin(points), std::end(points));
point3d n = tree.nearest({2, 9, 1});
EXPECT_NEAR(1.41421, tree.distance(), 1e-4);
EXPECT_EQ(7, tree.visited());
EXPECT_EQ(1, n.get(0));
EXPECT_EQ(8, n.get(1));
EXPECT_NEAR(1.41421, tree.distance(), 1e-4);
EXPECT_EQ(7, tree.visited());
EXPECT_EQ(1, n[0]);
EXPECT_EQ(8, n[1]);
EXPECT_EQ(1, n[2]);
}
{
point3d points[] = {{-10, 10, 10}, {10, 10, 10}, {-10, -10, 10},
{10, -10, 10}, {-10, 10, -10}, {10, 10, -10},
{-10, -10, -10}, {10, -10, -10}};
tree3d tree(std::begin(points), std::end(points));
point3d n = tree.nearest({0, 0, 0});
EXPECT_NEAR(17.3205080756888, tree.distance(), 1e-4);
EXPECT_EQ(8, tree.visited());
EXPECT_EQ(10, n[0]);
EXPECT_EQ(-10, n[1]);
EXPECT_EQ(-10, n[2]);
}
}
TEST(KDTree, NearestN3D)
{
typedef std::array<int, 3> point3d;
typedef KDTree<int, 3> tree3d;
{
point3d points[] = {{-10, 10, 10}, {10, 10, 10}, {-10, -10, 10},
{10, -10, 10}, {-10, 10, -10}, {10, 10, -10},
{-10, -10, -10}, {10, -10, -10}};
tree3d tree(std::begin(points), std::end(points));
std::vector<point3d> neighbors;
std::vector<double> distances;
point3d point{10, 10, 9};
tree.nearest(point, 4, neighbors, distances);
EXPECT_EQ(4, neighbors.size());
EXPECT_EQ(4, distances.size());
EXPECT_EQ(8, tree.visited());
for (size_t i = 0; i < neighbors.size(); ++i)
{
std::cout << i << ". " << neighbors[i] << std::endl;
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment