Tidy
This commit is contained in:
parent
c36de2b76b
commit
d645cc5964
157
convert.py
157
convert.py
|
@ -32,65 +32,40 @@ class ClusterPalette:
|
|||
def __init__(
|
||||
self, image: Image, rgb12_iigs_to_cam16ucs, rgb24_to_cam16ucs,
|
||||
reserved_colours=0):
|
||||
|
||||
# Source image in 24-bit linear RGB colour space
|
||||
self._image_rgb = image
|
||||
|
||||
# Source image in CAM16UCS colour space
|
||||
self._colours_cam = self._image_colours_cam(image)
|
||||
|
||||
self._errors = [1e9] * 16
|
||||
|
||||
# We fit a 16-colour palette against the entire image which is used
|
||||
# as starting values for fitting the 16 SHR palettes. This helps to
|
||||
# provide better global consistency of colours across the palettes,
|
||||
# e.g. for large blocks of colour. Otherwise these can take a while
|
||||
# to converge.
|
||||
self._global_palette = np.empty((16, 3), dtype=np.uint8)
|
||||
|
||||
# How many image colours to fix identically across all 16 SHR
|
||||
# palettes. These are taken to be the most prevalent colours from
|
||||
# _global_palette.
|
||||
self._reserved_colours = reserved_colours
|
||||
|
||||
# 16 SHR palettes each of 16 colours, in CAM16UCS format
|
||||
# We fit a 16-colour palette against the entire image which is used
|
||||
# as starting values for fitting the reserved colours in the 16 SHR
|
||||
# palettes.
|
||||
self._global_palette = np.empty((16, 3), dtype=np.uint8)
|
||||
|
||||
# 16 SHR palettes each of 16 colours, in CAM16UCS colour space
|
||||
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
|
||||
|
||||
# 16 SHR palettes each of 16 colours, in //gs 4-bit RGB format
|
||||
# 16 SHR palettes each of 16 colours, in //gs 4-bit RGB colour space
|
||||
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.uint8)
|
||||
|
||||
# defaultdict(list) mapping palette index to lines using this palette
|
||||
self._palette_lines = self._init_palette_lines()
|
||||
|
||||
# Conversion matrix from 12-bit //gs RGB colour space to CAM16UCS
|
||||
# colour space
|
||||
self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs
|
||||
|
||||
# Conversion matrix from 24-bit linear RGB colour space to CAM16UCS
|
||||
# colour space
|
||||
self._rgb24_to_cam16ucs = rgb24_to_cam16ucs
|
||||
|
||||
self._palette_lines = defaultdict(list)
|
||||
|
||||
# List of line ranges used to train the 16 SHR palettes
|
||||
# [(lower_0, upper_0), ...]
|
||||
# self._palette_splits = self._equal_palette_splits()
|
||||
self._init_palette_lines()
|
||||
|
||||
# Whether the previous iteration of proposed palettes was accepted
|
||||
self._palettes_accepted = False
|
||||
|
||||
# Which palette index's line ranges did we mutate in previous iteration
|
||||
self._palette_mutate_idx = 0
|
||||
|
||||
# Delta applied to palette split in previous iteration
|
||||
self._palette_mutate_delta = (0, 0)
|
||||
|
||||
def _init_palette_lines(self):
|
||||
palette_splits = self._equal_palette_splits()
|
||||
for i, lh in enumerate(palette_splits):
|
||||
l, h = lh
|
||||
self._palette_lines[i].extend(list(range(l, h)))
|
||||
|
||||
# lines = list(range(200))
|
||||
# random.shuffle(lines)
|
||||
# idx = 0
|
||||
# while lines:
|
||||
# self._palette_lines[idx].append(lines.pop())
|
||||
# idx += 1
|
||||
|
||||
|
||||
def _image_colours_cam(self, image: Image):
|
||||
colours_rgb = np.asarray(image) # .reshape((-1, 3))
|
||||
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
|
||||
|
@ -98,6 +73,23 @@ class ClusterPalette:
|
|||
"CAM16UCS").astype(np.float32)
|
||||
return colours_cam
|
||||
|
||||
def _init_palette_lines(self, init_random = False):
|
||||
palette_lines = defaultdict(list)
|
||||
|
||||
if init_random:
|
||||
lines = list(range(200))
|
||||
random.shuffle(lines)
|
||||
idx = 0
|
||||
while lines:
|
||||
palette_lines[idx].append(lines.pop())
|
||||
idx += 1
|
||||
else:
|
||||
palette_splits = self._equal_palette_splits()
|
||||
for i, lh in enumerate(palette_splits):
|
||||
l, h = lh
|
||||
palette_lines[i].extend(list(range(l, h)))
|
||||
return palette_lines
|
||||
|
||||
def _equal_palette_splits(self, palette_height=35):
|
||||
# The 16 palettes are striped across consecutive (overlapping) line
|
||||
# ranges. Since nearby lines tend to have similar colours, this has
|
||||
|
@ -133,6 +125,8 @@ class ClusterPalette:
|
|||
self._image_rgb, palettes_cam, palettes_linear_rgb,
|
||||
self._rgb24_to_cam16ucs, float(penalty))
|
||||
|
||||
# Update map of palettes to image lines for which the palette was the
|
||||
# best match
|
||||
palette_lines = defaultdict(list)
|
||||
for line, palette in enumerate(line_to_palette):
|
||||
palette_lines[palette].append(line)
|
||||
|
@ -146,23 +140,21 @@ class ClusterPalette:
|
|||
def iterate(self, penalty: float, max_inner_iterations: int,
|
||||
max_outer_iterations: int):
|
||||
total_image_error = 1e9
|
||||
# last_good_splits = self._palette_splits
|
||||
|
||||
outer_iterations_since_improvement = 0
|
||||
while outer_iterations_since_improvement < max_outer_iterations:
|
||||
print("New iteration")
|
||||
inner_iterations_since_improvement = 0
|
||||
# self._palette_splits = self._equal_palette_splits()
|
||||
self._init_palette_lines()
|
||||
self._palette_lines = self._init_palette_lines()
|
||||
|
||||
self._fit_global_palette()
|
||||
while inner_iterations_since_improvement < max_inner_iterations:
|
||||
# print("Iterations %d" % inner_iterations_since_improvement)
|
||||
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
|
||||
new_palettes_cam, new_palettes_rgb12_iigs = (
|
||||
self._propose_palettes())
|
||||
|
||||
# Recompute image with proposed palettes and check whether it has
|
||||
# lower total image error than our previous best.
|
||||
# Recompute image with proposed palettes and check whether it
|
||||
# has lower total image error than our previous best.
|
||||
(output_4bit, line_to_palette, palettes_linear_rgb,
|
||||
new_total_image_error) = self._dither_image(
|
||||
new_palettes_cam, penalty)
|
||||
|
@ -171,27 +163,24 @@ class ClusterPalette:
|
|||
# within a palette
|
||||
self._reassign_unused_palettes(line_to_palette)
|
||||
|
||||
# print(total_image_error, new_total_image_error)
|
||||
if new_total_image_error >= total_image_error:
|
||||
inner_iterations_since_improvement += 1
|
||||
continue
|
||||
|
||||
# We found a globally better set of palettes
|
||||
# We found a globally better set of palettes, so restart the
|
||||
# clocks
|
||||
inner_iterations_since_improvement = 0
|
||||
outer_iterations_since_improvement = -1
|
||||
# last_good_splits = self._palette_splits
|
||||
total_image_error = new_total_image_error
|
||||
|
||||
self._palettes_cam = new_palettes_cam
|
||||
self._palettes_rgb = new_palettes_rgb12_iigs
|
||||
self._errors = new_palette_errors
|
||||
self._palettes_accepted = True
|
||||
|
||||
yield (new_total_image_error, output_4bit, line_to_palette,
|
||||
new_palettes_rgb12_iigs, palettes_linear_rgb)
|
||||
outer_iterations_since_improvement += 1
|
||||
|
||||
def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
|
||||
def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Attempt to find new palettes that locally improve image quality.
|
||||
|
||||
Re-fit a set of 16 palettes from (overlapping) line ranges of the
|
||||
|
@ -208,7 +197,6 @@ class ClusterPalette:
|
|||
The current (locally) best palettes are returned and can be applied
|
||||
using accept_palettes().
|
||||
"""
|
||||
new_errors = list(self._errors)
|
||||
new_palettes_cam = np.empty_like(self._palettes_cam)
|
||||
new_palettes_rgb12_iigs = np.empty_like(self._palettes_rgb)
|
||||
|
||||
|
@ -217,14 +205,15 @@ class ClusterPalette:
|
|||
# individual palettes
|
||||
self._fit_global_palette()
|
||||
|
||||
# self._mutate_palette_splits()
|
||||
for palette_idx in range(16):
|
||||
# print(palette_idx, self._palette_lines[palette_idx])
|
||||
# palette_lower, palette_upper = self._palette_splits[palette_idx]
|
||||
palette_pixels = (
|
||||
self._colours_cam[
|
||||
self._palette_lines[palette_idx], :, :].reshape(-1, 3))
|
||||
|
||||
# Fix reserved colours from the global palette and pick unique
|
||||
# random colours from the sample points for the remaining initial
|
||||
# centroids. This tends to increase the number of colours in the
|
||||
# resulting image, and improves quality.
|
||||
initial_centroids = self._global_palette
|
||||
pixels_rgb_iigs = dither_pyx.convert_cam16ucs_to_rgb12_iigs(
|
||||
palette_pixels)
|
||||
|
@ -236,10 +225,8 @@ class ClusterPalette:
|
|||
0])
|
||||
new_colour = pixels_rgb_iigs[choice, :]
|
||||
if tuple(new_colour) in seen_colours:
|
||||
# print("Skipping")
|
||||
continue
|
||||
seen_colours.add(tuple(new_colour))
|
||||
# print(i, choice)
|
||||
initial_centroids[i, :] = new_colour
|
||||
|
||||
palettes_rgb12_iigs, palette_error = \
|
||||
|
@ -251,13 +238,6 @@ class ClusterPalette:
|
|||
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs
|
||||
)
|
||||
|
||||
# if (palette_error >= self._errors[palette_idx] and not
|
||||
# self._reserved_colours):
|
||||
# # Not a local improvement to the existing palette, so ignore it.
|
||||
# # We can't take this shortcut when we're reserving colours
|
||||
# # because it would break the invariant that all palettes must
|
||||
# # share colours.
|
||||
# continue
|
||||
for i in range(16):
|
||||
new_palettes_cam[palette_idx, i, :] = (
|
||||
np.array(dither_pyx.convert_rgb12_iigs_to_cam(
|
||||
|
@ -265,10 +245,9 @@ class ClusterPalette:
|
|||
i]), dtype=np.float32))
|
||||
|
||||
new_palettes_rgb12_iigs[palette_idx, :, :] = palettes_rgb12_iigs
|
||||
new_errors[palette_idx] = palette_error
|
||||
|
||||
self._palettes_accepted = False
|
||||
return new_palettes_cam, new_palettes_rgb12_iigs, new_errors
|
||||
return new_palettes_cam, new_palettes_rgb12_iigs
|
||||
|
||||
def _fit_global_palette(self):
|
||||
"""Compute a 16-colour palette for the entire image to use as
|
||||
|
@ -299,7 +278,6 @@ class ClusterPalette:
|
|||
best_palette_lines = [v for k, v in sorted(list(zip(
|
||||
self._palette_line_errors, range(200))))]
|
||||
|
||||
# print(self._palette_lines)
|
||||
for palette_idx, palette_used in enumerate(palettes_used):
|
||||
if palette_used:
|
||||
continue
|
||||
|
@ -309,44 +287,6 @@ class ClusterPalette:
|
|||
worst_line = best_palette_lines.pop()
|
||||
self._palette_lines[palette_idx] = [worst_line]
|
||||
|
||||
# print("Picked line %d with error %f" % (worst_line,
|
||||
# self._palette_line_errors[worst_line]))
|
||||
|
||||
|
||||
#
|
||||
# worst_average_palette_error = 0
|
||||
# split_palette_idx = -1
|
||||
# idx = 0
|
||||
# for idx, lines in self._palette_lines.items():
|
||||
# if len(lines) < 10:
|
||||
# continue
|
||||
# average_palette_error = np.sum(self._palette_line_errors[
|
||||
# lines]) / len(lines)
|
||||
# print(idx, average_palette_error)
|
||||
# if average_palette_error > worst_average_palette_error:
|
||||
# worst_average_palette_error = average_palette_error
|
||||
# split_palette_idx = idx
|
||||
#
|
||||
# print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error))
|
||||
# # TODO: split off lines with largest error
|
||||
#
|
||||
# palette_line_errors = self._palette_line_errors[
|
||||
# self._palette_lines[split_palette_idx]]
|
||||
#
|
||||
# print(sorted(
|
||||
# list(zip(palette_line_errors, self._palette_lines[
|
||||
# split_palette_idx])), reverse=True))
|
||||
# best_palette_lines = [v for k, v in sorted(
|
||||
# list(zip(palette_line_errors, self._palette_lines[
|
||||
# split_palette_idx])))]
|
||||
# num_max_lines = len(self._palette_lines[split_palette_idx])
|
||||
#
|
||||
# self._palette_lines[split_palette_idx] = best_palette_lines[
|
||||
# :num_max_lines // 2]
|
||||
# # Move worst half to new palette
|
||||
# self._palette_lines[palette_idx] = best_palette_lines[
|
||||
# num_max_lines // 2:]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -410,7 +350,7 @@ def main():
|
|||
|
||||
# TODO: flags
|
||||
penalty = 1 # 1e18 # TODO: is this needed any more?
|
||||
inner_iterations = 10 # 20
|
||||
inner_iterations = 10
|
||||
outer_iterations = 20
|
||||
|
||||
pygame.init()
|
||||
|
@ -476,7 +416,6 @@ def main():
|
|||
canvas.blit(surface, (0, 0))
|
||||
pygame.display.flip()
|
||||
|
||||
# print((palettes_rgb * 255).astype(np.uint8))
|
||||
unique_colours = np.unique(
|
||||
palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0]
|
||||
print("%d unique colours" % unique_colours)
|
||||
|
|
Loading…
Reference in New Issue