This commit is contained in:
kris 2021-11-23 14:51:04 +00:00
parent 6988b19b43
commit 0323b80e68
1 changed files with 85 additions and 73 deletions

View File

@ -28,7 +28,9 @@ import screen as screen_py
class ClusterPalette:
def __init__(
self, image: Image, rgb12_iigs_to_cam16ucs, reserved_colours=0):
self, image: Image, rgb12_iigs_to_cam16ucs, rgb24_to_cam16ucs,
reserved_colours=0):
self._image_rgb = image
self._colours_cam = self._image_colours_cam(image)
self._errors = [1e9] * 16
@ -55,6 +57,8 @@ class ClusterPalette:
# colour space
self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs
self._rgb24_to_cam16ucs = rgb24_to_cam16ucs
# List of line ranges used to train the 16 SHR palettes
# [(lower_0, upper_0), ...]
self._palette_splits = self._palette_splits()
@ -68,6 +72,75 @@ class ClusterPalette:
# Delta applied to palette split in previous iteration
self._palette_mutate_delta = (0, 0)
def iterate(self, penalty: float, max_iterations: int):
iterations_since_improvement = 0
total_image_error = 1e9
last_good_splits = self._palette_splits
while iterations_since_improvement < max_iterations:
# print("Iterations %d" % iterations_since_improvement)
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
self._propose_palettes())
# Suppress divide by zero warning,
# https://github.com/colour-science/colour/issues/900
with colour.utilities.suppress_warnings(python_warnings=True):
new_palettes_linear_rgb = colour.convert(
new_palettes_cam, "CAM16UCS", "RGB").astype(np.float32)
# Recompute image with proposed palettes and check whether it has
# lower total image error than our previous best.
new_output_4bit, new_line_to_palette, new_total_image_error = \
dither_pyx.dither_shr(
self._image_rgb, new_palettes_cam, new_palettes_linear_rgb,
self._rgb24_to_cam16ucs, float(penalty))
# print(total_image_error, new_total_image_error,
# self._palette_splits)
# TODO: extract this into a function
palettes_used = [False] * 16
for palette in new_line_to_palette:
palettes_used[palette] = True
for palette_idx, palette_used in enumerate(palettes_used):
if palette_used:
continue
print("Reassigning palette %d" % palette_idx)
max_width = 0
split_palette_idx = -1
idx = 0
for lower, upper in last_good_splits:
width = upper - lower
if width > max_width:
split_palette_idx = idx
idx += 1
lower, upper = last_good_splits[split_palette_idx]
if upper - lower > 20:
mid = (lower + upper) // 2
self._palette_splits[split_palette_idx] = (
lower, mid - 1)
self._palette_splits[palette_idx] = (mid, upper)
else:
lower = np.random.randint(0, 199)
upper = np.random.randint(lower, 200)
self._palette_splits[palette_idx] = (lower, upper)
if new_total_image_error >= total_image_error:
iterations_since_improvement += 1
continue
# We found a globally better set of palettes
iterations_since_improvement = 0
last_good_splits = self._palette_splits
total_image_error = new_total_image_error
self._accept_palettes(
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors)
yield (new_total_image_error, new_output_4bit, new_line_to_palette,
new_palettes_rgb12_iigs, new_palettes_linear_rgb)
def _image_colours_cam(self, image: Image):
colours_rgb = np.asarray(image).reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
@ -163,7 +236,7 @@ class ClusterPalette:
self._apply_palette_delta(palette_to_mutate, palette_lower_delta,
palette_upper_delta)
def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
"""Attempt to find new palettes that locally improve image quality.
Re-fit a set of 16 palettes from (overlapping) line ranges of the
@ -226,7 +299,7 @@ class ClusterPalette:
self._palettes_accepted = False
return new_palettes_cam, new_palettes_rgb12_iigs, new_errors
def accept_palettes(
def _accept_palettes(
self, new_palettes_cam: np.ndarray,
new_palettes_rgb: np.ndarray, new_errors: List[float]):
self._palettes_cam = np.copy(new_palettes_cam)
@ -297,7 +370,7 @@ def main():
# TODO: flags
penalty = 1 # 1e18 # TODO: is this needed any more?
iterations = 200
iterations = 20# 0
pygame.init()
# TODO: for some reason I need to execute this twice - the first time
@ -307,83 +380,22 @@ def main():
canvas.fill((0, 0, 0))
pygame.display.flip()
total_image_error = 1e9
iterations_since_improvement = 0
total_image_error = None
# TODO: reserved_colours should be a flag
cluster_palette = ClusterPalette(
rgb, reserved_colours=1, rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs)
last_good_splits = cluster_palette._palette_splits
rgb, reserved_colours=1,
rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs,
rgb24_to_cam16ucs=rgb24_to_cam16ucs)
while iterations_since_improvement < iterations:
# print("Iterations %d" % iterations_since_improvement)
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
cluster_palette.propose_palettes())
for (new_total_image_error, output_4bit, line_to_palette,
palettes_rgb12_iigs, palettes_linear_rgb) in cluster_palette.iterate(
penalty, iterations):
# Suppress divide by zero warning,
# https://github.com/colour-science/colour/issues/900
with colour.utilities.suppress_warnings(python_warnings=True):
new_palettes_linear_rgb = colour.convert(
new_palettes_cam, "CAM16UCS", "RGB").astype(np.float32)
# Recompute image with proposed palettes and check whether it has
# lower total image error than our previous best.
new_output_4bit, new_line_to_palette, new_total_image_error = \
dither_pyx.dither_shr(
rgb, new_palettes_cam, new_palettes_linear_rgb,
rgb24_to_cam16ucs, float(penalty))
# print(total_image_error, new_total_image_error,
# cluster_palette._palette_splits)
# TODO: move this into ClusterPalettes
palettes_used = [False] * 16
for palette in new_line_to_palette:
palettes_used[palette] = True
for palette_idx, palette_used in enumerate(palettes_used):
if palette_used:
continue
print("Reassigning palette %d" % palette_idx)
max_width = 0
split_palette_idx = -1
idx = 0
for lower, upper in last_good_splits:
width = upper - lower
if width > max_width:
split_palette_idx = idx
idx += 1
lower, upper = last_good_splits[split_palette_idx]
if upper - lower > 20:
mid = (lower + upper) // 2
cluster_palette._palette_splits[split_palette_idx] = (
lower, mid - 1)
cluster_palette._palette_splits[palette_idx] = (mid, upper)
else:
lower = np.random.randint(0, 199)
upper = np.random.randint(lower, 200)
cluster_palette._palette_splits[palette_idx] = (lower, upper)
if new_total_image_error >= total_image_error:
iterations_since_improvement += 1
continue
# We found a globally better set of palettes
iterations_since_improvement = 0
cluster_palette.accept_palettes(
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors)
last_good_splits = cluster_palette._palette_splits
if total_image_error < 1e9:
if total_image_error is not None:
print("Improved quality +%f%% (%f)" % (
(1 - new_total_image_error / total_image_error) * 100,
new_total_image_error))
# print(cluster_palette._palette_splits)
output_4bit = new_output_4bit
line_to_palette = new_line_to_palette
total_image_error = new_total_image_error
palettes_rgb12_iigs = new_palettes_rgb12_iigs
palettes_linear_rgb = new_palettes_linear_rgb
for i in range(16):
screen.set_palette(i, palettes_rgb12_iigs[i, :, :])