diff --git a/Source/utils/palette_blending.cpp b/Source/utils/palette_blending.cpp index e6af27090..3c176b5a3 100644 --- a/Source/utils/palette_blending.cpp +++ b/Source/utils/palette_blending.cpp @@ -1,10 +1,13 @@ #include "utils/palette_blending.hpp" +#include #include #include #include +#include "utils/palette_kd_tree.hpp" + namespace devilution { // This array is read from a lot on every frame. @@ -18,11 +21,7 @@ uint16_t paletteTransparencyLookupBlack16[65536]; namespace { -struct RGB { - uint8_t r; - uint8_t g; - uint8_t b; -}; +using RGB = std::array; uint8_t FindBestMatchForColor(SDL_Color palette[256], RGB color, int skipFrom, int skipTo) { @@ -31,11 +30,7 @@ uint8_t FindBestMatchForColor(SDL_Color palette[256], RGB color, int skipFrom, i for (int i = 0; i < 256; i++) { if (i >= skipFrom && i <= skipTo) continue; - const int diffr = palette[i].r - color.r; - const int diffg = palette[i].g - color.g; - const int diffb = palette[i].b - color.b; - const uint32_t diff = diffr * diffr + diffg * diffg + diffb * diffb; - + const uint32_t diff = GetColorDistance(palette[i], color); if (bestDiff > diff) { best = i; bestDiff = diff; @@ -47,9 +42,9 @@ uint8_t FindBestMatchForColor(SDL_Color palette[256], RGB color, int skipFrom, i RGB BlendColors(const SDL_Color &a, const SDL_Color &b) { return RGB { - .r = static_cast((static_cast(a.r) + static_cast(b.r)) / 2), - .g = static_cast((static_cast(a.g) + static_cast(b.g)) / 2), - .b = static_cast((static_cast(a.b) + static_cast(b.b)) / 2), + static_cast((static_cast(a.r) + static_cast(b.r)) / 2), + static_cast((static_cast(a.g) + static_cast(b.g)) / 2), + static_cast((static_cast(a.b) + static_cast(b.b)) / 2), }; } @@ -64,6 +59,7 @@ void SetPaletteTransparencyLookupBlack16(unsigned i, unsigned j) void GenerateBlendedLookupTable(SDL_Color palette[256], int skipFrom, int skipTo) { + const PaletteKdTree kdTree { palette }; for (unsigned i = 0; i < 256; i++) { paletteTransparencyLookup[i][i] = i; unsigned j = 0; @@ -72,8 +68,7 @@ void GenerateBlendedLookupTable(SDL_Color palette[256], int skipFrom, int skipTo } ++j; for (; j < 256; j++) { - const uint8_t best = FindBestMatchForColor(palette, BlendColors(palette[i], palette[j]), skipFrom, skipTo); - paletteTransparencyLookup[i][j] = best; + paletteTransparencyLookup[i][j] = kdTree.findNearestNeighbor(BlendColors(palette[i], palette[j])); } } diff --git a/Source/utils/palette_kd_tree.hpp b/Source/utils/palette_kd_tree.hpp new file mode 100644 index 000000000..f263e5dc7 --- /dev/null +++ b/Source/utils/palette_kd_tree.hpp @@ -0,0 +1,171 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include "utils/algorithm/container.hpp" +#include "utils/static_vector.hpp" + +namespace devilution { + +[[nodiscard]] inline uint32_t GetColorDistance(const SDL_Color &a, const std::array &b) +{ + const int diffr = a.r - b[0]; + const int diffg = a.g - b[1]; + const int diffb = a.b - b[2]; + return (diffr * diffr) + (diffg * diffg) + (diffb * diffb); +} + +/** + * @brief A 3-level kd-tree used to find the nearest neighbor in the color space. + * + * Each level splits the space in half by red, green, and blue respectively. + */ +class PaletteKdTree { + using RGB = std::array; + +public: + explicit PaletteKdTree(const SDL_Color palette[256]) + : palette_(palette) + , pivots_(getPivots(palette)) + { + for (unsigned i = 0; i < 256; ++i) { + const SDL_Color &color = palette[i]; + auto &level1 = color.r < pivots_[0] ? tree_.first : tree_.second; + auto &level2 = color.g < pivots_[1] ? level1.first : level1.second; + auto &level3 = color.b < pivots_[2] ? level2.first : level2.second; + level3.emplace_back(i); + } + + // Uncomment the loop below to print the node distribution: + // for (const bool r : { false, true }) { + // for (const bool g : { false, true }) { + // for (const bool b : { false, true }) { + // printf("r%d.g%d.b%d: %d\n", + // static_cast(r), static_cast(g), static_cast(b), + // static_cast(getLeaf(r, g, b).size())); + // } + // } + // } + } + + [[nodiscard]] uint8_t findNearestNeighbor(const RGB &rgb) const + { + const bool compR = rgb[0] < pivots_[0]; + const bool compG = rgb[1] < pivots_[1]; + const bool compB = rgb[2] < pivots_[2]; + + // Conceptually, we visit the tree recursively. + // As the tree only has 3 levels, we fully unroll + // the recursion here. + uint8_t best; + uint32_t bestDiff = std::numeric_limits::max(); + checkLeaf(compR, compG, compB, rgb, best, bestDiff); + if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) { + checkLeaf(compR, compG, !compB, rgb, best, bestDiff); + } + if (shouldCheckNode(best, bestDiff, /*coord=*/1, rgb)) { + checkLeaf(compR, !compG, compB, rgb, best, bestDiff); + if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) { + checkLeaf(compR, !compG, !compB, rgb, best, bestDiff); + } + } + if (shouldCheckNode(best, bestDiff, /*coord=*/0, rgb)) { + checkLeaf(!compR, compG, compB, rgb, best, bestDiff); + if (shouldCheckNode(best, bestDiff, /*coord=*/1, rgb)) { + checkLeaf(!compR, !compG, compB, rgb, best, bestDiff); + if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) { + checkLeaf(!compR, !compG, !compB, rgb, best, bestDiff); + } + } + if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) { + checkLeaf(!compR, compG, !compB, rgb, best, bestDiff); + } + } + return best; + } + +private: + static uint8_t getMedian(std::span elements) + { + const auto middleItr = elements.begin() + (elements.size() / 2); + std::nth_element(elements.begin(), middleItr, elements.end()); + if (elements.size() % 2 == 0) { + const auto leftMiddleItr = std::max_element(elements.begin(), middleItr); + return (*leftMiddleItr + *middleItr) / 2; + } + return *middleItr; + } + + static std::array getPivots(const SDL_Color palette[256]) + { + std::array, 3> coords; + for (unsigned i = 0; i < 256; ++i) { + coords[0][i] = palette[i].r; + coords[1][i] = palette[i].g; + coords[2][i] = palette[i].b; + } + return { getMedian(coords[0]), getMedian(coords[1]), getMedian(coords[2]) }; + } + + void checkLeaf(bool compR, bool compG, bool compB, const RGB &rgb, uint8_t &best, uint32_t &bestDiff) const + { + const std::span leaf = getLeaf(compR, compG, compB); + uint8_t leafBest; + uint32_t leafBestDiff = bestDiff; + for (const uint8_t paletteIndex : leaf) { + const uint32_t diff = GetColorDistance(palette_[paletteIndex], rgb); + if (diff < leafBestDiff) { + leafBest = paletteIndex; + leafBestDiff = diff; + } + } + if (leafBestDiff < bestDiff) { + best = leafBest; + bestDiff = leafBestDiff; + } + } + + [[nodiscard]] bool shouldCheckNode(uint8_t best, uint32_t bestDiff, unsigned coord, const RGB &rgb) const + { + // To see if we need to check a node's subtree, we compare the distance from the query + // to the current best candidate vs the distance to the edge of the half-space represented + // by the node. + if (bestDiff == std::numeric_limits::max()) return true; + const int delta = static_cast(pivots_[coord]) - static_cast(rgb[coord]); + return delta * delta < GetColorDistance(palette_[best], rgb); + } + + [[nodiscard]] std::span getLeaf(bool r, bool g, bool b) const + { + const auto &level1 = r ? tree_.first : tree_.second; + const auto &level2 = g ? level1.first : level1.second; + const auto &level3 = b ? level2.first : level2.second; + return { level3 }; + } + + const SDL_Color *palette_; + std::array pivots_; + std::pair< + // r0 + std::pair< + // r0.g0.b{0, 1} + std::pair, StaticVector>, + // r0.g1.b{0, 1} + std::pair, StaticVector>>, + // r1 + std::pair< + // r1.g0.b{0, 1} + std::pair, StaticVector>, + // r1.g1.b{0, 1} + std::pair, StaticVector>>> + tree_; +}; + +} // namespace devilution diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a758e5eb0..cc37ca113 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -106,8 +106,8 @@ target_link_dependencies(dun_render_benchmark PRIVATE libdevilutionx_so) target_link_dependencies(file_util_test PRIVATE libdevilutionx_file_util app_fatal_for_testing) target_link_dependencies(format_int_test PRIVATE libdevilutionx_format_int language_for_testing) target_link_dependencies(ini_test PRIVATE libdevilutionx_ini app_fatal_for_testing) -target_link_dependencies(palette_blending_test PRIVATE libdevilutionx_palette_blending DevilutionX::SDL libdevilutionx_strings GTest::gmock) -target_link_dependencies(palette_blending_benchmark PRIVATE libdevilutionx_palette_blending DevilutionX::SDL) +target_link_dependencies(palette_blending_test PRIVATE libdevilutionx_palette_blending DevilutionX::SDL libdevilutionx_strings GTest::gmock app_fatal_for_testing) +target_link_dependencies(palette_blending_benchmark PRIVATE libdevilutionx_palette_blending DevilutionX::SDL app_fatal_for_testing) target_link_dependencies(parse_int_test PRIVATE libdevilutionx_parse_int) target_link_dependencies(path_test PRIVATE libdevilutionx_pathfinding libdevilutionx_direction app_fatal_for_testing) target_link_dependencies(vision_test PRIVATE libdevilutionx_vision)