Tidy a bit and add a --save-intermediate flag

This commit is contained in:
kris 2022-07-18 10:00:19 +01:00
parent 1ffb2c9110
commit 3196369b7d
2 changed files with 46 additions and 28 deletions

View File

@ -91,6 +91,11 @@ def main():
default=False, help='Whether to output the final image quality score ' default=False, help='Whether to output the final image quality score '
'(default: False)' '(default: False)'
) )
shr_parser.add_argument(
'--save-intermediate', action=argparse.BooleanOptionalAction,
default=False, help='Whether to save each intermediate iteration, '
'or just the final image (default: False)'
)
shr_parser.set_defaults(func=convert_shr) shr_parser.set_defaults(func=convert_shr)
args = parser.parse_args() args = parser.parse_args()
args.func(args) args.func(args)

View File

@ -41,26 +41,28 @@ class ClusterPalette:
# Preprocessed source image in CAM16UCS colour space # Preprocessed source image in CAM16UCS colour space
self._colours_cam = self._image_colours_cam(self._image_rgb) self._colours_cam = self._image_colours_cam(self._image_rgb)
# How many image colours to fix identically across all 16 SHR
# palettes. These are taken to be the most prevalent colours from
# _global_palette.
self._fixed_colours = fixed_colours
# We fit a 16-colour palette against the entire image which is used # We fit a 16-colour palette against the entire image which is used
# as starting values for fitting the reserved colours in the 16 SHR # as starting values for fitting the reserved colours in the 16 SHR
# palettes. # palettes.
self._global_palette = np.empty((16, 3), dtype=np.uint8) self._global_palette = np.empty((16, 3), dtype=np.uint8)
# How many image colours to fix identically across all 16 SHR
# palettes. These are taken to be the most prevalent colours from
# _global_palette.
self._fixed_colours = fixed_colours
# 16 SHR palettes each of 16 colours, in CAM16UCS colour space # 16 SHR palettes each of 16 colours, in CAM16UCS colour space
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32) self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
# 16 SHR palettes each of 16 colours, in //gs 4-bit RGB colour space # 16 SHR palettes each of 16 colours, in //gs 4-bit RGB colour space
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.uint8) self._palettes_rgb = np.empty((16, 16, 3), dtype=np.uint8)
# defaultdict(list) mapping palette index to lines using this palette # defaultdict(list) mapping palette index to the lines that use this
# palette
self._palette_lines = self._init_palette_lines() self._palette_lines = self._init_palette_lines()
def _image_colours_cam(self, image: Image): @staticmethod
def _image_colours_cam(image: Image):
colours_rgb = np.asarray(image) # .reshape((-1, 3)) colours_rgb = np.asarray(image) # .reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True): with colour.utilities.suppress_warnings(colour_usage_warnings=True):
colours_cam = colour.convert(colours_rgb, "RGB", colours_cam = colour.convert(colours_rgb, "RGB",
@ -84,7 +86,8 @@ class ClusterPalette:
palette_lines[i].extend(list(range(l, h))) palette_lines[i].extend(list(range(l, h)))
return palette_lines return palette_lines
def _equal_palette_splits(self, palette_height=35): @staticmethod
def _equal_palette_splits(palette_height=35):
# The 16 palettes are striped across consecutive (overlapping) line # The 16 palettes are striped across consecutive (overlapping) line
# ranges. Since nearby lines tend to have similar colours, this has # ranges. Since nearby lines tend to have similar colours, this has
# the effect of smoothing out the colour transitions across palettes. # the effect of smoothing out the colour transitions across palettes.
@ -202,7 +205,9 @@ class ClusterPalette:
when dithering. i.e. they would reduce the overall image quality. when dithering. i.e. they would reduce the overall image quality.
The current (locally) best palettes are returned and can be applied The current (locally) best palettes are returned and can be applied
using accept_palettes(). using accept_palettes()
XXX update
""" """
new_palettes_cam = np.empty_like(self._palettes_cam) new_palettes_cam = np.empty_like(self._palettes_cam)
new_palettes_rgb12_iigs = np.empty_like(self._palettes_rgb) new_palettes_rgb12_iigs = np.empty_like(self._palettes_rgb)
@ -214,8 +219,8 @@ class ClusterPalette:
for palette_idx in range(16): for palette_idx in range(16):
palette_pixels = ( palette_pixels = (
self._colours_cam[ self._colours_cam[self._palette_lines[
self._palette_lines[palette_idx], :, :].reshape(-1, 3)) palette_idx], :, :].reshape(-1, 3))
# Fix reserved colours from the global palette. # Fix reserved colours from the global palette.
initial_centroids = np.copy(self._global_palette) initial_centroids = np.copy(self._global_palette)
@ -244,14 +249,14 @@ class ClusterPalette:
*np.unique(pixels_rgb_iigs, return_counts=True, axis=0))), *np.unique(pixels_rgb_iigs, return_counts=True, axis=0))),
key=lambda kv: kv[1], reverse=True) key=lambda kv: kv[1], reverse=True)
fixed_colours = self._fixed_colours fixed_colours = self._fixed_colours
for colour, freq in most_frequent_colours: for palette_colour, freq in most_frequent_colours:
if (freq < (palette_pixels.shape[0] * if (freq < (palette_pixels.shape[0] *
fixed_colour_fraction_threshold)) or ( fixed_colour_fraction_threshold)) or (
fixed_colours == 16): fixed_colours == 16):
break break
if tuple(colour) not in seen_colours: if tuple(palette_colour) not in seen_colours:
seen_colours.add(tuple(colour)) seen_colours.add(tuple(palette_colour))
initial_centroids[fixed_colours, :] = colour initial_centroids[fixed_colours, :] = palette_colour
fixed_colours += 1 fixed_colours += 1
palette_rgb12_iigs = dither_shr_pyx.k_means_with_fixed_centroids( palette_rgb12_iigs = dither_shr_pyx.k_means_with_fixed_centroids(
@ -300,7 +305,8 @@ class ClusterPalette:
clusters.cluster_centers_[frequency_order].astype( clusters.cluster_centers_[frequency_order].astype(
np.float32))) np.float32)))
def _fill_short_palette(self, palette_iigs_rgb, most_frequent_colours): @staticmethod
def _fill_short_palette(palette_iigs_rgb, most_frequent_colours):
"""Fill out the palette to 16 unique entries.""" """Fill out the palette to 16 unique entries."""
# We want to maintain order of insertion so that we respect the # We want to maintain order of insertion so that we respect the
@ -313,11 +319,10 @@ class ClusterPalette:
return palette_iigs_rgb return palette_iigs_rgb
# Add most frequent image colours that are not yet in the palette # Add most frequent image colours that are not yet in the palette
for colour, freq in most_frequent_colours: for palette_colour, freq in most_frequent_colours:
if tuple(colour) in palette_set: if tuple(palette_colour) in palette_set:
continue continue
palette_set[tuple(colour)] = True palette_set[tuple(palette_colour)] = True
# print("Added freq %d" % freq)
if len(palette_set) == 16: if len(palette_set) == 16:
break break
@ -348,7 +353,6 @@ class ClusterPalette:
for palette_idx, palette_used in enumerate(palettes_used): for palette_idx, palette_used in enumerate(palettes_used):
if palette_used: if palette_used:
continue continue
# print("Reassigning palette %d" % palette_idx)
# TODO: also remove from old entry # TODO: also remove from old entry
worst_line = best_palette_lines.pop() worst_line = best_palette_lines.pop()
@ -373,7 +377,7 @@ def convert(screen, rgb: np.ndarray, args):
canvas = pygame.display.set_mode((640, 400)) canvas = pygame.display.set_mode((640, 400))
canvas.fill((0, 0, 0)) canvas.fill((0, 0, 0))
pygame.display.set_caption("][-Pix image preview") pygame.display.set_caption("][-Pix image preview")
pygame.event.pump() pygame.event.pump() # Update caption
pygame.display.flip() pygame.display.flip()
total_image_error = None total_image_error = None
@ -382,6 +386,8 @@ def convert(screen, rgb: np.ndarray, args):
rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs, rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs,
rgb24_to_cam16ucs=rgb24_to_cam16ucs) rgb24_to_cam16ucs=rgb24_to_cam16ucs)
output_base, output_ext = os.path.splitext(args.output)
seq = 0 seq = 0
for ( for (
new_total_image_error, output_4bit, line_to_palette, new_total_image_error, output_4bit, line_to_palette,
@ -418,7 +424,7 @@ def convert(screen, rgb: np.ndarray, args):
canvas.blit(surface, (0, 0)) canvas.blit(surface, (0, 0))
pygame.display.set_caption("][-Pix image preview [Iteration %d]" pygame.display.set_caption("][-Pix image preview [Iteration %d]"
% seq) % seq)
pygame.event.pump() pygame.event.pump() # Update caption
pygame.display.flip() pygame.display.flip()
unique_colours = np.unique( unique_colours = np.unique(
@ -426,16 +432,23 @@ def convert(screen, rgb: np.ndarray, args):
if args.verbose: if args.verbose:
print("%d unique colours" % unique_colours) print("%d unique colours" % unique_colours)
seq += 1
if args.save_preview: if args.save_preview:
# Save super hi-res image # Save super hi-res image
outfile = os.path.join( if args.save_intermediate:
os.path.splitext(args.output)[0] + "-preview.png") outfile = "%s-%d-preview.png" % (output_base, seq)
else:
outfile = "%s-preview.png" % output_base
out_image.save(outfile, "PNG") out_image.save(outfile, "PNG")
screen.pack() screen.pack()
with open(args.output + ".%d" % seq, "wb") as f:
if args.save_intermediate:
outfile = "%s-%d%s" % (output_base, seq, output_ext)
else:
outfile = "%s%s" % (output_base, output_ext)
with open(outfile, "wb") as f:
f.write(bytes(screen.memory)) f.write(bytes(screen.memory))
seq += 1
if args.show_final_score: if args.show_final_score:
print("FINAL_SCORE:", total_image_error) print("FINAL_SCORE:", total_image_error)