ACF $AcfVersion:0$
TKdTree.h
Go to the documentation of this file.
1// SPDX-License-Identifier: LGPL-2.1-or-later OR GPL-2.0-or-later OR GPL-3.0-or-later OR LicenseRef-ACF-Commercial
2#pragma once
3
4#include <vector>
5#include <array>
6#include <algorithm>
7#include <functional>
8
9#undef max
10
11namespace imath
12{
13
14template<typename TPoint, uint8_t Dimensions>
16{
17public:
18 typedef std::array<double, Dimensions> Coordinate;
19 typedef std::function<double(const TPoint& p, uint8_t index)> GetComponentFunc;
20 typedef std::function<double(const Coordinate& x, const TPoint& y)> GetDistanceFunc;
21
22private:
23 GetComponentFunc m_getComponentFunc;
24 GetDistanceFunc m_getDistanceFunc;
25
26 struct Node
27 {
28 Node(const TPoint& pt) : m_point(pt), m_left(nullptr), m_right(nullptr)
29 {
30 }
31
32 TPoint m_point;
33 Node* m_left;
34 Node* m_right;
35 };
36
37 class BoundedPriorityQueue
38 {
39 public:
40
41 BoundedPriorityQueue() = delete;
42 BoundedPriorityQueue(size_t bound) : bound(bound) { elements.reserve(bound + 1); };
43
44 void push(const std::pair<double, const Node*>& val)
45 {
46 auto it = std::find_if(std::begin(elements), std::end(elements),
47 [&val](std::pair<double, const Node*> element) { return val.first < element.first; });
48 elements.insert(it, val);
49
50 if (elements.size() > bound)
51 elements.pop_back();
52 }
53
54 const std::pair<double, const Node*>& back() const { return elements.back(); };
55 const std::pair<double, const Node*>& operator[](size_t index) const { return elements[index]; }
56 size_t size() const { return elements.size(); }
57
58 private:
59 size_t bound;
60 std::vector<std::pair<double, const Node*>> elements;
61 };
62
63 Node* m_root;
64 std::vector<Node> m_nodes;
65
66 struct NodeCmp
67 {
68 NodeCmp(uint8_t index, const GetComponentFunc& getComponentFunc) : m_index(index), m_getComponentFunc(getComponentFunc)
69 {
70 }
71 bool operator()(const Node& a, const Node& b) const
72 {
73 return m_getComponentFunc(a.m_point, m_index) < m_getComponentFunc(b.m_point, m_index);
74 }
75 uint8_t m_index;
76 const GetComponentFunc& m_getComponentFunc;
77 };
78
79 Node* MakeTree(size_t begin, size_t end, uint8_t index)
80 {
81 if (end <= begin)
82 return nullptr;
83
84 size_t n = begin + (end - begin) / 2;
85
86 auto i = m_nodes.begin();
87 std::nth_element(i + begin, i + n, i + end, NodeCmp(index, m_getComponentFunc));
88
89 index = (index + 1) % Dimensions;
90
91 m_nodes[n].m_left = MakeTree(begin, n, index);
92 m_nodes[n].m_right = MakeTree(n + 1, end, index);
93 return &m_nodes[n];
94 }
95
96 void Nearest(const Node* root, const Coordinate& point, uint8_t index, double& bestDistance, const Node*& best, double maxDistance) const
97 {
98 if (root == nullptr) {
99 return;
100 }
101
102 const double d = m_getDistanceFunc(point, root->m_point);
103
104 if (d < bestDistance) {
105 bestDistance = d;
106
107 if (bestDistance <= maxDistance) {
108 best = root;
109 }
110 }
111
112 if (bestDistance == 0) {
113 return;
114 }
115
116 const double dx = m_getComponentFunc(root->m_point, index) - point[index];
117 index = (index + 1) % Dimensions;
118
119 Nearest(dx > 0 ? root->m_left : root->m_right, point, index, bestDistance, best, maxDistance);
120
121 if (abs(dx) >= bestDistance || abs(dx) >= maxDistance) {
122 return;
123 }
124
125 Nearest(dx > 0 ? root->m_right : root->m_left, point, index, bestDistance, best, maxDistance);
126 }
127
128 void KNearest(const Node* root, const Coordinate& point, uint8_t index, BoundedPriorityQueue& queue, size_t k) const
129 {
130 if (root == nullptr) {
131 return;
132 }
133
134 const double d = m_getDistanceFunc(point, root->m_point);
135 queue.push(std::make_pair(d, root));
136
137 const double dx = m_getComponentFunc(root->m_point, index) - point[index];
138 index = (index + 1) % Dimensions;
139
140 KNearest(dx > 0 ? root->m_left : root->m_right, point, index, queue, k);
141
142 if(queue.size() < k || abs(dx) < queue.back().first)
143 KNearest(dx > 0 ? root->m_right : root->m_left, point, index, queue, k);
144 }
145
146 void InRadius(const Node* root, const Coordinate& point, const double radius, uint8_t index, std::vector<std::pair<const Node*, double>>& inRadius) const
147 {
148 if (root == nullptr) {
149 return;
150 }
151
152 const double d = m_getDistanceFunc(point, root->m_point);
153
154 if (d < radius) {
155 inRadius.push_back(std::make_pair(root, d));
156 }
157
158 const double dx = m_getComponentFunc(root->m_point, index) - point[index];
159 index = (index + 1) % Dimensions;
160
161 InRadius(dx > 0 ? root->m_left : root->m_right, point, radius, index, inRadius);
162
163 if (abs(dx) > radius) {
164 return;
165 }
166
167 InRadius(dx > 0 ? root->m_right : root->m_left, point, radius, index, inRadius);
168 }
169
170public:
171 TKdTree(const TKdTree&) = delete;
172 TKdTree& operator=(const TKdTree&) = delete;
173
174 TKdTree() : m_root(nullptr)
175 {
176 }
177
178 template<typename iterator>
179 void MakeTree(iterator begin, iterator end, const GetComponentFunc& getComponentFunc, const GetDistanceFunc& getDistanceFunc)
180 {
181 m_getComponentFunc = getComponentFunc;
182 m_getDistanceFunc = getDistanceFunc;
183 m_nodes.clear();
184 m_nodes.reserve(std::distance(begin, end));
185
186 for (auto i = begin; i != end; ++i) {
187 m_nodes.emplace_back(*i);
188 }
189
190 if (m_nodes.size() > 0) {
191 m_root = MakeTree(0, m_nodes.size() - 1, 0);
192 }
193 }
194
195 void MakeTree(const std::function<TPoint(size_t)>& construct, size_t n, const GetComponentFunc& getComponentFunc, const GetDistanceFunc& getDistanceFunc)
196 {
197 m_getComponentFunc = getComponentFunc;
198 m_getDistanceFunc = getDistanceFunc;
199 m_nodes.clear();
200 m_nodes.reserve(n);
201
202 for (size_t i = 0; i < n; ++i) {
203 m_nodes.emplace_back(construct(i));
204 }
205
206 if (m_nodes.size() > 0) {
207 m_root = MakeTree(0, m_nodes.size(), 0);
208 }
209 }
210
211 bool Empty() const
212 {
213 return m_nodes.empty();
214 }
215
216 bool Nearest(const Coordinate& pt, TPoint& p, double& resultDistance, double maxDistance = std::numeric_limits<double>::max()) const
217 {
218 if (m_root == nullptr) {
219 return false;
220 }
221
222 const Node* best = nullptr;
223 double dist = std::numeric_limits<double>::max();
224 Nearest(m_root, pt, 0, dist, best, maxDistance);
225
226 if (best != nullptr) {
227 p = best->m_point;
228 resultDistance = dist;
229 return true;
230 }
231
232 return false;
233 }
234
235 bool KNearest(const Coordinate& pt, std::vector<std::pair<TPoint, double>>& neighborsWithDistance, size_t k) const
236 {
237 if (m_root == nullptr) {
238 return false;
239 }
240
241 if (k == 0) {
242 return true;
243 }
244
245 BoundedPriorityQueue queue(k);
246 KNearest(m_root, pt, 0, queue, k);
247 neighborsWithDistance.resize(queue.size());
248
249 for (size_t i = 0; i < queue.size(); ++i) {
250 neighborsWithDistance[i] = std::make_pair(queue[i].second->m_point, queue[i].first);
251 }
252
253 return true;
254 }
255
256 bool InRadius(const Coordinate& pt, double radius, std::vector<std::pair<TPoint, double>>& pointsWithDistance) const
257 {
258 if (m_root == nullptr) {
259 return false;
260 }
261
262 std::vector<std::pair<const Node*, double>> inRadiusNode;
263 InRadius(m_root, pt, radius, 0, inRadiusNode);
264 pointsWithDistance.resize(inRadiusNode.size());
265
266 for (size_t i = 0; i < inRadiusNode.size(); ++i) {
267 pointsWithDistance[i] = std::make_pair(inRadiusNode[i].first->m_point, inRadiusNode[i].second);
268 }
269
270 return true;
271 }
272};
273
274} // namespace imath
275
276
bool Empty() const
Definition TKdTree.h:211
void MakeTree(const std::function< TPoint(size_t)> &construct, size_t n, const GetComponentFunc &getComponentFunc, const GetDistanceFunc &getDistanceFunc)
Definition TKdTree.h:195
bool InRadius(const Coordinate &pt, double radius, std::vector< std::pair< TPoint, double > > &pointsWithDistance) const
Definition TKdTree.h:256
std::array< double, Dimensions > Coordinate
Definition TKdTree.h:18
bool KNearest(const Coordinate &pt, std::vector< std::pair< TPoint, double > > &neighborsWithDistance, size_t k) const
Definition TKdTree.h:235
bool Nearest(const Coordinate &pt, TPoint &p, double &resultDistance, double maxDistance=std::numeric_limits< double >::max()) const
Definition TKdTree.h:216
std::function< double(const TPoint &p, uint8_t index)> GetComponentFunc
Definition TKdTree.h:19
TKdTree & operator=(const TKdTree &)=delete
void MakeTree(iterator begin, iterator end, const GetComponentFunc &getComponentFunc, const GetDistanceFunc &getDistanceFunc)
Definition TKdTree.h:179
std::function< double(const Coordinate &x, const TPoint &y)> GetDistanceFunc
Definition TKdTree.h:20
TKdTree(const TKdTree &)=delete
Package with mathematical functions and algebraical primitives.