Browse Source

Use a k-d tree for palette blending

Uses a k-d tree to quickly find the best match
for a color when generating the palette blending
lookup table.

https://en.wikipedia.org/wiki/K-d_tree

This is 3x faster than the previous naive approach:

```
Benchmark                                              Time             CPU      Time Old      Time New       CPU Old       CPU New
-----------------------------------------------------------------------------------------------------------------------------------
BM_GenerateBlendedLookupTable_pvalue                 0.0002          0.0002      U Test, Repetitions: 10 vs 10
BM_GenerateBlendedLookupTable_mean                  -0.7153         -0.7153      18402641       5239051      18399111       5238025
BM_GenerateBlendedLookupTable_median                -0.7153         -0.7153      18403261       5239042      18398841       5237497
BM_GenerateBlendedLookupTable_stddev                -0.2775         +0.3858          2257          1631          1347          1867
BM_GenerateBlendedLookupTable_cv                    +1.5379         +3.8677             0             0             0             0
OVERALL_GEOMEAN                                     -0.7153         -0.7153             0             0             0             0
```

The distribution is somewhat poor with just 3 levels, so this can be improved further.
For example, here is the leaf size distribution in the cathedral:

```
r0.g0.b0: 88
r0.g0.b1: 10
r0.g1.b0: 2
r0.g1.b1: 32
r1.g0.b0: 27
r1.g0.b1: 4
r1.g1.b0: 12
r1.g1.b1: 81
```
pull/8028/head
Gleb Mazovetskiy 10 months ago
parent
commit
cd38ca7631
  1. 25
      Source/utils/palette_blending.cpp
  2. 171
      Source/utils/palette_kd_tree.hpp
  3. 4
      test/CMakeLists.txt

25
Source/utils/palette_blending.cpp

