From 1b4a574d45b2946a6ce7189bb979cc4d9b4ce30d Mon Sep 17 00:00:00 2001 From: Gleb Mazovetskiy Date: Wed, 2 Jul 2025 18:55:49 +0100 Subject: [PATCH] Faster k-d tree construction When populating the pivots, only keep the element counts instead of the actual elements. This removes the need to compute the counts in `getMedian`. ``` Benchmark Time CPU Time Old Time New CPU Old CPU New ----------------------------------------------------------------------------------------------------------------------------------- BM_GenerateBlendedLookupTable_pvalue 0.0002 0.0002 U Test, Repetitions: 10 vs 10 BM_GenerateBlendedLookupTable_mean -0.0224 -0.0225 2154569 2106218 2154215 2105837 BM_GenerateBlendedLookupTable_median -0.0224 -0.0225 2154463 2106143 2154147 2105729 BM_GenerateBlendedLookupTable_stddev -0.5094 -0.5714 1153 566 952 408 BM_GenerateBlendedLookupTable_cv -0.4981 -0.5616 0 0 0 0 BM_BuildTree_pvalue 0.0002 0.0002 U Test, Repetitions: 10 vs 10 BM_BuildTree_mean -0.3660 -0.3660 6520 4134 6519 4133 BM_BuildTree_median -0.3659 -0.3659 6519 4134 6518 4133 BM_BuildTree_stddev -0.2606 -0.2381 7 5 6 5 BM_BuildTree_cv +0.1661 +0.2016 0 0 0 0 BM_FindNearestNeighbor_pvalue 0.0002 0.0002 U Test, Repetitions: 10 vs 10 BM_FindNearestNeighbor_mean -0.0181 -0.0181 1980869037 1945027191 1980539225 1944679255 BM_FindNearestNeighbor_median -0.0181 -0.0181 1980930663 1945081501 1980593915 1944736796 BM_FindNearestNeighbor_stddev -0.4920 -0.4921 809594 411280 814962 413885 BM_FindNearestNeighbor_cv -0.4826 -0.4828 0 0 0 0 OVERALL_GEOMEAN -0.1526 -0.1526 0 0 0 0 ``` --- Source/utils/palette_kd_tree.hpp | 58 +++++++++++++++++--------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/Source/utils/palette_kd_tree.hpp b/Source/utils/palette_kd_tree.hpp index 9f6843304..7c4d291e6 100644 --- a/Source/utils/palette_kd_tree.hpp +++ b/Source/utils/palette_kd_tree.hpp @@ -234,59 +234,61 @@ public: } private: - [[nodiscard]] static uint8_t getMedian(std::span elements) - { - uint8_t min = 255; - uint8_t max = 0; - uint_fast16_t count[256] = {}; - for (const uint8_t x : elements) { - min = std::min(x, min); - max = std::max(x, max); - ++count[x]; - } + struct MedianInfo { + std::array counts = {}; + uint16_t numValues = 0; + }; - const auto medianTarget = static_cast((elements.size() + 1) / 2); - uint_fast16_t partialSum = count[min]; - for (uint_fast16_t i = min + 1; i <= max; ++i) { - if (partialSum >= medianTarget) return i; - partialSum += count[i]; + [[nodiscard]] static uint8_t getMedian(const MedianInfo &medianInfo) + { + const std::span counts = medianInfo.counts; + const uint_fast16_t numValues = medianInfo.numValues; + const auto medianTarget = static_cast((medianInfo.numValues + 1) / 2); + uint_fast16_t partialSum = 0; + uint_fast16_t i = 0; + for (; partialSum < medianTarget && partialSum != numValues; ++i) { + partialSum += counts[i]; } - // Can't find a helpful pivot so return 255 so that - // NN lookups through this node mostly go to the left child. - return 255; + // Special cases: + // 1. If the elements are empty, this will return 0. + // 2. If all the elements are the same, this will be `value + 1` (rolling over to 0 if value is 256). + // This means all the elements will be on one side of the pivot (left unless the value is 255). + return static_cast(i); } template static void maybeAddToSubdivisionForMedian( const PaletteKdTreeNode &node, const SDL_Color palette[256], unsigned paletteIndex, - std::span, N> out) + std::span medianInfos) { const uint8_t color = node.getColorCoordinate(palette[paletteIndex]); if constexpr (N == 1) { - out[0].emplace_back(color); + MedianInfo &medianInfo = medianInfos[0]; + ++medianInfo.counts[color]; + ++medianInfo.numValues; } else { const bool isLeft = color < node.pivot; maybeAddToSubdivisionForMedian(node.child(isLeft), palette, paletteIndex, isLeft - ? out.template subspan<0, N / 2>() - : out.template subspan()); + ? medianInfos.template subspan<0, N / 2>() + : medianInfos.template subspan()); } } template static void setPivotsRecursively( PaletteKdTreeNode &node, - std::span, N> values) + std::span medianInfos) { if constexpr (N == 1) { - node.pivot = getMedian(values[0]); + node.pivot = getMedian(medianInfos[0]); } else { - setPivotsRecursively(node.left, values.template subspan<0, N / 2>()); - setPivotsRecursively(node.right, values.template subspan()); + setPivotsRecursively(node.left, medianInfos.template subspan<0, N / 2>()); + setPivotsRecursively(node.right, medianInfos.template subspan()); } } @@ -294,8 +296,8 @@ private: void populatePivotsForTargetDepth(const SDL_Color palette[256], int skipFrom, int skipTo) { constexpr size_t NumSubdivisions = 1U << TargetDepth; - std::array, NumSubdivisions> subdivisions; - const std::span, NumSubdivisions> subdivisionsSpan { subdivisions }; + std::array subdivisions = {}; + const std::span subdivisionsSpan { subdivisions }; for (int i = 0; i < 256; ++i) { if (i >= skipFrom && i <= skipTo) continue; maybeAddToSubdivisionForMedian(tree_, palette, i, subdivisionsSpan);