When initializing centroids for fitting the SHR palettes, only use the

reserved colours from the global palette, and pick unique random
points from the samples for the rest.  This encourages a larger range
of colours in the resulting images and may improve quality.

Iterate a max number of times without improvement in the outer loop as
well.

Save intermediate preview outputs.
This commit is contained in:
kris 2021-11-24 14:57:24 +00:00
parent 3b8767782b
commit c36de2b76b
1 changed files with 57 additions and 67 deletions

View File

@ -65,7 +65,7 @@ class ClusterPalette:
# List of line ranges used to train the 16 SHR palettes
# [(lower_0, upper_0), ...]
self._palette_splits = self._equal_palette_splits()
# self._palette_splits = self._equal_palette_splits()
self._init_palette_lines()
# Whether the previous iteration of proposed palettes was accepted
@ -78,10 +78,19 @@ class ClusterPalette:
self._palette_mutate_delta = (0, 0)
def _init_palette_lines(self):
for i, lh in enumerate(self._palette_splits):
palette_splits = self._equal_palette_splits()
for i, lh in enumerate(palette_splits):
l, h = lh
self._palette_lines[i].extend(list(range(l, h)))
# lines = list(range(200))
# random.shuffle(lines)
# idx = 0
# while lines:
# self._palette_lines[idx].append(lines.pop())
# idx += 1
def _image_colours_cam(self, image: Image):
colours_rgb = np.asarray(image) # .reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
@ -134,18 +143,21 @@ class ClusterPalette:
return (output_4bit, line_to_palette, palettes_linear_rgb,
total_image_error)
def iterate(self, penalty: float, max_iterations: int):
def iterate(self, penalty: float, max_inner_iterations: int,
max_outer_iterations: int):
total_image_error = 1e9
last_good_splits = self._palette_splits
while True:
# last_good_splits = self._palette_splits
outer_iterations_since_improvement = 0
while outer_iterations_since_improvement < max_outer_iterations:
print("New iteration")
iterations_since_improvement = 0
self._palette_splits = self._equal_palette_splits()
inner_iterations_since_improvement = 0
# self._palette_splits = self._equal_palette_splits()
self._init_palette_lines()
self._fit_global_palette()
while iterations_since_improvement < max_iterations:
# print("Iterations %d" % iterations_since_improvement)
while inner_iterations_since_improvement < max_inner_iterations:
# print("Iterations %d" % inner_iterations_since_improvement)
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
self._propose_palettes())
@ -155,17 +167,19 @@ class ClusterPalette:
new_total_image_error) = self._dither_image(
new_palettes_cam, penalty)
self._reassign_unused_palettes(line_to_palette,
last_good_splits)
# TODO: check for duplicate palettes and unused colours
# within a palette
self._reassign_unused_palettes(line_to_palette)
# print(total_image_error, new_total_image_error)
if new_total_image_error >= total_image_error:
iterations_since_improvement += 1
inner_iterations_since_improvement += 1
continue
# We found a globally better set of palettes
iterations_since_improvement = 0
last_good_splits = self._palette_splits
inner_iterations_since_improvement = 0
outer_iterations_since_improvement = -1
# last_good_splits = self._palette_splits
total_image_error = new_total_image_error
self._palettes_cam = new_palettes_cam
@ -175,6 +189,7 @@ class ClusterPalette:
yield (new_total_image_error, output_4bit, line_to_palette,
new_palettes_rgb12_iigs, palettes_linear_rgb)
outer_iterations_since_improvement += 1
def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
"""Attempt to find new palettes that locally improve image quality.
@ -202,7 +217,7 @@ class ClusterPalette:
# individual palettes
self._fit_global_palette()
self._mutate_palette_splits()
# self._mutate_palette_splits()
for palette_idx in range(16):
# print(palette_idx, self._palette_lines[palette_idx])
# palette_lower, palette_upper = self._palette_splits[palette_idx]
@ -210,11 +225,28 @@ class ClusterPalette:
self._colours_cam[
self._palette_lines[palette_idx], :, :].reshape(-1, 3))
initial_centroids = self._global_palette
pixels_rgb_iigs = dither_pyx.convert_cam16ucs_to_rgb12_iigs(
palette_pixels)
seen_colours = set()
for i in range(self._reserved_colours):
seen_colours.add(tuple(initial_centroids[i, :]))
for i in range(self._reserved_colours, 16):
choice = np.random.randint(0, pixels_rgb_iigs.shape[
0])
new_colour = pixels_rgb_iigs[choice, :]
if tuple(new_colour) in seen_colours:
# print("Skipping")
continue
seen_colours.add(tuple(new_colour))
# print(i, choice)
initial_centroids[i, :] = new_colour
palettes_rgb12_iigs, palette_error = \
dither_pyx.k_means_with_fixed_centroids(
n_clusters=16, n_fixed=self._reserved_colours,
samples=palette_pixels,
initial_centroids=self._global_palette,
initial_centroids=initial_centroids,
max_iterations=1000, tolerance=0.05,
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs
)
@ -260,50 +292,7 @@ class ClusterPalette:
clusters.cluster_centers_[frequency_order].astype(
np.float32)))
def _mutate_palette_splits(self):
if self._palettes_accepted:
# Last time was good, keep going
self._apply_palette_delta(self._palette_mutate_idx,
self._palette_mutate_delta[0],
self._palette_mutate_delta[1])
else:
# undo last mutation
self._apply_palette_delta(self._palette_mutate_idx,
-self._palette_mutate_delta[0],
-self._palette_mutate_delta[1])
# Pick a palette endpoint to move up or down
palette_to_mutate = np.random.randint(0, 16)
while True:
if palette_to_mutate > 0:
palette_lower_delta = np.random.randint(-20, 21)
else:
palette_lower_delta = 0
if palette_to_mutate < 15:
palette_upper_delta = np.random.randint(-20, 21)
else:
palette_upper_delta = 0
if palette_lower_delta != 0 or palette_upper_delta != 0:
break
self._apply_palette_delta(palette_to_mutate, palette_lower_delta,
palette_upper_delta)
def _apply_palette_delta(
self, palette_to_mutate, palette_lower_delta, palette_upper_delta):
old_lower, old_upper = self._palette_splits[palette_to_mutate]
new_lower = old_lower + palette_lower_delta
new_upper = old_upper + palette_upper_delta
new_lower = np.clip(new_lower, 0, np.clip(new_upper, 1, 200) - 1)
new_upper = np.clip(new_upper, new_lower + 1, 200)
assert new_lower >= 0, new_upper - 1
self._palette_splits[palette_to_mutate] = (new_lower, new_upper)
self._palette_mutate_idx = palette_to_mutate
self._palette_mutate_delta = (palette_lower_delta, palette_upper_delta)
def _reassign_unused_palettes(self, new_line_to_palette, last_good_splits):
def _reassign_unused_palettes(self, new_line_to_palette):
palettes_used = [False] * 16
for palette in new_line_to_palette:
palettes_used[palette] = True
@ -421,7 +410,8 @@ def main():
# TODO: flags
penalty = 1 # 1e18 # TODO: is this needed any more?
iterations = 10 # 20
inner_iterations = 10 # 20
outer_iterations = 20
pygame.init()
# TODO: for some reason I need to execute this twice - the first time
@ -441,7 +431,7 @@ def main():
seq = 0
for (new_total_image_error, output_4bit, line_to_palette,
palettes_rgb12_iigs, palettes_linear_rgb) in cluster_palette.iterate(
penalty, iterations):
penalty, inner_iterations, outer_iterations):
if total_image_error is not None:
print("Improved quality +%f%% (%f)" % (
@ -486,6 +476,11 @@ def main():
canvas.blit(surface, (0, 0))
pygame.display.flip()
# print((palettes_rgb * 255).astype(np.uint8))
unique_colours = np.unique(
palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0]
print("%d unique colours" % unique_colours)
seq += 1
# Save Double hi-res image
outfile = os.path.join(
@ -498,11 +493,6 @@ def main():
with open(args.output, "wb") as f:
f.write(bytes(screen.memory))
# print((palettes_rgb * 255).astype(np.uint8))
unique_colours = np.unique(
palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0]
print("%d unique colours" % unique_colours)
if __name__ == "__main__":