@ -1,10 +1,13 @@
#include "utils/palette_blending.hpp" #include "utils/palette_blending.hpp"
#include <array>
#include <cstdint> #include <cstdint>
#include <limits> #include <limits>
#include <SDL.h> #include <SDL.h>
#include "utils/palette_kd_tree.hpp"
namespace devilution { namespace devilution {
// This array is read from a lot on every frame. // This array is read from a lot on every frame.
@ -18,11 +21,7 @@ uint16_t paletteTransparencyLookupBlack16[65536];
namespace { namespace {
struct RGB { using RGB = std::array<uint8_t, 3>;
uint8_t r;
uint8_t g;
uint8_t b;
};
uint8_t FindBestMatchForColor(SDL_Color palette[256], RGB color, int skipFrom, int skipTo) uint8_t FindBestMatchForColor(SDL_Color palette[256], RGB color, int skipFrom, int skipTo)
{ {
@ -31,11 +30,7 @@ uint8_t FindBestMatchForColor(SDL_Color palette[256], RGB color, int skipFrom, i
for (int i = 0; i < 256; i++) { for (int i = 0; i < 256; i++) {
if (i >= skipFrom && i <= skipTo) if (i >= skipFrom && i <= skipTo)
continue; continue;
const int diffr = palette[i].r - color.r; const uint32_t diff = GetColorDistance(palette[i], color);
const int diffg = palette[i].g - color.g;
const int diffb = palette[i].b - color.b;
const uint32_t diff = diffr * diffr + diffg * diffg + diffb * diffb;
if (bestDiff > diff) { if (bestDiff > diff) {
best = i; best = i;
bestDiff = diff; bestDiff = diff;
@ -47,9 +42,9 @@ uint8_t FindBestMatchForColor(SDL_Color palette[256], RGB color, int skipFrom, i
RGB BlendColors(const SDL_Color &a, const SDL_Color &b) RGB BlendColors(const SDL_Color &a, const SDL_Color &b)
{ {
return RGB { return RGB {
.r = static_cast<uint8_t>((static_cast<int>(a.r) + static_cast<int>(b.r)) / 2), static_cast<uint8_t>((static_cast<int>(a.r) + static_cast<int>(b.r)) / 2),
.g = static_cast<uint8_t>((static_cast<int>(a.g) + static_cast<int>(b.g)) / 2), static_cast<uint8_t>((static_cast<int>(a.g) + static_cast<int>(b.g)) / 2),
.b = static_cast<uint8_t>((static_cast<int>(a.b) + static_cast<int>(b.b)) / 2), static_cast<uint8_t>((static_cast<int>(a.b) + static_cast<int>(b.b)) / 2),
}; };
} }
@ -64,6 +59,7 @@ void SetPaletteTransparencyLookupBlack16(unsigned i, unsigned j)
void GenerateBlendedLookupTable(SDL_Color palette[256], int skipFrom, int skipTo) void GenerateBlendedLookupTable(SDL_Color palette[256], int skipFrom, int skipTo)
{ {
const PaletteKdTree kdTree { palette };
for (unsigned i = 0; i < 256; i++) { for (unsigned i = 0; i < 256; i++) {
paletteTransparencyLookup[i][i] = i; paletteTransparencyLookup[i][i] = i;
unsigned j = 0; unsigned j = 0;
@ -72,8 +68,7 @@ void GenerateBlendedLookupTable(SDL_Color palette[256], int skipFrom, int skipTo
} }
++j; ++j;
for (; j < 256; j++) { for (; j < 256; j++) {
const uint8_t best = FindBestMatchForColor(palette, BlendColors(palette[i], palette[j]), skipFrom, skipTo); paletteTransparencyLookup[i][j] = kdTree.findNearestNeighbor(BlendColors(palette[i], palette[j]));
paletteTransparencyLookup[i][j] = best;
} }
} }

171
Source/utils/palette_kd_tree.hpp

@ -0,0 +1,171 @@
#pragma once
#include <algorithm>
#include <array>
#include <cstdint>
#include <limits>
#include <span>
#include <utility>
#include <SDL.h>
#include "utils/algorithm/container.hpp"
#include "utils/static_vector.hpp"
namespace devilution {
[[nodiscard]] inline uint32_t GetColorDistance(const SDL_Color &a, const std::array<uint8_t, 3> &b)
{
const int diffr = a.r - b[0];
const int diffg = a.g - b[1];
const int diffb = a.b - b[2];
return (diffr * diffr) + (diffg * diffg) + (diffb * diffb);
}
/**
* @brief A 3-level kd-tree used to find the nearest neighbor in the color space.
*
* Each level splits the space in half by red, green, and blue respectively.
*/
class PaletteKdTree {
using RGB = std::array<uint8_t, 3>;
public:
explicit PaletteKdTree(const SDL_Color palette[256])
: palette_(palette)
, pivots_(getPivots(palette))
{
for (unsigned i = 0; i < 256; ++i) {
const SDL_Color &color = palette[i];
auto &level1 = color.r < pivots_[0] ? tree_.first : tree_.second;
auto &level2 = color.g < pivots_[1] ? level1.first : level1.second;
auto &level3 = color.b < pivots_[2] ? level2.first : level2.second;
level3.emplace_back(i);
}
// Uncomment the loop below to print the node distribution:
// for (const bool r : { false, true }) {
// for (const bool g : { false, true }) {
// for (const bool b : { false, true }) {
// printf("r%d.g%d.b%d: %d\n",
// static_cast<int>(r), static_cast<int>(g), static_cast<int>(b),
// static_cast<int>(getLeaf(r, g, b).size()));
// }
// }
// }
}
[[nodiscard]] uint8_t findNearestNeighbor(const RGB &rgb) const
{
const bool compR = rgb[0] < pivots_[0];
const bool compG = rgb[1] < pivots_[1];
const bool compB = rgb[2] < pivots_[2];
// Conceptually, we visit the tree recursively.
// As the tree only has 3 levels, we fully unroll
// the recursion here.
uint8_t best;
uint32_t bestDiff = std::numeric_limits<uint32_t>::max();
checkLeaf(compR, compG, compB, rgb, best, bestDiff);
if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) {
checkLeaf(compR, compG, !compB, rgb, best, bestDiff);
}
if (shouldCheckNode(best, bestDiff, /*coord=*/1, rgb)) {
checkLeaf(compR, !compG, compB, rgb, best, bestDiff);
if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) {
checkLeaf(compR, !compG, !compB, rgb, best, bestDiff);
}
}
if (shouldCheckNode(best, bestDiff, /*coord=*/0, rgb)) {
checkLeaf(!compR, compG, compB, rgb, best, bestDiff);
if (shouldCheckNode(best, bestDiff, /*coord=*/1, rgb)) {
checkLeaf(!compR, !compG, compB, rgb, best, bestDiff);
if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) {
checkLeaf(!compR, !compG, !compB, rgb, best, bestDiff);
}
}
if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) {
checkLeaf(!compR, compG, !compB, rgb, best, bestDiff);
}
}
return best;
}
private:
static uint8_t getMedian(std::span<uint8_t, 256> elements)
{
const auto middleItr = elements.begin() + (elements.size() / 2);
std::nth_element(elements.begin(), middleItr, elements.end());
if (elements.size() % 2 == 0) {
const auto leftMiddleItr = std::max_element(elements.begin(), middleItr);
return (*leftMiddleItr + *middleItr) / 2;
}
return *middleItr;
}
static std::array<uint8_t, 3> getPivots(const SDL_Color palette[256])
{
std::array<std::array<uint8_t, 256>, 3> coords;
for (unsigned i = 0; i < 256; ++i) {
coords[0][i] = palette[i].r;
coords[1][i] = palette[i].g;
coords[2][i] = palette[i].b;
}
return { getMedian(coords[0]), getMedian(coords[1]), getMedian(coords[2]) };
}
void checkLeaf(bool compR, bool compG, bool compB, const RGB &rgb, uint8_t &best, uint32_t &bestDiff) const
{
const std::span<const uint8_t> leaf = getLeaf(compR, compG, compB);
uint8_t leafBest;
uint32_t leafBestDiff = bestDiff;
for (const uint8_t paletteIndex : leaf) {
const uint32_t diff = GetColorDistance(palette_[paletteIndex], rgb);
if (diff < leafBestDiff) {
leafBest = paletteIndex;
leafBestDiff = diff;
}
}
if (leafBestDiff < bestDiff) {
best = leafBest;
bestDiff = leafBestDiff;
}
}
[[nodiscard]] bool shouldCheckNode(uint8_t best, uint32_t bestDiff, unsigned coord, const RGB &rgb) const
{
// To see if we need to check a node's subtree, we compare the distance from the query
// to the current best candidate vs the distance to the edge of the half-space represented
// by the node.
if (bestDiff == std::numeric_limits<uint32_t>::max()) return true;
const int delta = static_cast<int>(pivots_[coord]) - static_cast<int>(rgb[coord]);
return delta * delta < GetColorDistance(palette_[best], rgb);
}
[[nodiscard]] std::span<const uint8_t> getLeaf(bool r, bool g, bool b) const
{
const auto &level1 = r ? tree_.first : tree_.second;
const auto &level2 = g ? level1.first : level1.second;
const auto &level3 = b ? level2.first : level2.second;
return { level3 };
}
const SDL_Color *palette_;
std::array<uint8_t, 3> pivots_;
std::pair<
// r0
std::pair<
// r0.g0.b{0, 1}
std::pair<StaticVector<uint8_t, 256>, StaticVector<uint8_t, 256>>,
// r0.g1.b{0, 1}
std::pair<StaticVector<uint8_t, 256>, StaticVector<uint8_t, 256>>>,
// r1
std::pair<
// r1.g0.b{0, 1}
std::pair<StaticVector<uint8_t, 256>, StaticVector<uint8_t, 256>>,
// r1.g1.b{0, 1}
std::pair<StaticVector<uint8_t, 256>, StaticVector<uint8_t, 256>>>>
tree_;
};
} // namespace devilution

4
test/CMakeLists.txt

@ -106,8 +106,8 @@ target_link_dependencies(dun_render_benchmark PRIVATE libdevilutionx_so)
target_link_dependencies(file_util_test PRIVATE libdevilutionx_file_util app_fatal_for_testing) target_link_dependencies(file_util_test PRIVATE libdevilutionx_file_util app_fatal_for_testing)
target_link_dependencies(format_int_test PRIVATE libdevilutionx_format_int language_for_testing) target_link_dependencies(format_int_test PRIVATE libdevilutionx_format_int language_for_testing)
target_link_dependencies(ini_test PRIVATE libdevilutionx_ini app_fatal_for_testing) target_link_dependencies(ini_test PRIVATE libdevilutionx_ini app_fatal_for_testing)
target_link_dependencies(palette_blending_test PRIVATE libdevilutionx_palette_blending DevilutionX::SDL libdevilutionx_strings GTest::gmock) target_link_dependencies(palette_blending_test PRIVATE libdevilutionx_palette_blending DevilutionX::SDL libdevilutionx_strings GTest::gmock app_fatal_for_testing)
target_link_dependencies(palette_blending_benchmark PRIVATE libdevilutionx_palette_blending DevilutionX::SDL) target_link_dependencies(palette_blending_benchmark PRIVATE libdevilutionx_palette_blending DevilutionX::SDL app_fatal_for_testing)
target_link_dependencies(parse_int_test PRIVATE libdevilutionx_parse_int) target_link_dependencies(parse_int_test PRIVATE libdevilutionx_parse_int)
target_link_dependencies(path_test PRIVATE libdevilutionx_pathfinding libdevilutionx_direction app_fatal_for_testing) target_link_dependencies(path_test PRIVATE libdevilutionx_pathfinding libdevilutionx_direction app_fatal_for_testing)
target_link_dependencies(vision_test PRIVATE libdevilutionx_vision) target_link_dependencies(vision_test PRIVATE libdevilutionx_vision)

Loading…
Cancel
Save