mirror of
https://github.com/KrisKennaway/ii-pix.git
synced 2025-01-30 15:36:53 +00:00
Each run seems to converge fairly quickly but there is a lot of variation across runs. Run in a loop and keep the running best.
This commit is contained in:
parent
de8a303de2
commit
3b8767782b
184
convert.py
184
convert.py
@ -61,9 +61,12 @@ class ClusterPalette:
|
||||
|
||||
self._rgb24_to_cam16ucs = rgb24_to_cam16ucs
|
||||
|
||||
self._palette_lines = defaultdict(list)
|
||||
|
||||
# List of line ranges used to train the 16 SHR palettes
|
||||
# [(lower_0, upper_0), ...]
|
||||
self._palette_splits = self._equal_palette_splits()
|
||||
self._init_palette_lines()
|
||||
|
||||
# Whether the previous iteration of proposed palettes was accepted
|
||||
self._palettes_accepted = False
|
||||
@ -74,7 +77,7 @@ class ClusterPalette:
|
||||
# Delta applied to palette split in previous iteration
|
||||
self._palette_mutate_delta = (0, 0)
|
||||
|
||||
self._palette_lines = defaultdict(list)
|
||||
def _init_palette_lines(self):
|
||||
for i, lh in enumerate(self._palette_splits):
|
||||
l, h = lh
|
||||
self._palette_lines[i].extend(list(range(l, h)))
|
||||
@ -132,41 +135,46 @@ class ClusterPalette:
|
||||
total_image_error)
|
||||
|
||||
def iterate(self, penalty: float, max_iterations: int):
|
||||
iterations_since_improvement = 0
|
||||
total_image_error = 1e9
|
||||
last_good_splits = self._palette_splits
|
||||
self._fit_global_palette()
|
||||
while iterations_since_improvement < max_iterations:
|
||||
# print("Iterations %d" % iterations_since_improvement)
|
||||
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
|
||||
self._propose_palettes())
|
||||
|
||||
# Recompute image with proposed palettes and check whether it has
|
||||
# lower total image error than our previous best.
|
||||
(output_4bit, line_to_palette, palettes_linear_rgb,
|
||||
new_total_image_error) = self._dither_image(
|
||||
new_palettes_cam, penalty)
|
||||
|
||||
self._reassign_unused_palettes(line_to_palette,
|
||||
last_good_splits)
|
||||
|
||||
print(total_image_error, new_total_image_error)
|
||||
if new_total_image_error >= total_image_error:
|
||||
iterations_since_improvement += 1
|
||||
continue
|
||||
|
||||
# We found a globally better set of palettes
|
||||
while True:
|
||||
print("New iteration")
|
||||
iterations_since_improvement = 0
|
||||
last_good_splits = self._palette_splits
|
||||
total_image_error = new_total_image_error
|
||||
self._palette_splits = self._equal_palette_splits()
|
||||
self._init_palette_lines()
|
||||
|
||||
self._palettes_cam = new_palettes_cam
|
||||
self._palettes_rgb = new_palettes_rgb12_iigs
|
||||
self._errors = new_palette_errors
|
||||
self._palettes_accepted = True
|
||||
self._fit_global_palette()
|
||||
while iterations_since_improvement < max_iterations:
|
||||
# print("Iterations %d" % iterations_since_improvement)
|
||||
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
|
||||
self._propose_palettes())
|
||||
|
||||
yield (new_total_image_error, output_4bit, line_to_palette,
|
||||
new_palettes_rgb12_iigs, palettes_linear_rgb)
|
||||
# Recompute image with proposed palettes and check whether it has
|
||||
# lower total image error than our previous best.
|
||||
(output_4bit, line_to_palette, palettes_linear_rgb,
|
||||
new_total_image_error) = self._dither_image(
|
||||
new_palettes_cam, penalty)
|
||||
|
||||
self._reassign_unused_palettes(line_to_palette,
|
||||
last_good_splits)
|
||||
|
||||
# print(total_image_error, new_total_image_error)
|
||||
if new_total_image_error >= total_image_error:
|
||||
iterations_since_improvement += 1
|
||||
continue
|
||||
|
||||
# We found a globally better set of palettes
|
||||
iterations_since_improvement = 0
|
||||
last_good_splits = self._palette_splits
|
||||
total_image_error = new_total_image_error
|
||||
|
||||
self._palettes_cam = new_palettes_cam
|
||||
self._palettes_rgb = new_palettes_rgb12_iigs
|
||||
self._errors = new_palette_errors
|
||||
self._palettes_accepted = True
|
||||
|
||||
yield (new_total_image_error, output_4bit, line_to_palette,
|
||||
new_palettes_rgb12_iigs, palettes_linear_rgb)
|
||||
|
||||
def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
|
||||
"""Attempt to find new palettes that locally improve image quality.
|
||||
@ -192,7 +200,7 @@ class ClusterPalette:
|
||||
# Compute a new 16-colour global palette for the entire image,
|
||||
# used as the starting center positions for k-means clustering of the
|
||||
# individual palettes
|
||||
# self._fit_global_palette()
|
||||
self._fit_global_palette()
|
||||
|
||||
self._mutate_palette_splits()
|
||||
for palette_idx in range(16):
|
||||
@ -211,13 +219,13 @@ class ClusterPalette:
|
||||
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs
|
||||
)
|
||||
|
||||
if (palette_error >= self._errors[palette_idx] and not
|
||||
self._reserved_colours):
|
||||
# Not a local improvement to the existing palette, so ignore it.
|
||||
# We can't take this shortcut when we're reserving colours
|
||||
# because it would break the invariant that all palettes must
|
||||
# share colours.
|
||||
continue
|
||||
# if (palette_error >= self._errors[palette_idx] and not
|
||||
# self._reserved_colours):
|
||||
# # Not a local improvement to the existing palette, so ignore it.
|
||||
# # We can't take this shortcut when we're reserving colours
|
||||
# # because it would break the invariant that all palettes must
|
||||
# # share colours.
|
||||
# continue
|
||||
for i in range(16):
|
||||
new_palettes_cam[palette_idx, i, :] = (
|
||||
np.array(dither_pyx.convert_rgb12_iigs_to_cam(
|
||||
@ -299,43 +307,56 @@ class ClusterPalette:
|
||||
palettes_used = [False] * 16
|
||||
for palette in new_line_to_palette:
|
||||
palettes_used[palette] = True
|
||||
best_palette_lines = [v for k, v in sorted(list(zip(
|
||||
self._palette_line_errors, range(200))))]
|
||||
|
||||
# print(self._palette_lines)
|
||||
for palette_idx, palette_used in enumerate(palettes_used):
|
||||
if palette_used:
|
||||
continue
|
||||
print("Reassigning palette %d" % palette_idx)
|
||||
# print("Reassigning palette %d" % palette_idx)
|
||||
|
||||
worst_average_palette_error = 0
|
||||
split_palette_idx = -1
|
||||
idx = 0
|
||||
for idx, lines in self._palette_lines.items():
|
||||
if len(lines) < 10:
|
||||
continue
|
||||
average_palette_error = np.sum(self._palette_line_errors[
|
||||
lines]) / len(lines)
|
||||
print(idx, average_palette_error)
|
||||
if average_palette_error > worst_average_palette_error:
|
||||
worst_average_palette_error = average_palette_error
|
||||
split_palette_idx = idx
|
||||
# TODO: also remove from old entry
|
||||
worst_line = best_palette_lines.pop()
|
||||
self._palette_lines[palette_idx] = [worst_line]
|
||||
|
||||
print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error))
|
||||
# TODO: split off lines with largest error
|
||||
# print("Picked line %d with error %f" % (worst_line,
|
||||
# self._palette_line_errors[worst_line]))
|
||||
|
||||
palette_line_errors = self._palette_line_errors[
|
||||
self._palette_lines[split_palette_idx]]
|
||||
|
||||
print(sorted(
|
||||
list(zip(palette_line_errors, self._palette_lines[
|
||||
split_palette_idx])), reverse=True))
|
||||
best_palette_lines = [v for k, v in sorted(
|
||||
list(zip(palette_line_errors, self._palette_lines[
|
||||
split_palette_idx])))]
|
||||
num_max_lines = len(self._palette_lines[split_palette_idx])
|
||||
|
||||
self._palette_lines[split_palette_idx] = best_palette_lines[
|
||||
:num_max_lines // 2]
|
||||
# Move worst half to new palette
|
||||
self._palette_lines[palette_idx] = best_palette_lines[
|
||||
num_max_lines // 2:]
|
||||
#
|
||||
# worst_average_palette_error = 0
|
||||
# split_palette_idx = -1
|
||||
# idx = 0
|
||||
# for idx, lines in self._palette_lines.items():
|
||||
# if len(lines) < 10:
|
||||
# continue
|
||||
# average_palette_error = np.sum(self._palette_line_errors[
|
||||
# lines]) / len(lines)
|
||||
# print(idx, average_palette_error)
|
||||
# if average_palette_error > worst_average_palette_error:
|
||||
# worst_average_palette_error = average_palette_error
|
||||
# split_palette_idx = idx
|
||||
#
|
||||
# print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error))
|
||||
# # TODO: split off lines with largest error
|
||||
#
|
||||
# palette_line_errors = self._palette_line_errors[
|
||||
# self._palette_lines[split_palette_idx]]
|
||||
#
|
||||
# print(sorted(
|
||||
# list(zip(palette_line_errors, self._palette_lines[
|
||||
# split_palette_idx])), reverse=True))
|
||||
# best_palette_lines = [v for k, v in sorted(
|
||||
# list(zip(palette_line_errors, self._palette_lines[
|
||||
# split_palette_idx])))]
|
||||
# num_max_lines = len(self._palette_lines[split_palette_idx])
|
||||
#
|
||||
# self._palette_lines[split_palette_idx] = best_palette_lines[
|
||||
# :num_max_lines // 2]
|
||||
# # Move worst half to new palette
|
||||
# self._palette_lines[palette_idx] = best_palette_lines[
|
||||
# num_max_lines // 2:]
|
||||
|
||||
|
||||
def main():
|
||||
@ -400,7 +421,7 @@ def main():
|
||||
|
||||
# TODO: flags
|
||||
penalty = 1 # 1e18 # TODO: is this needed any more?
|
||||
iterations = 200
|
||||
iterations = 10 # 20
|
||||
|
||||
pygame.init()
|
||||
# TODO: for some reason I need to execute this twice - the first time
|
||||
@ -417,6 +438,7 @@ def main():
|
||||
rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs,
|
||||
rgb24_to_cam16ucs=rgb24_to_cam16ucs)
|
||||
|
||||
seq = 0
|
||||
for (new_total_image_error, output_4bit, line_to_palette,
|
||||
palettes_rgb12_iigs, palettes_linear_rgb) in cluster_palette.iterate(
|
||||
penalty, iterations):
|
||||
@ -463,20 +485,24 @@ def main():
|
||||
np.asarray(out_image).transpose((1, 0, 2))) # flip y/x axes
|
||||
canvas.blit(surface, (0, 0))
|
||||
pygame.display.flip()
|
||||
|
||||
seq += 1
|
||||
# Save Double hi-res image
|
||||
outfile = os.path.join(
|
||||
os.path.splitext(args.output)[0] + "-%d-preview.png" % seq)
|
||||
out_image.save(outfile, "PNG")
|
||||
screen.pack()
|
||||
# with open(args.output, "wb") as f:
|
||||
# f.write(bytes(screen.aux))
|
||||
# f.write(bytes(screen.main))
|
||||
with open(args.output, "wb") as f:
|
||||
f.write(bytes(screen.memory))
|
||||
|
||||
# print((palettes_rgb * 255).astype(np.uint8))
|
||||
unique_colours = np.unique(
|
||||
palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0]
|
||||
print("%d unique colours" % unique_colours)
|
||||
|
||||
# Save Double hi-res image
|
||||
outfile = os.path.join(os.path.splitext(args.output)[0] + "-preview.png")
|
||||
out_image.save(outfile, "PNG")
|
||||
screen.pack()
|
||||
# with open(args.output, "wb") as f:
|
||||
# f.write(bytes(screen.aux))
|
||||
# f.write(bytes(screen.main))
|
||||
with open(args.output, "wb") as f:
|
||||
f.write(bytes(screen.memory))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user