This commit is contained in:
kris 2021-11-24 15:21:50 +00:00
parent c36de2b76b
commit d645cc5964
1 changed files with 48 additions and 109 deletions

View File

@ -32,65 +32,40 @@ class ClusterPalette:
def __init__(
self, image: Image, rgb12_iigs_to_cam16ucs, rgb24_to_cam16ucs,
reserved_colours=0):
# Source image in 24-bit linear RGB colour space
self._image_rgb = image
# Source image in CAM16UCS colour space
self._colours_cam = self._image_colours_cam(image)
self._errors = [1e9] * 16
# We fit a 16-colour palette against the entire image which is used
# as starting values for fitting the 16 SHR palettes. This helps to
# provide better global consistency of colours across the palettes,
# e.g. for large blocks of colour. Otherwise these can take a while
# to converge.
self._global_palette = np.empty((16, 3), dtype=np.uint8)
# How many image colours to fix identically across all 16 SHR
# palettes. These are taken to be the most prevalent colours from
# _global_palette.
self._reserved_colours = reserved_colours
# 16 SHR palettes each of 16 colours, in CAM16UCS format
# We fit a 16-colour palette against the entire image which is used
# as starting values for fitting the reserved colours in the 16 SHR
# palettes.
self._global_palette = np.empty((16, 3), dtype=np.uint8)
# 16 SHR palettes each of 16 colours, in CAM16UCS colour space
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
# 16 SHR palettes each of 16 colours, in //gs 4-bit RGB format
# 16 SHR palettes each of 16 colours, in //gs 4-bit RGB colour space
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.uint8)
# defaultdict(list) mapping palette index to lines using this palette
self._palette_lines = self._init_palette_lines()
# Conversion matrix from 12-bit //gs RGB colour space to CAM16UCS
# colour space
self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs
# Conversion matrix from 24-bit linear RGB colour space to CAM16UCS
# colour space
self._rgb24_to_cam16ucs = rgb24_to_cam16ucs
self._palette_lines = defaultdict(list)
# List of line ranges used to train the 16 SHR palettes
# [(lower_0, upper_0), ...]
# self._palette_splits = self._equal_palette_splits()
self._init_palette_lines()
# Whether the previous iteration of proposed palettes was accepted
self._palettes_accepted = False
# Which palette index's line ranges did we mutate in previous iteration
self._palette_mutate_idx = 0
# Delta applied to palette split in previous iteration
self._palette_mutate_delta = (0, 0)
def _init_palette_lines(self):
palette_splits = self._equal_palette_splits()
for i, lh in enumerate(palette_splits):
l, h = lh
self._palette_lines[i].extend(list(range(l, h)))
# lines = list(range(200))
# random.shuffle(lines)
# idx = 0
# while lines:
# self._palette_lines[idx].append(lines.pop())
# idx += 1
def _image_colours_cam(self, image: Image):
colours_rgb = np.asarray(image) # .reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
@ -98,6 +73,23 @@ class ClusterPalette:
"CAM16UCS").astype(np.float32)
return colours_cam
def _init_palette_lines(self, init_random = False):
palette_lines = defaultdict(list)
if init_random:
lines = list(range(200))
random.shuffle(lines)
idx = 0
while lines:
palette_lines[idx].append(lines.pop())
idx += 1
else:
palette_splits = self._equal_palette_splits()
for i, lh in enumerate(palette_splits):
l, h = lh
palette_lines[i].extend(list(range(l, h)))
return palette_lines
def _equal_palette_splits(self, palette_height=35):
# The 16 palettes are striped across consecutive (overlapping) line
# ranges. Since nearby lines tend to have similar colours, this has
@ -133,6 +125,8 @@ class ClusterPalette:
self._image_rgb, palettes_cam, palettes_linear_rgb,
self._rgb24_to_cam16ucs, float(penalty))
# Update map of palettes to image lines for which the palette was the
# best match
palette_lines = defaultdict(list)
for line, palette in enumerate(line_to_palette):
palette_lines[palette].append(line)
@ -146,23 +140,21 @@ class ClusterPalette:
def iterate(self, penalty: float, max_inner_iterations: int,
max_outer_iterations: int):
total_image_error = 1e9
# last_good_splits = self._palette_splits
outer_iterations_since_improvement = 0
while outer_iterations_since_improvement < max_outer_iterations:
print("New iteration")
inner_iterations_since_improvement = 0
# self._palette_splits = self._equal_palette_splits()
self._init_palette_lines()
self._palette_lines = self._init_palette_lines()
self._fit_global_palette()
while inner_iterations_since_improvement < max_inner_iterations:
# print("Iterations %d" % inner_iterations_since_improvement)
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
new_palettes_cam, new_palettes_rgb12_iigs = (
self._propose_palettes())
# Recompute image with proposed palettes and check whether it has
# lower total image error than our previous best.
# Recompute image with proposed palettes and check whether it
# has lower total image error than our previous best.
(output_4bit, line_to_palette, palettes_linear_rgb,
new_total_image_error) = self._dither_image(
new_palettes_cam, penalty)
@ -171,27 +163,24 @@ class ClusterPalette:
# within a palette
self._reassign_unused_palettes(line_to_palette)
# print(total_image_error, new_total_image_error)
if new_total_image_error >= total_image_error:
inner_iterations_since_improvement += 1
continue
# We found a globally better set of palettes
# We found a globally better set of palettes, so restart the
# clocks
inner_iterations_since_improvement = 0
outer_iterations_since_improvement = -1
# last_good_splits = self._palette_splits
total_image_error = new_total_image_error
self._palettes_cam = new_palettes_cam
self._palettes_rgb = new_palettes_rgb12_iigs
self._errors = new_palette_errors
self._palettes_accepted = True
yield (new_total_image_error, output_4bit, line_to_palette,
new_palettes_rgb12_iigs, palettes_linear_rgb)
outer_iterations_since_improvement += 1
def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray]:
"""Attempt to find new palettes that locally improve image quality.
Re-fit a set of 16 palettes from (overlapping) line ranges of the
@ -208,7 +197,6 @@ class ClusterPalette:
The current (locally) best palettes are returned and can be applied
using accept_palettes().
"""
new_errors = list(self._errors)
new_palettes_cam = np.empty_like(self._palettes_cam)
new_palettes_rgb12_iigs = np.empty_like(self._palettes_rgb)
@ -217,14 +205,15 @@ class ClusterPalette:
# individual palettes
self._fit_global_palette()
# self._mutate_palette_splits()
for palette_idx in range(16):
# print(palette_idx, self._palette_lines[palette_idx])
# palette_lower, palette_upper = self._palette_splits[palette_idx]
palette_pixels = (
self._colours_cam[
self._palette_lines[palette_idx], :, :].reshape(-1, 3))
# Fix reserved colours from the global palette and pick unique
# random colours from the sample points for the remaining initial
# centroids. This tends to increase the number of colours in the
# resulting image, and improves quality.
initial_centroids = self._global_palette
pixels_rgb_iigs = dither_pyx.convert_cam16ucs_to_rgb12_iigs(
palette_pixels)
@ -236,10 +225,8 @@ class ClusterPalette:
0])
new_colour = pixels_rgb_iigs[choice, :]
if tuple(new_colour) in seen_colours:
# print("Skipping")
continue
seen_colours.add(tuple(new_colour))
# print(i, choice)
initial_centroids[i, :] = new_colour
palettes_rgb12_iigs, palette_error = \
@ -251,13 +238,6 @@ class ClusterPalette:
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs
)
# if (palette_error >= self._errors[palette_idx] and not
# self._reserved_colours):
# # Not a local improvement to the existing palette, so ignore it.
# # We can't take this shortcut when we're reserving colours
# # because it would break the invariant that all palettes must
# # share colours.
# continue
for i in range(16):
new_palettes_cam[palette_idx, i, :] = (
np.array(dither_pyx.convert_rgb12_iigs_to_cam(
@ -265,10 +245,9 @@ class ClusterPalette:
i]), dtype=np.float32))
new_palettes_rgb12_iigs[palette_idx, :, :] = palettes_rgb12_iigs
new_errors[palette_idx] = palette_error
self._palettes_accepted = False
return new_palettes_cam, new_palettes_rgb12_iigs, new_errors
return new_palettes_cam, new_palettes_rgb12_iigs
def _fit_global_palette(self):
"""Compute a 16-colour palette for the entire image to use as
@ -299,7 +278,6 @@ class ClusterPalette:
best_palette_lines = [v for k, v in sorted(list(zip(
self._palette_line_errors, range(200))))]
# print(self._palette_lines)
for palette_idx, palette_used in enumerate(palettes_used):
if palette_used:
continue
@ -309,44 +287,6 @@ class ClusterPalette:
worst_line = best_palette_lines.pop()
self._palette_lines[palette_idx] = [worst_line]
# print("Picked line %d with error %f" % (worst_line,
# self._palette_line_errors[worst_line]))
#
# worst_average_palette_error = 0
# split_palette_idx = -1
# idx = 0
# for idx, lines in self._palette_lines.items():
# if len(lines) < 10:
# continue
# average_palette_error = np.sum(self._palette_line_errors[
# lines]) / len(lines)
# print(idx, average_palette_error)
# if average_palette_error > worst_average_palette_error:
# worst_average_palette_error = average_palette_error
# split_palette_idx = idx
#
# print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error))
# # TODO: split off lines with largest error
#
# palette_line_errors = self._palette_line_errors[
# self._palette_lines[split_palette_idx]]
#
# print(sorted(
# list(zip(palette_line_errors, self._palette_lines[
# split_palette_idx])), reverse=True))
# best_palette_lines = [v for k, v in sorted(
# list(zip(palette_line_errors, self._palette_lines[
# split_palette_idx])))]
# num_max_lines = len(self._palette_lines[split_palette_idx])
#
# self._palette_lines[split_palette_idx] = best_palette_lines[
# :num_max_lines // 2]
# # Move worst half to new palette
# self._palette_lines[palette_idx] = best_palette_lines[
# num_max_lines // 2:]
def main():
parser = argparse.ArgumentParser()
@ -410,7 +350,7 @@ def main():
# TODO: flags
penalty = 1 # 1e18 # TODO: is this needed any more?
inner_iterations = 10 # 20
inner_iterations = 10
outer_iterations = 20
pygame.init()
@ -476,7 +416,6 @@ def main():
canvas.blit(surface, (0, 0))
pygame.display.flip()
# print((palettes_rgb * 255).astype(np.uint8))
unique_colours = np.unique(
palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0]
print("%d unique colours" % unique_colours)