Fix a bug where _fit_global_palette would crash if there were fewer

than 16 global colours computed.
This commit is contained in:
kris 2021-11-23 13:59:48 +00:00
parent 6e52680cf1
commit 1ce5c25764
1 changed files with 17 additions and 11 deletions

View File

@ -83,17 +83,22 @@ class ClusterPalette:
clusters = cluster.MiniBatchKMeans(n_clusters=16, max_iter=10000)
clusters.fit_predict(self._colours_cam)
num_colours = clusters.n_clusters
labels = clusters.labels_
# Dict of {palette idx : frequency count}
palette_freq = {idx: 0 for idx in range(16)}
for idx, freq in zip(*np.unique(labels, return_counts=True)):
palette_freq[idx] = freq
frequency_order = [
k for k, v in sorted(
# List of (palette idx, frequency count)
list(zip(*np.unique(labels, return_counts=True))),
key=lambda kv: kv[1], reverse=True)]
list(palette_freq.items()), key=lambda kv: kv[1], reverse=True)]
return dither_pyx.convert_cam16ucs_to_rgb12_iigs(
clusters.cluster_centers_[frequency_order].astype(
np.float32))
self._global_palette = (
dither_pyx.convert_cam16ucs_to_rgb12_iigs(
clusters.cluster_centers_[frequency_order].astype(
np.float32)))
def _palette_splits(self, palette_height=35):
# The 16 palettes are striped across consecutive (overlapping) line
@ -126,7 +131,7 @@ class ClusterPalette:
new_lower = np.clip(new_lower, 0, np.clip(new_upper, 1, 200) - 1)
new_upper = np.clip(new_upper, new_lower + 1, 200)
assert new_lower >= 0, new_upper-1
assert new_lower >= 0, new_upper - 1
self._palette_splits[palette_to_mutate] = (new_lower, new_upper)
self._palette_mutate_idx = palette_to_mutate
@ -185,7 +190,7 @@ class ClusterPalette:
# Compute a new 16-colour global palette for the entire image,
# used as the starting center positions for k-means clustering of the
# individual palettes
self._global_palette = self._fit_global_palette()
self._fit_global_palette()
dynamic_colours = 16 - self._reserved_colours
@ -338,8 +343,8 @@ def main():
palettes_used = [False] * 16
for palette in new_line_to_palette:
palettes_used[palette] = True
for palette_idx, palette in enumerate(palettes_used):
if palette:
for palette_idx, palette_used in enumerate(palettes_used):
if palette_used:
continue
print("Reassigning palette %d" % palette_idx)
max_width = 0
@ -354,7 +359,8 @@ def main():
lower, upper = last_good_splits[split_palette_idx]
if upper - lower > 20:
mid = (lower + upper) // 2
cluster_palette._palette_splits[split_palette_idx] = (lower, mid)
cluster_palette._palette_splits[split_palette_idx] = (
lower, mid - 1)
cluster_palette._palette_splits[palette_idx] = (mid, upper)
else:
lower = np.random.randint(0, 199)