10#ifndef PBAT_COMMON_BINARYRADIXTREE_H
11#define PBAT_COMMON_BINARYRADIXTREE_H
26template <
class TIndex = Index>
30 BinaryRadixTree() =
default;
38 template <
class TDerived>
39 BinaryRadixTree(Eigen::DenseBase<TDerived>
const& codes,
bool bStoreParent =
false);
94 template <
class TDerived>
95 void Construct(Eigen::DenseBase<TDerived>
const& codes,
bool bStoreParent =
false);
103 TIndex
Left(TIndex i)
const {
return mChild(0, i); }
111 TIndex
Right(TIndex i)
const {
return mChild(1, i); }
118 TIndex
Parent(TIndex i)
const {
return mParent(i); }
143 constexpr TIndex
Root()
const {
return 0; }
170 auto Left()
const {
return mChild.row(0); }
176 auto Right()
const {
return mChild.row(1); }
192 Eigen::Matrix<TIndex, 2, Eigen::Dynamic>
196 Eigen::Vector<TIndex, Eigen::Dynamic> mParent;
200template <
class TIndex>
201template <
class TDerived>
202inline BinaryRadixTree<TIndex>::BinaryRadixTree(
203 Eigen::DenseBase<TDerived>
const& codes,
206 Construct(codes.derived(), bStoreParent);
209template <
class TIndex>
210template <
class TDerived>
214 using CodeType =
typename TDerived::Scalar;
216 std::is_integral_v<CodeType> and std::is_unsigned_v<CodeType> and
217 not std::is_same_v<CodeType, bool>,
218 "Codes must be of unsigned integral type");
219 auto constexpr nBits =
sizeof(CodeType) * 8;
220 static_assert(nBits <= 64,
"CodeType must have at most 64 bits");
222 std::uint64_t
constexpr msb = 0b1ULL << (nBits - 1);
223 TIndex
const nLeaves = codes.size();
224 TIndex
const nInternal = nLeaves - 1;
225 mChild.resize(2, nInternal);
226 mParent.resize(bStoreParent * (nLeaves + nInternal));
233 auto const fCommonPrefixLength = [](CodeType ci, CodeType cj) {
234 return std::countl_zero(ci ^ cj);
237 stack.
Push({0, nLeaves - 1});
239 while (not stack.IsEmpty())
241 Node
const node = stack.Pop();
245 bool const bReversed = node.begin > node.end;
246 auto const first = (not bReversed) * node.begin + bReversed * node.end;
247 auto const last = (not bReversed) * node.end + bReversed * node.begin;
249 TIndex split = first;
250 auto const cfirst = codes(first);
251 auto const clast = codes(last);
254 split += (last - first + 1) / 2;
258 auto const mask = msb >> fCommonPrefixLength(cfirst, clast);
259 auto const begin = codes.begin() + first;
260 auto const end = codes.begin() + last + 1;
261 auto const upper = std::upper_bound(begin, end, cfirst, [&](CodeType ci, CodeType cj) {
262 return (mask & ci) < (mask & cj);
264 split += std::distance(begin, upper);
268 TIndex lc = split - 1;
272 bool const bIsLeftLeaf = (lc == first);
273 bool const bIsRightLeaf = (rc == last);
274 lc += bIsLeftLeaf * nInternal;
275 rc += bIsRightLeaf * nInternal;
277 mChild(0, node.begin) = lc;
278 mChild(1, node.begin) = rc;
281 stack.Push({lc, first});
282 if (not bIsRightLeaf)
283 stack.Push({rc, last});
289 for (
auto i = 0; i < nInternal; ++i)
291 mParent(
Left(i)) = i;
292 mParent(
Right(i)) = i;
auto Parent() const
Get the parent array of the tree.
Definition BinaryRadixTree.h:183
auto Children() const
Get the children array of the tree.
Definition BinaryRadixTree.h:164
TIndex Parent(TIndex i) const
Get the parent of a node.
Definition BinaryRadixTree.h:118
auto Left() const
Left child array of the tree.
Definition BinaryRadixTree.h:170
auto Right() const
Right child array of the tree.
Definition BinaryRadixTree.h:176
bool HasParentRelationship() const
Check if the tree has parent relationships.
Definition BinaryRadixTree.h:189
TIndex CodeIndex(TIndex leaf) const
Get the index of the code associated with a leaf node.
Definition BinaryRadixTree.h:150
TIndex Right(TIndex i) const
Get the right child of internal node i
Definition BinaryRadixTree.h:111
constexpr TIndex Root() const
Get the root of the tree.
Definition BinaryRadixTree.h:143
TIndex InternalNodeCount() const
Get the number of internal nodes.
Definition BinaryRadixTree.h:124
BinaryRadixTree(Eigen::DenseBase< TDerived > const &codes, bool bStoreParent=false)
Construct a new Binary Radix Tree object.
Definition BinaryRadixTree.h:202
TIndex LeafCount() const
Get the number of leaf nodes.
Definition BinaryRadixTree.h:130
void Construct(Eigen::DenseBase< TDerived > const &codes, bool bStoreParent=false)
Construct a Radix Tree from a sorted list of unsigned integral codes.
Definition BinaryRadixTree.h:212
TIndex LeafIndex(TIndex codeIdx) const
Get the index of the leaf node associated with a code.
Definition BinaryRadixTree.h:157
TIndex Left(TIndex i) const
Get the left child of internal node i
Definition BinaryRadixTree.h:103
bool IsLeaf(TIndex i) const
Check if a node is a leaf.
Definition BinaryRadixTree.h:137
Fixed-size stack implementation.
Definition Stack.h:26
PBAT_HOST_DEVICE void Push(T value)
Add element to the stack.
Definition Stack.h:37
Fixed-size stack implementation usable in both host and device code.
Common functionality.
Definition ArgSort.h:20