diff --git a/Source/CMakeLists.txt b/Source/CMakeLists.txt index 2b6824d1c..7289d7f92 100644 --- a/Source/CMakeLists.txt +++ b/Source/CMakeLists.txt @@ -252,6 +252,15 @@ target_link_dependencies(libdevilutionx_padmapper PUBLIC libdevilutionx_options ) +add_devilutionx_object_library(libdevilutionx_palette_kd_tree + utils/palette_kd_tree.cpp +) +target_link_dependencies(libdevilutionx_palette_kd_tree PUBLIC + DevilutionX::SDL + fmt::fmt + libdevilutionx_strings +) + add_devilutionx_object_library(libdevilutionx_paths utils/paths.cpp ) @@ -546,6 +555,7 @@ add_devilutionx_object_library(libdevilutionx_palette_blending target_link_dependencies(libdevilutionx_palette_blending PUBLIC DevilutionX::SDL fmt::fmt + libdevilutionx_palette_kd_tree libdevilutionx_strings ) diff --git a/Source/utils/palette_kd_tree.cpp b/Source/utils/palette_kd_tree.cpp new file mode 100644 index 000000000..d3c569f56 --- /dev/null +++ b/Source/utils/palette_kd_tree.cpp @@ -0,0 +1,217 @@ +#include "utils/palette_kd_tree.hpp" + +#include +#include +#include +#include +#include +#include + +#ifdef USE_SDL1 +#include +#else +#include +#endif + +#include + +#include "utils/static_vector.hpp" +#include "utils/str_cat.hpp" + +#if DEVILUTIONX_PRINT_PALETTE_BLENDING_TREE_GRAPHVIZ +#include +#endif + +namespace devilution { +namespace { + +template +uint8_t GetColorComponent(const SDL_Color &); +template <> +inline uint8_t GetColorComponent<0>(const SDL_Color &c) { return c.r; } +template <> +inline uint8_t GetColorComponent<1>(const SDL_Color &c) { return c.g; } +template <> +inline uint8_t GetColorComponent<2>(const SDL_Color &c) { return c.b; } + +template +[[nodiscard]] PaletteKdTreeNode<0> &LeafByIndex(PaletteKdTreeNode &node, uint8_t index) +{ + if constexpr (RemainingDepth == 1) { + return node.child(index % 2 == 0); + } else { + return LeafByIndex(node.child(index % 2 == 0), index / 2); + } +} + +template +[[nodiscard]] uint8_t LeafIndexForColor(const PaletteKdTreeNode &node, const SDL_Color &color, uint8_t result = 0) +{ + const bool isLeft = GetColorComponent::Coord>(color) < node.pivot; + if constexpr (RemainingDepth == 1) { + return (2 * result) + (isLeft ? 0 : 1); + } else { + return (2 * LeafIndexForColor(node.child(isLeft), color, result)) + (isLeft ? 0 : 1); + } +} + +struct MedianInfo { + std::array counts = {}; + uint16_t numValues = 0; +}; + +[[nodiscard]] static uint8_t GetMedian(const MedianInfo &medianInfo) +{ + const std::span counts = medianInfo.counts; + const uint_fast16_t numValues = medianInfo.numValues; + const auto medianTarget = static_cast((medianInfo.numValues + 1) / 2); + uint_fast16_t partialSum = 0; + uint_fast16_t i = 0; + for (; partialSum < medianTarget && partialSum != numValues; ++i) { + partialSum += counts[i]; + } + + // Special cases: + // 1. If the elements are empty, this will return 0. + // 2. If all the elements are the same, this will be `value + 1` (rolling over to 0 if value is 256). + // This means all the elements will be on one side of the pivot (left unless the value is 255). + return static_cast(i); +} + +template +void MaybeAddToSubdivisionForMedian( + const PaletteKdTreeNode &node, + const SDL_Color palette[256], unsigned paletteIndex, + std::span medianInfos) +{ + const uint8_t color = GetColorComponent::Coord>(palette[paletteIndex]); + if constexpr (N == 1) { + MedianInfo &medianInfo = medianInfos[0]; + ++medianInfo.counts[color]; + ++medianInfo.numValues; + } else { + const bool isLeft = color < node.pivot; + MaybeAddToSubdivisionForMedian(node.child(isLeft), + palette, + paletteIndex, + isLeft + ? medianInfos.template subspan<0, N / 2>() + : medianInfos.template subspan()); + } +} + +template +void SetPivotsRecursively( + PaletteKdTreeNode &node, + std::span medianInfos) +{ + if constexpr (N == 1) { + node.pivot = GetMedian(medianInfos[0]); + } else { + SetPivotsRecursively(node.left, medianInfos.template subspan<0, N / 2>()); + SetPivotsRecursively(node.right, medianInfos.template subspan()); + } +} + +template +void PopulatePivotsForTargetDepth(PaletteKdTreeNode &root, + const SDL_Color palette[256], int skipFrom, int skipTo) +{ + constexpr size_t NumSubdivisions = 1U << TargetDepth; + std::array subdivisions = {}; + const std::span subdivisionsSpan { subdivisions }; + for (int i = 0; i < 256; ++i) { + if (i >= skipFrom && i <= skipTo) continue; + MaybeAddToSubdivisionForMedian(root, palette, i, subdivisionsSpan); + } + SetPivotsRecursively(root, subdivisionsSpan); +} + +template +void PopulatePivotsImpl(PaletteKdTreeNode &root, + const SDL_Color palette[256], int skipFrom, int skipTo, std::index_sequence intSeq) // NOLINT(misc-unused-parameters) +{ + (PopulatePivotsForTargetDepth(root, palette, skipFrom, skipTo), ...); +} + +void PopulatePivots(PaletteKdTreeNode &root, + const SDL_Color palette[256], int skipFrom, int skipTo) +{ + PopulatePivotsImpl(root, palette, skipFrom, skipTo, std::make_index_sequence {}); +} + +} // namespace + +PaletteKdTree::PaletteKdTree(const SDL_Color palette[256], int skipFrom, int skipTo) +{ + PopulatePivots(tree_, palette, skipFrom, skipTo); + StaticVector leafValues[NumLeaves]; + for (int i = 0; i < 256; ++i) { + if (i >= skipFrom && i <= skipTo) continue; + leafValues[LeafIndexForColor(tree_, palette[i])].emplace_back(i); + } + + size_t totalLen = 0; + for (uint8_t leafIndex = 0; leafIndex < NumLeaves; ++leafIndex) { + PaletteKdTreeNode<0> &leaf = LeafByIndex(tree_, leafIndex); + const std::span values = leafValues[leafIndex]; + if (values.empty()) { + leaf.valuesBegin = 1; + leaf.valuesEndInclusive = 0; + } else { + leaf.valuesBegin = static_cast(totalLen); + leaf.valuesEndInclusive = static_cast(totalLen - 1 + values.size()); + + for (size_t i = 0; i < values.size(); ++i) { + const uint8_t value = values[i]; + values_[totalLen + i] = std::make_pair(RGB { palette[value].r, palette[value].g, palette[value].b }, value); + } + totalLen += values.size(); + } + } + +#if DEVILUTIONX_PRINT_PALETTE_BLENDING_TREE_GRAPHVIZ + // To generate palette.dot.svg, run: + // dot -O -Tsvg palette.dot + FILE *out = std::fopen("palette.dot", "w"); + std::string dot = toGraphvizDot(); + std::fwrite(dot.data(), dot.size(), 1, out); + std::fclose(out); +#endif +} + +std::string PaletteKdTree::toGraphvizDot() const +{ + std::string dot = "graph palette_tree {\n rankdir=LR\n"; + tree_.toGraphvizDot(0, values_, dot); + dot.append("}\n"); + return dot; +} + +void PaletteKdTreeNode<0>::toGraphvizDot( + size_t id, std::span::RGB, uint8_t>, 256> values, std::string &dot) const +{ + StrAppend(dot, " node_", id, R"( [shape=plain label=< + + )"); + const std::pair *const end = values.data() + valuesEndInclusive; + for (const std::pair *it = values.data() + valuesBegin; it <= end; ++it) { + const auto &[rgb, paletteIndex] = *it; + char hexColor[6]; + fmt::format_to(hexColor, "{:02x}{:02x}{:02x}", rgb[0], rgb[1], rgb[2]); + StrAppend(dot, R"("); + } + if (valuesBegin > valuesEndInclusive) StrAppend(dot, ""); + StrAppend(dot, "\n
"); + const bool useWhiteText = rgb[0] + rgb[1] + rgb[2] < 350; + if (useWhiteText) StrAppend(dot, R"()"); + StrAppend(dot, + static_cast(rgb[0]), " ", + static_cast(rgb[1]), " ", + static_cast(rgb[2]), R"(
)", + static_cast(paletteIndex)); + if (useWhiteText) StrAppend(dot, "
"); + StrAppend(dot, "
>]\n"); +} + +} // namespace devilution diff --git a/Source/utils/palette_kd_tree.hpp b/Source/utils/palette_kd_tree.hpp index 7c4d291e6..af9208b27 100644 --- a/Source/utils/palette_kd_tree.hpp +++ b/Source/utils/palette_kd_tree.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -9,45 +8,18 @@ #include #include -#include -#include +#ifdef USE_SDL1 +#include +#else +#include +#endif -#include "utils/static_vector.hpp" #include "utils/str_cat.hpp" -#define DEVILUTIONX_PRINT_PALETTE_BLENDING_TREE_GRAPHVIZ 0 - -#if DEVILUTIONX_PRINT_PALETTE_BLENDING_TREE_GRAPHVIZ -#include -#endif +#define DEVILUTIONX_PRINT_PALETTE_BLENDING_TREE_GRAPHVIZ 0 // NOLINT(modernize-macro-to-enum) namespace devilution { -[[nodiscard]] inline uint32_t GetColorDistance(const std::array &a, const std::array &b) -{ - const int diffr = a[0] - b[0]; - const int diffg = a[1] - b[1]; - const int diffb = a[2] - b[2]; - return (diffr * diffr) + (diffg * diffg) + (diffb * diffb); -} - -[[nodiscard]] inline uint32_t GetColorDistanceToPlane(int x1, int x2) -{ - // Our planes are axis-aligned, so a distance from a point to a plane - // can be calculated based on just the axis coordinate. - const int delta = x1 - x2; - return static_cast(delta * delta); -} - -template -uint8_t GetColorComponent(const SDL_Color &); -template <> -inline uint8_t GetColorComponent<0>(const SDL_Color &c) { return c.r; } -template <> -inline uint8_t GetColorComponent<1>(const SDL_Color &c) { return c.g; } -template <> -inline uint8_t GetColorComponent<2>(const SDL_Color &c) { return c.b; } - /** * @brief Depth (number of levels) of the tree. */ @@ -77,48 +49,15 @@ struct PaletteKdTreeNode { return isLeft ? left : right; } - [[nodiscard]] static constexpr uint8_t getColorCoordinate(const SDL_Color &color) - { - return GetColorComponent(color); - } - - [[nodiscard]] uint8_t leafIndexForColor(const SDL_Color &color, uint8_t result = 0) - { - const bool isLeft = getColorCoordinate(color) < pivot; - if constexpr (RemainingDepth == 1) { - return (2 * result) + (isLeft ? 0 : 1); - } else { - return (2 * child(isLeft).leafIndexForColor(color, result)) + (isLeft ? 0 : 1); - } - } - - [[nodiscard]] PaletteKdTreeNode<0> &leafByIndex(uint8_t index) - { - if constexpr (RemainingDepth == 1) { - return child(index % 2 == 0); - } else { - return child(index % 2 == 0).leafByIndex(index / 2); - } - } - [[maybe_unused]] void toGraphvizDot(size_t id, std::span, 256> values, std::string &dot) const { - StrAppend(dot, " node_", id, " [label=\""); - if (Coord == 0) { - dot += 'r'; - } else if (Coord == 1) { - dot += 'g'; - } else { - dot += 'b'; - } - StrAppend(dot, ": ", pivot, "\"]\n"); - + StrAppend(dot, " node_", id, " [label=\"", "rgb"[Coord], ": ", pivot, "\"]\n"); const size_t leftId = (2 * id) + 1; const size_t rightId = (2 * id) + 2; left.toGraphvizDot(leftId, values, dot); right.toGraphvizDot(rightId, values, dot); - StrAppend(dot, " node_", id, " -- node_", leftId, "\n"); - StrAppend(dot, " node_", id, " -- node_", rightId, "\n"); + StrAppend(dot, " node_", id, " -- node_", leftId, + "\n node_", id, " -- node_", rightId, "\n"); } }; @@ -134,30 +73,7 @@ struct PaletteKdTreeNode { uint8_t valuesBegin; uint8_t valuesEndInclusive; - [[maybe_unused]] void toGraphvizDot(size_t id, std::span, 256> values, std::string &dot) const - { - StrAppend(dot, " node_", id, R"( [shape=plain label=< - - )"); - const std::pair *const end = values.data() + valuesEndInclusive; - for (const std::pair *it = values.data() + valuesBegin; it <= end; ++it) { - const auto &[rgb, paletteIndex] = *it; - char hexColor[6]; - fmt::format_to(hexColor, "{:02x}{:02x}{:02x}", rgb[0], rgb[1], rgb[2]); - StrAppend(dot, R"("); - } - if (valuesBegin > valuesEndInclusive) StrAppend(dot, ""); - StrAppend(dot, "\n
"); - const bool useWhiteText = rgb[0] + rgb[1] + rgb[2] < 350; - if (useWhiteText) StrAppend(dot, R"()"); - StrAppend(dot, - static_cast(rgb[0]), " ", - static_cast(rgb[1]), " ", - static_cast(rgb[2]), R"(
)", - static_cast(paletteIndex)); - if (useWhiteText) StrAppend(dot, "
"); - StrAppend(dot, "
>]\n"); - } + void toGraphvizDot(size_t id, std::span, 256> values, std::string &dot) const; }; /** @@ -179,171 +95,73 @@ public: * The palette is used as points in the tree. * Colors between skipFrom and skipTo (inclusive) are skipped. */ - explicit PaletteKdTree(const SDL_Color palette[256], int skipFrom, int skipTo) - { - populatePivots(palette, skipFrom, skipTo); - StaticVector leafValues[NumLeaves]; - for (int i = 0; i < 256; ++i) { - if (i >= skipFrom && i <= skipTo) continue; - leafValues[tree_.leafIndexForColor(palette[i])].emplace_back(i); - } - - size_t totalLen = 0; - for (uint8_t leafIndex = 0; leafIndex < NumLeaves; ++leafIndex) { - PaletteKdTreeNode<0> &leaf = tree_.leafByIndex(leafIndex); - const std::span values = leafValues[leafIndex]; - if (values.empty()) { - leaf.valuesBegin = 1; - leaf.valuesEndInclusive = 0; - } else { - leaf.valuesBegin = static_cast(totalLen); - leaf.valuesEndInclusive = static_cast(totalLen - 1 + values.size()); - - for (size_t i = 0; i < values.size(); ++i) { - const uint8_t value = values[i]; - values_[totalLen + i] = std::make_pair(RGB { palette[value].r, palette[value].g, palette[value].b }, value); - } - totalLen += values.size(); - } - } - -#if DEVILUTIONX_PRINT_PALETTE_BLENDING_TREE_GRAPHVIZ - // To generate palette.dot.svg, run: - // dot -O -Tsvg palette.dot - FILE *out = std::fopen("palette.dot", "w"); - std::string dot = toGraphvizDot(); - std::fwrite(dot.data(), dot.size(), 1, out); - std::fclose(out); -#endif - } + PaletteKdTree(const SDL_Color palette[256], int skipFrom, int skipTo); - [[nodiscard]] uint8_t findNearestNeighbor(const RGB &rgb) const - { + struct VisitState { uint8_t best; - uint32_t bestDiff = std::numeric_limits::max(); - findNearestNeighborVisit(tree_, rgb, bestDiff, best); - return values_[best].second; - } - - [[maybe_unused]] [[nodiscard]] std::string toGraphvizDot() const - { - std::string dot = "graph palette_tree {\n rankdir=LR\n"; - tree_.toGraphvizDot(0, values_, dot); - dot.append("}\n"); - return dot; - } - -private: - struct MedianInfo { - std::array counts = {}; - uint16_t numValues = 0; + uint32_t bestDiff; }; - [[nodiscard]] static uint8_t getMedian(const MedianInfo &medianInfo) - { - const std::span counts = medianInfo.counts; - const uint_fast16_t numValues = medianInfo.numValues; - const auto medianTarget = static_cast((medianInfo.numValues + 1) / 2); - uint_fast16_t partialSum = 0; - uint_fast16_t i = 0; - for (; partialSum < medianTarget && partialSum != numValues; ++i) { - partialSum += counts[i]; - } - - // Special cases: - // 1. If the elements are empty, this will return 0. - // 2. If all the elements are the same, this will be `value + 1` (rolling over to 0 if value is 256). - // This means all the elements will be on one side of the pivot (left unless the value is 255). - return static_cast(i); - } - - template - static void maybeAddToSubdivisionForMedian( - const PaletteKdTreeNode &node, - const SDL_Color palette[256], unsigned paletteIndex, - std::span medianInfos) + [[nodiscard]] uint8_t findNearestNeighbor(const RGB &rgb) const { - const uint8_t color = node.getColorCoordinate(palette[paletteIndex]); - if constexpr (N == 1) { - MedianInfo &medianInfo = medianInfos[0]; - ++medianInfo.counts[color]; - ++medianInfo.numValues; - } else { - const bool isLeft = color < node.pivot; - maybeAddToSubdivisionForMedian(node.child(isLeft), - palette, - paletteIndex, - isLeft - ? medianInfos.template subspan<0, N / 2>() - : medianInfos.template subspan()); - } + VisitState visitState; + visitState.bestDiff = std::numeric_limits::max(); + findNearestNeighborVisit(tree_, rgb, visitState); + return visitState.best; } - template - static void setPivotsRecursively( - PaletteKdTreeNode &node, - std::span medianInfos) - { - if constexpr (N == 1) { - node.pivot = getMedian(medianInfos[0]); - } else { - setPivotsRecursively(node.left, medianInfos.template subspan<0, N / 2>()); - setPivotsRecursively(node.right, medianInfos.template subspan()); - } - } + [[maybe_unused]] [[nodiscard]] std::string toGraphvizDot() const; - template - void populatePivotsForTargetDepth(const SDL_Color palette[256], int skipFrom, int skipTo) - { - constexpr size_t NumSubdivisions = 1U << TargetDepth; - std::array subdivisions = {}; - const std::span subdivisionsSpan { subdivisions }; - for (int i = 0; i < 256; ++i) { - if (i >= skipFrom && i <= skipTo) continue; - maybeAddToSubdivisionForMedian(tree_, palette, i, subdivisionsSpan); - } - setPivotsRecursively(tree_, subdivisionsSpan); - } - - template - void populatePivotsImpl(const SDL_Color palette[256], int skipFrom, int skipTo, std::index_sequence intSeq) // NOLINT(misc-unused-parameters) +private: + [[nodiscard]] static constexpr uint32_t getColorDistance(const std::array &a, const std::array &b) { - (populatePivotsForTargetDepth(palette, skipFrom, skipTo), ...); + const int diffr = a[0] - b[0]; + const int diffg = a[1] - b[1]; + const int diffb = a[2] - b[2]; + return (diffr * diffr) + (diffg * diffg) + (diffb * diffb); } - void populatePivots(const SDL_Color palette[256], int skipFrom, int skipTo) + [[nodiscard]] static constexpr uint32_t getColorDistanceToPlane(int x1, int x2) { - populatePivotsImpl(palette, skipFrom, skipTo, std::make_index_sequence {}); + // Our planes are axis-aligned, so a distance from a point to a plane + // can be calculated based on just the axis coordinate. + const int delta = x1 - x2; + return static_cast(delta * delta); } template void findNearestNeighborVisit(const PaletteKdTreeNode &node, const RGB &rgb, - uint32_t &bestDiff, uint8_t &best) const + VisitState &visitState) const { const uint8_t coord = rgb[PaletteKdTreeNode::Coord]; - findNearestNeighborVisit(node.child(coord < node.pivot), rgb, bestDiff, best); + findNearestNeighborVisit(node.child(coord < node.pivot), rgb, visitState); // To see if we need to check a node's subtree, we compare the distance from the query // to the current best candidate vs the distance to the edge of the half-space represented // by the node. - if (GetColorDistanceToPlane(node.pivot, coord) < bestDiff) { - findNearestNeighborVisit(node.child(coord >= node.pivot), rgb, bestDiff, best); + if (getColorDistanceToPlane(node.pivot, coord) < visitState.bestDiff) { + findNearestNeighborVisit(node.child(coord >= node.pivot), rgb, visitState); } } void findNearestNeighborVisit(const PaletteKdTreeNode<0> &node, const RGB &rgb, - uint32_t &bestDiff, uint8_t &best) const + VisitState &visitState) const { + // Nodes are almost never empty. + // Separating the empty check from the loop makes this faster, + // probaly because of better branch prediction. + if (node.valuesBegin > node.valuesEndInclusive) return; + const std::pair *it = values_.data() + node.valuesBegin; const std::pair *const end = values_.data() + node.valuesEndInclusive; - while (it <= end) { + do { const auto &[paletteColor, paletteIndex] = *it++; - const uint32_t diff = GetColorDistance(paletteColor, rgb); - if (diff < bestDiff) { - best = static_cast(it - values_.data() - 1); - bestDiff = diff; + const uint32_t diff = getColorDistance(paletteColor, rgb); + if (diff < visitState.bestDiff) { + visitState.best = paletteIndex; + visitState.bestDiff = diff; } - } + } while (it <= end); } PaletteKdTreeNode tree_; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b0822551e..e490a0ac2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -119,7 +119,13 @@ target_link_dependencies(format_int_test PRIVATE libdevilutionx_format_int langu target_link_dependencies(ini_test PRIVATE libdevilutionx_ini app_fatal_for_testing) target_link_dependencies(light_render_benchmark PRIVATE libdevilutionx_light_render DevilutionX::SDL libdevilutionx_surface libdevilutionx_paths app_fatal_for_testing) target_link_dependencies(palette_blending_test PRIVATE libdevilutionx_palette_blending DevilutionX::SDL libdevilutionx_strings GTest::gmock app_fatal_for_testing) -target_link_dependencies(palette_blending_benchmark PRIVATE libdevilutionx_palette_blending DevilutionX::SDL app_fatal_for_testing) +target_link_dependencies(palette_blending_benchmark + PRIVATE + DevilutionX::SDL + libdevilutionx_palette_blending + libdevilutionx_palette_kd_tree + app_fatal_for_testing +) target_link_dependencies(parse_int_test PRIVATE libdevilutionx_parse_int) target_link_dependencies(path_test PRIVATE libdevilutionx_pathfinding libdevilutionx_direction app_fatal_for_testing) target_link_dependencies(vision_test PRIVATE libdevilutionx_vision)