Dynamically tune the line ranges used to fit the 16 SHR palettes:

- start with an equal split
- with each iteration, pick a palette and adjust its line ranges by a small random amount
- if the proposed palette is accepted, continue to apply the same delta
- if not, revert the adjustment and pick a different one

In addition, often there will be palettes that are entirely unused by
the image.  For such palettes:

- find the palette with the largest line range.  If > 20, then
  subdivide this range and assign half each to both palettes
- if not, then pick a random line range for the unused palette

This helps to refine and explore more of the parameter space.
This commit is contained in:
kris 2021-11-23 13:01:50 +00:00
parent 189b4655ad
commit 6e52680cf1
1 changed files with 120 additions and 10 deletions

View File

@ -30,13 +30,44 @@ class ClusterPalette:
def __init__(
self, image: Image, rgb12_iigs_to_cam16ucs, reserved_colours=0):
self._colours_cam = self._image_colours_cam(image)
self._reserved_colours = reserved_colours
self._errors = [1e9] * 16
# We fit a 16-colour palette against the entire image which is used
# as starting values for fitting the 16 SHR palettes. This helps to
# provide better global consistency of colours across the palettes,
# e.g. for large blocks of colour. Otherwise these can take a while
# to converge.
self._global_palette = np.empty((16, 3), dtype=np.uint8)
# How many image colours to fix identically across all 16 SHR
# palettes. These are taken to be the most prevalent colours from
# _global_palette.
self._reserved_colours = reserved_colours
# 16 SHR palettes each of 16 colours, in CAM16UCS format
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
# 16 SHR palettes each of 16 colours, in //gs 4-bit RGB format
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.uint8)
self._global_palette = np.empty((16, 16, 3), dtype=np.float32)
# Conversion matrix from 12-bit //gs RGB colour space to CAM16UCS
# colour space
self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs
# List of line ranges used to train the 16 SHR palettes
# [(lower_0, upper_0), ...]
self._palette_splits = self._palette_splits()
# Whether the previous iteration of proposed palettes was accepted
self._palettes_accepted = False
# Which palette index's line ranges did we mutate in previous iteration
self._palette_mutate_idx = 0
# Delta applied to palette split in previous iteration
self._palette_mutate_delta = (0, 0)
def _image_colours_cam(self, image: Image):
colours_rgb = np.asarray(image).reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
@ -60,10 +91,9 @@ class ClusterPalette:
list(zip(*np.unique(labels, return_counts=True))),
key=lambda kv: kv[1], reverse=True)]
res = dither_pyx.convert_cam16ucs_to_rgb12_iigs(
return dither_pyx.convert_cam16ucs_to_rgb12_iigs(
clusters.cluster_centers_[frequency_order].astype(
np.float32))
return res
def _palette_splits(self, palette_height=35):
# The 16 palettes are striped across consecutive (overlapping) line
@ -74,7 +104,7 @@ class ClusterPalette:
# has height H and overlaps the previous one by L lines, then the
# boundaries are at lines:
# (0, H), (H-L, 2H-L), (2H-2L, 3H-2L), ..., (15H-15L, 16H - 15L)
# i.e. 16H - 15L = 200, sofor a given palette height H we need to
# i.e. 16H - 15L = 200, so for a given palette height H we need to
# overlap by:
# L = (16H - 200)/15
@ -86,9 +116,51 @@ class ClusterPalette:
palette_upper = palette_lower + palette_height
palette_ranges.append((int(np.round(palette_lower)),
int(np.round(palette_upper))))
# print(palette_ranges)
return palette_ranges
def _apply_palette_delta(
self, palette_to_mutate, palette_lower_delta, palette_upper_delta):
old_lower, old_upper = self._palette_splits[palette_to_mutate]
new_lower = old_lower + palette_lower_delta
new_upper = old_upper + palette_upper_delta
new_lower = np.clip(new_lower, 0, np.clip(new_upper, 1, 200) - 1)
new_upper = np.clip(new_upper, new_lower + 1, 200)
assert new_lower >= 0, new_upper-1
self._palette_splits[palette_to_mutate] = (new_lower, new_upper)
self._palette_mutate_idx = palette_to_mutate
self._palette_mutate_delta = (palette_lower_delta, palette_upper_delta)
def _mutate_palette_splits(self):
if self._palettes_accepted:
# Last time was good, keep going
self._apply_palette_delta(self._palette_mutate_idx,
self._palette_mutate_delta[0],
self._palette_mutate_delta[1])
else:
# undo last mutation
self._apply_palette_delta(self._palette_mutate_idx,
-self._palette_mutate_delta[0],
-self._palette_mutate_delta[1])
# Pick a palette endpoint to move up or down
palette_to_mutate = np.random.randint(0, 16)
while True:
if palette_to_mutate > 0:
palette_lower_delta = np.random.randint(-20, 21)
else:
palette_lower_delta = 0
if palette_to_mutate < 15:
palette_upper_delta = np.random.randint(-20, 21)
else:
palette_upper_delta = 0
if palette_lower_delta != 0 or palette_upper_delta != 0:
break
self._apply_palette_delta(palette_to_mutate, palette_lower_delta,
palette_upper_delta)
def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
"""Attempt to find new palettes that locally improve image quality.
@ -117,9 +189,9 @@ class ClusterPalette:
dynamic_colours = 16 - self._reserved_colours
palette_splits = self._palette_splits()
self._mutate_palette_splits()
for palette_idx in range(16):
palette_lower, palette_upper = palette_splits[palette_idx]
palette_lower, palette_upper = self._palette_splits[palette_idx]
# TODO: dynamically tune palette cuts
palette_pixels = self._colours_cam[
palette_lower * 320:palette_upper * 320, :]
@ -149,6 +221,7 @@ class ClusterPalette:
new_palettes_rgb12_iigs[palette_idx, :, :] = palettes_rgb12_iigs
new_errors[palette_idx] = palette_error
self._palettes_accepted = False
return new_palettes_cam, new_palettes_rgb12_iigs, new_errors
def accept_palettes(
@ -157,6 +230,7 @@ class ClusterPalette:
self._palettes_cam = np.copy(new_palettes_cam)
self._palettes_rgb = np.copy(new_palettes_rgb)
self._errors = list(new_errors)
self._palettes_accepted = True
def main():
@ -220,8 +294,8 @@ def main():
gamma=args.gamma_correct)).astype(np.float32) / 255
# TODO: flags
penalty = 1e9
iterations = 50
penalty = 1 # 1e18 # TODO: is this needed any more?
iterations = 200
pygame.init()
# TODO: for some reason I need to execute this twice - the first time
@ -234,10 +308,13 @@ def main():
total_image_error = 1e9
iterations_since_improvement = 0
# TODO: reserved_colours should be a flag
cluster_palette = ClusterPalette(
rgb, reserved_colours=1, rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs)
last_good_splits = cluster_palette._palette_splits
while iterations_since_improvement < iterations:
# print("Iterations %d" % iterations_since_improvement)
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
cluster_palette.propose_palettes())
@ -253,6 +330,37 @@ def main():
dither_pyx.dither_shr(
rgb, new_palettes_cam, new_palettes_linear_rgb,
rgb24_to_cam16ucs, float(penalty))
# print(total_image_error, new_total_image_error,
# cluster_palette._palette_splits)
# TODO: move this into ClusterPalettes
palettes_used = [False] * 16
for palette in new_line_to_palette:
palettes_used[palette] = True
for palette_idx, palette in enumerate(palettes_used):
if palette:
continue
print("Reassigning palette %d" % palette_idx)
max_width = 0
split_palette_idx = -1
idx = 0
for lower, upper in last_good_splits:
width = upper - lower
if width > max_width:
split_palette_idx = idx
idx += 1
lower, upper = last_good_splits[split_palette_idx]
if upper - lower > 20:
mid = (lower + upper) // 2
cluster_palette._palette_splits[split_palette_idx] = (lower, mid)
cluster_palette._palette_splits[palette_idx] = (mid, upper)
else:
lower = np.random.randint(0, 199)
upper = np.random.randint(lower, 200)
cluster_palette._palette_splits[palette_idx] = (lower, upper)
if new_total_image_error >= total_image_error:
iterations_since_improvement += 1
continue
@ -261,11 +369,13 @@ def main():
iterations_since_improvement = 0
cluster_palette.accept_palettes(
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors)
last_good_splits = cluster_palette._palette_splits
if total_image_error < 1e9:
print("Improved quality +%f%% (%f)" % (
(1 - new_total_image_error / total_image_error) * 100,
new_total_image_error))
# print(cluster_palette._palette_splits)
output_4bit = new_output_4bit
line_to_palette = new_line_to_palette
total_image_error = new_total_image_error