You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
5.6 KiB
171 lines
5.6 KiB
#pragma once |
|
|
|
#include <algorithm> |
|
#include <array> |
|
#include <cstdint> |
|
#include <limits> |
|
#include <span> |
|
#include <utility> |
|
|
|
#include <SDL.h> |
|
|
|
#include "utils/algorithm/container.hpp" |
|
#include "utils/static_vector.hpp" |
|
|
|
namespace devilution { |
|
|
|
[[nodiscard]] inline uint32_t GetColorDistance(const SDL_Color &a, const std::array<uint8_t, 3> &b) |
|
{ |
|
const int diffr = a.r - b[0]; |
|
const int diffg = a.g - b[1]; |
|
const int diffb = a.b - b[2]; |
|
return (diffr * diffr) + (diffg * diffg) + (diffb * diffb); |
|
} |
|
|
|
/** |
|
* @brief A 3-level kd-tree used to find the nearest neighbor in the color space. |
|
* |
|
* Each level splits the space in half by red, green, and blue respectively. |
|
*/ |
|
class PaletteKdTree { |
|
using RGB = std::array<uint8_t, 3>; |
|
|
|
public: |
|
explicit PaletteKdTree(const SDL_Color palette[256]) |
|
: palette_(palette) |
|
, pivots_(getPivots(palette)) |
|
{ |
|
for (unsigned i = 0; i < 256; ++i) { |
|
const SDL_Color &color = palette[i]; |
|
auto &level1 = color.r < pivots_[0] ? tree_.first : tree_.second; |
|
auto &level2 = color.g < pivots_[1] ? level1.first : level1.second; |
|
auto &level3 = color.b < pivots_[2] ? level2.first : level2.second; |
|
level3.emplace_back(i); |
|
} |
|
|
|
// Uncomment the loop below to print the node distribution: |
|
// for (const bool r : { false, true }) { |
|
// for (const bool g : { false, true }) { |
|
// for (const bool b : { false, true }) { |
|
// printf("r%d.g%d.b%d: %d\n", |
|
// static_cast<int>(r), static_cast<int>(g), static_cast<int>(b), |
|
// static_cast<int>(getLeaf(r, g, b).size())); |
|
// } |
|
// } |
|
// } |
|
} |
|
|
|
[[nodiscard]] uint8_t findNearestNeighbor(const RGB &rgb) const |
|
{ |
|
const bool compR = rgb[0] < pivots_[0]; |
|
const bool compG = rgb[1] < pivots_[1]; |
|
const bool compB = rgb[2] < pivots_[2]; |
|
|
|
// Conceptually, we visit the tree recursively. |
|
// As the tree only has 3 levels, we fully unroll |
|
// the recursion here. |
|
uint8_t best; |
|
uint32_t bestDiff = std::numeric_limits<uint32_t>::max(); |
|
checkLeaf(compR, compG, compB, rgb, best, bestDiff); |
|
if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) { |
|
checkLeaf(compR, compG, !compB, rgb, best, bestDiff); |
|
} |
|
if (shouldCheckNode(best, bestDiff, /*coord=*/1, rgb)) { |
|
checkLeaf(compR, !compG, compB, rgb, best, bestDiff); |
|
if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) { |
|
checkLeaf(compR, !compG, !compB, rgb, best, bestDiff); |
|
} |
|
} |
|
if (shouldCheckNode(best, bestDiff, /*coord=*/0, rgb)) { |
|
checkLeaf(!compR, compG, compB, rgb, best, bestDiff); |
|
if (shouldCheckNode(best, bestDiff, /*coord=*/1, rgb)) { |
|
checkLeaf(!compR, !compG, compB, rgb, best, bestDiff); |
|
if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) { |
|
checkLeaf(!compR, !compG, !compB, rgb, best, bestDiff); |
|
} |
|
} |
|
if (shouldCheckNode(best, bestDiff, /*coord=*/2, rgb)) { |
|
checkLeaf(!compR, compG, !compB, rgb, best, bestDiff); |
|
} |
|
} |
|
return best; |
|
} |
|
|
|
private: |
|
static uint8_t getMedian(std::span<uint8_t, 256> elements) |
|
{ |
|
const auto middleItr = elements.begin() + (elements.size() / 2); |
|
std::nth_element(elements.begin(), middleItr, elements.end()); |
|
if (elements.size() % 2 == 0) { |
|
const auto leftMiddleItr = std::max_element(elements.begin(), middleItr); |
|
return (*leftMiddleItr + *middleItr) / 2; |
|
} |
|
return *middleItr; |
|
} |
|
|
|
static std::array<uint8_t, 3> getPivots(const SDL_Color palette[256]) |
|
{ |
|
std::array<std::array<uint8_t, 256>, 3> coords; |
|
for (unsigned i = 0; i < 256; ++i) { |
|
coords[0][i] = palette[i].r; |
|
coords[1][i] = palette[i].g; |
|
coords[2][i] = palette[i].b; |
|
} |
|
return { getMedian(coords[0]), getMedian(coords[1]), getMedian(coords[2]) }; |
|
} |
|
|
|
void checkLeaf(bool compR, bool compG, bool compB, const RGB &rgb, uint8_t &best, uint32_t &bestDiff) const |
|
{ |
|
const std::span<const uint8_t> leaf = getLeaf(compR, compG, compB); |
|
uint8_t leafBest; |
|
uint32_t leafBestDiff = bestDiff; |
|
for (const uint8_t paletteIndex : leaf) { |
|
const uint32_t diff = GetColorDistance(palette_[paletteIndex], rgb); |
|
if (diff < leafBestDiff) { |
|
leafBest = paletteIndex; |
|
leafBestDiff = diff; |
|
} |
|
} |
|
if (leafBestDiff < bestDiff) { |
|
best = leafBest; |
|
bestDiff = leafBestDiff; |
|
} |
|
} |
|
|
|
[[nodiscard]] bool shouldCheckNode(uint8_t best, uint32_t bestDiff, unsigned coord, const RGB &rgb) const |
|
{ |
|
// To see if we need to check a node's subtree, we compare the distance from the query |
|
// to the current best candidate vs the distance to the edge of the half-space represented |
|
// by the node. |
|
if (bestDiff == std::numeric_limits<uint32_t>::max()) return true; |
|
const int delta = static_cast<int>(pivots_[coord]) - static_cast<int>(rgb[coord]); |
|
return delta * delta < GetColorDistance(palette_[best], rgb); |
|
} |
|
|
|
[[nodiscard]] std::span<const uint8_t> getLeaf(bool r, bool g, bool b) const |
|
{ |
|
const auto &level1 = r ? tree_.first : tree_.second; |
|
const auto &level2 = g ? level1.first : level1.second; |
|
const auto &level3 = b ? level2.first : level2.second; |
|
return { level3 }; |
|
} |
|
|
|
const SDL_Color *palette_; |
|
std::array<uint8_t, 3> pivots_; |
|
std::pair< |
|
// r0 |
|
std::pair< |
|
// r0.g0.b{0, 1} |
|
std::pair<StaticVector<uint8_t, 256>, StaticVector<uint8_t, 256>>, |
|
// r0.g1.b{0, 1} |
|
std::pair<StaticVector<uint8_t, 256>, StaticVector<uint8_t, 256>>>, |
|
// r1 |
|
std::pair< |
|
// r1.g0.b{0, 1} |
|
std::pair<StaticVector<uint8_t, 256>, StaticVector<uint8_t, 256>>, |
|
// r1.g1.b{0, 1} |
|
std::pair<StaticVector<uint8_t, 256>, StaticVector<uint8_t, 256>>>> |
|
tree_; |
|
}; |
|
|
|
} // namespace devilution
|
|
|