Each run seems to converge fairly quickly but there is a lot of variation across runs. Run in a loop and keep the running best.

This commit is contained in:
kris 2021-11-24 11:47:39 +00:00
parent de8a303de2
commit 3b8767782b
1 changed files with 105 additions and 79 deletions

View File

@ -61,9 +61,12 @@ class ClusterPalette:
self._rgb24_to_cam16ucs = rgb24_to_cam16ucs
self._palette_lines = defaultdict(list)
# List of line ranges used to train the 16 SHR palettes
# [(lower_0, upper_0), ...]
self._palette_splits = self._equal_palette_splits()
self._init_palette_lines()
# Whether the previous iteration of proposed palettes was accepted
self._palettes_accepted = False
@ -74,7 +77,7 @@ class ClusterPalette:
# Delta applied to palette split in previous iteration
self._palette_mutate_delta = (0, 0)
self._palette_lines = defaultdict(list)
def _init_palette_lines(self):
for i, lh in enumerate(self._palette_splits):
l, h = lh
self._palette_lines[i].extend(list(range(l, h)))
@ -132,41 +135,46 @@ class ClusterPalette:
total_image_error)
def iterate(self, penalty: float, max_iterations: int):
iterations_since_improvement = 0
total_image_error = 1e9
last_good_splits = self._palette_splits
self._fit_global_palette()
while iterations_since_improvement < max_iterations:
# print("Iterations %d" % iterations_since_improvement)
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
self._propose_palettes())
# Recompute image with proposed palettes and check whether it has
# lower total image error than our previous best.
(output_4bit, line_to_palette, palettes_linear_rgb,
new_total_image_error) = self._dither_image(
new_palettes_cam, penalty)
self._reassign_unused_palettes(line_to_palette,
last_good_splits)
print(total_image_error, new_total_image_error)
if new_total_image_error >= total_image_error:
iterations_since_improvement += 1
continue
# We found a globally better set of palettes
while True:
print("New iteration")
iterations_since_improvement = 0
last_good_splits = self._palette_splits
total_image_error = new_total_image_error
self._palette_splits = self._equal_palette_splits()
self._init_palette_lines()
self._palettes_cam = new_palettes_cam
self._palettes_rgb = new_palettes_rgb12_iigs
self._errors = new_palette_errors
self._palettes_accepted = True
self._fit_global_palette()
while iterations_since_improvement < max_iterations:
# print("Iterations %d" % iterations_since_improvement)
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
self._propose_palettes())
yield (new_total_image_error, output_4bit, line_to_palette,
new_palettes_rgb12_iigs, palettes_linear_rgb)
# Recompute image with proposed palettes and check whether it has
# lower total image error than our previous best.
(output_4bit, line_to_palette, palettes_linear_rgb,
new_total_image_error) = self._dither_image(
new_palettes_cam, penalty)
self._reassign_unused_palettes(line_to_palette,
last_good_splits)
# print(total_image_error, new_total_image_error)
if new_total_image_error >= total_image_error:
iterations_since_improvement += 1
continue
# We found a globally better set of palettes
iterations_since_improvement = 0
last_good_splits = self._palette_splits
total_image_error = new_total_image_error
self._palettes_cam = new_palettes_cam
self._palettes_rgb = new_palettes_rgb12_iigs
self._errors = new_palette_errors
self._palettes_accepted = True
yield (new_total_image_error, output_4bit, line_to_palette,
new_palettes_rgb12_iigs, palettes_linear_rgb)
def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
"""Attempt to find new palettes that locally improve image quality.
@ -192,7 +200,7 @@ class ClusterPalette:
# Compute a new 16-colour global palette for the entire image,
# used as the starting center positions for k-means clustering of the
# individual palettes
# self._fit_global_palette()
self._fit_global_palette()
self._mutate_palette_splits()
for palette_idx in range(16):
@ -211,13 +219,13 @@ class ClusterPalette:
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs
)
if (palette_error >= self._errors[palette_idx] and not
self._reserved_colours):
# Not a local improvement to the existing palette, so ignore it.
# We can't take this shortcut when we're reserving colours
# because it would break the invariant that all palettes must
# share colours.
continue
# if (palette_error >= self._errors[palette_idx] and not
# self._reserved_colours):
# # Not a local improvement to the existing palette, so ignore it.
# # We can't take this shortcut when we're reserving colours
# # because it would break the invariant that all palettes must
# # share colours.
# continue
for i in range(16):
new_palettes_cam[palette_idx, i, :] = (
np.array(dither_pyx.convert_rgb12_iigs_to_cam(
@ -299,43 +307,56 @@ class ClusterPalette:
palettes_used = [False] * 16
for palette in new_line_to_palette:
palettes_used[palette] = True
best_palette_lines = [v for k, v in sorted(list(zip(
self._palette_line_errors, range(200))))]
# print(self._palette_lines)
for palette_idx, palette_used in enumerate(palettes_used):
if palette_used:
continue
print("Reassigning palette %d" % palette_idx)
# print("Reassigning palette %d" % palette_idx)
worst_average_palette_error = 0
split_palette_idx = -1
idx = 0
for idx, lines in self._palette_lines.items():
if len(lines) < 10:
continue
average_palette_error = np.sum(self._palette_line_errors[
lines]) / len(lines)
print(idx, average_palette_error)
if average_palette_error > worst_average_palette_error:
worst_average_palette_error = average_palette_error
split_palette_idx = idx
# TODO: also remove from old entry
worst_line = best_palette_lines.pop()
self._palette_lines[palette_idx] = [worst_line]
print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error))
# TODO: split off lines with largest error
# print("Picked line %d with error %f" % (worst_line,
# self._palette_line_errors[worst_line]))
palette_line_errors = self._palette_line_errors[
self._palette_lines[split_palette_idx]]
print(sorted(
list(zip(palette_line_errors, self._palette_lines[
split_palette_idx])), reverse=True))
best_palette_lines = [v for k, v in sorted(
list(zip(palette_line_errors, self._palette_lines[
split_palette_idx])))]
num_max_lines = len(self._palette_lines[split_palette_idx])
self._palette_lines[split_palette_idx] = best_palette_lines[
:num_max_lines // 2]
# Move worst half to new palette
self._palette_lines[palette_idx] = best_palette_lines[
num_max_lines // 2:]
#
# worst_average_palette_error = 0
# split_palette_idx = -1
# idx = 0
# for idx, lines in self._palette_lines.items():
# if len(lines) < 10:
# continue
# average_palette_error = np.sum(self._palette_line_errors[
# lines]) / len(lines)
# print(idx, average_palette_error)
# if average_palette_error > worst_average_palette_error:
# worst_average_palette_error = average_palette_error
# split_palette_idx = idx
#
# print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error))
# # TODO: split off lines with largest error
#
# palette_line_errors = self._palette_line_errors[
# self._palette_lines[split_palette_idx]]
#
# print(sorted(
# list(zip(palette_line_errors, self._palette_lines[
# split_palette_idx])), reverse=True))
# best_palette_lines = [v for k, v in sorted(
# list(zip(palette_line_errors, self._palette_lines[
# split_palette_idx])))]
# num_max_lines = len(self._palette_lines[split_palette_idx])
#
# self._palette_lines[split_palette_idx] = best_palette_lines[
# :num_max_lines // 2]
# # Move worst half to new palette
# self._palette_lines[palette_idx] = best_palette_lines[
# num_max_lines // 2:]
def main():
@ -400,7 +421,7 @@ def main():
# TODO: flags
penalty = 1 # 1e18 # TODO: is this needed any more?
iterations = 200
iterations = 10 # 20
pygame.init()
# TODO: for some reason I need to execute this twice - the first time
@ -417,6 +438,7 @@ def main():
rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs,
rgb24_to_cam16ucs=rgb24_to_cam16ucs)
seq = 0
for (new_total_image_error, output_4bit, line_to_palette,
palettes_rgb12_iigs, palettes_linear_rgb) in cluster_palette.iterate(
penalty, iterations):
@ -463,20 +485,24 @@ def main():
np.asarray(out_image).transpose((1, 0, 2))) # flip y/x axes
canvas.blit(surface, (0, 0))
pygame.display.flip()
seq += 1
# Save Double hi-res image
outfile = os.path.join(
os.path.splitext(args.output)[0] + "-%d-preview.png" % seq)
out_image.save(outfile, "PNG")
screen.pack()
# with open(args.output, "wb") as f:
# f.write(bytes(screen.aux))
# f.write(bytes(screen.main))
with open(args.output, "wb") as f:
f.write(bytes(screen.memory))
# print((palettes_rgb * 255).astype(np.uint8))
unique_colours = np.unique(
palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0]
print("%d unique colours" % unique_colours)
# Save Double hi-res image
outfile = os.path.join(os.path.splitext(args.output)[0] + "-preview.png")
out_image.save(outfile, "PNG")
screen.pack()
# with open(args.output, "wb") as f:
# f.write(bytes(screen.aux))
# f.write(bytes(screen.main))
with open(args.output, "wb") as f:
f.write(bytes(screen.memory))
if __name__ == "__main__":