Tidy a bit and add a --save-intermediate flag

This commit is contained in:
kris 2022-07-18 10:00:19 +01:00
parent 1ffb2c9110
commit 3196369b7d
2 changed files with 46 additions and 28 deletions

View File

@ -91,6 +91,11 @@ def main():
default=False, help='Whether to output the final image quality score '
'(default: False)'
)
shr_parser.add_argument(
'--save-intermediate', action=argparse.BooleanOptionalAction,
default=False, help='Whether to save each intermediate iteration, '
'or just the final image (default: False)'
)
shr_parser.set_defaults(func=convert_shr)
args = parser.parse_args()
args.func(args)

View File

@ -41,26 +41,28 @@ class ClusterPalette:
# Preprocessed source image in CAM16UCS colour space
self._colours_cam = self._image_colours_cam(self._image_rgb)
# How many image colours to fix identically across all 16 SHR
# palettes. These are taken to be the most prevalent colours from
# _global_palette.
self._fixed_colours = fixed_colours
# We fit a 16-colour palette against the entire image which is used
# as starting values for fitting the reserved colours in the 16 SHR
# palettes.
self._global_palette = np.empty((16, 3), dtype=np.uint8)
# How many image colours to fix identically across all 16 SHR
# palettes. These are taken to be the most prevalent colours from
# _global_palette.
self._fixed_colours = fixed_colours
# 16 SHR palettes each of 16 colours, in CAM16UCS colour space
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
# 16 SHR palettes each of 16 colours, in //gs 4-bit RGB colour space
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.uint8)
# defaultdict(list) mapping palette index to lines using this palette
# defaultdict(list) mapping palette index to the lines that use this
# palette
self._palette_lines = self._init_palette_lines()
def _image_colours_cam(self, image: Image):
@staticmethod
def _image_colours_cam(image: Image):
colours_rgb = np.asarray(image) # .reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
colours_cam = colour.convert(colours_rgb, "RGB",
@ -84,7 +86,8 @@ class ClusterPalette:
palette_lines[i].extend(list(range(l, h)))
return palette_lines
def _equal_palette_splits(self, palette_height=35):
@staticmethod
def _equal_palette_splits(palette_height=35):
# The 16 palettes are striped across consecutive (overlapping) line
# ranges. Since nearby lines tend to have similar colours, this has
# the effect of smoothing out the colour transitions across palettes.
@ -202,7 +205,9 @@ class ClusterPalette:
when dithering. i.e. they would reduce the overall image quality.
The current (locally) best palettes are returned and can be applied
using accept_palettes().
using accept_palettes()
XXX update
"""
new_palettes_cam = np.empty_like(self._palettes_cam)
new_palettes_rgb12_iigs = np.empty_like(self._palettes_rgb)
@ -214,8 +219,8 @@ class ClusterPalette:
for palette_idx in range(16):
palette_pixels = (
self._colours_cam[
self._palette_lines[palette_idx], :, :].reshape(-1, 3))
self._colours_cam[self._palette_lines[
palette_idx], :, :].reshape(-1, 3))
# Fix reserved colours from the global palette.
initial_centroids = np.copy(self._global_palette)
@ -244,14 +249,14 @@ class ClusterPalette:
*np.unique(pixels_rgb_iigs, return_counts=True, axis=0))),
key=lambda kv: kv[1], reverse=True)
fixed_colours = self._fixed_colours
for colour, freq in most_frequent_colours:
for palette_colour, freq in most_frequent_colours:
if (freq < (palette_pixels.shape[0] *
fixed_colour_fraction_threshold)) or (
fixed_colours == 16):
break
if tuple(colour) not in seen_colours:
seen_colours.add(tuple(colour))
initial_centroids[fixed_colours, :] = colour
if tuple(palette_colour) not in seen_colours:
seen_colours.add(tuple(palette_colour))
initial_centroids[fixed_colours, :] = palette_colour
fixed_colours += 1
palette_rgb12_iigs = dither_shr_pyx.k_means_with_fixed_centroids(
@ -300,7 +305,8 @@ class ClusterPalette:
clusters.cluster_centers_[frequency_order].astype(
np.float32)))
def _fill_short_palette(self, palette_iigs_rgb, most_frequent_colours):
@staticmethod
def _fill_short_palette(palette_iigs_rgb, most_frequent_colours):
"""Fill out the palette to 16 unique entries."""
# We want to maintain order of insertion so that we respect the
@ -313,11 +319,10 @@ class ClusterPalette:
return palette_iigs_rgb
# Add most frequent image colours that are not yet in the palette
for colour, freq in most_frequent_colours:
if tuple(colour) in palette_set:
for palette_colour, freq in most_frequent_colours:
if tuple(palette_colour) in palette_set:
continue
palette_set[tuple(colour)] = True
# print("Added freq %d" % freq)
palette_set[tuple(palette_colour)] = True
if len(palette_set) == 16:
break
@ -348,7 +353,6 @@ class ClusterPalette:
for palette_idx, palette_used in enumerate(palettes_used):
if palette_used:
continue
# print("Reassigning palette %d" % palette_idx)
# TODO: also remove from old entry
worst_line = best_palette_lines.pop()
@ -373,7 +377,7 @@ def convert(screen, rgb: np.ndarray, args):
canvas = pygame.display.set_mode((640, 400))
canvas.fill((0, 0, 0))
pygame.display.set_caption("][-Pix image preview")
pygame.event.pump()
pygame.event.pump() # Update caption
pygame.display.flip()
total_image_error = None
@ -382,6 +386,8 @@ def convert(screen, rgb: np.ndarray, args):
rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs,
rgb24_to_cam16ucs=rgb24_to_cam16ucs)
output_base, output_ext = os.path.splitext(args.output)
seq = 0
for (
new_total_image_error, output_4bit, line_to_palette,
@ -418,7 +424,7 @@ def convert(screen, rgb: np.ndarray, args):
canvas.blit(surface, (0, 0))
pygame.display.set_caption("][-Pix image preview [Iteration %d]"
% seq)
pygame.event.pump()
pygame.event.pump() # Update caption
pygame.display.flip()
unique_colours = np.unique(
@ -426,16 +432,23 @@ def convert(screen, rgb: np.ndarray, args):
if args.verbose:
print("%d unique colours" % unique_colours)
seq += 1
if args.save_preview:
# Save super hi-res image
outfile = os.path.join(
os.path.splitext(args.output)[0] + "-preview.png")
if args.save_intermediate:
outfile = "%s-%d-preview.png" % (output_base, seq)
else:
outfile = "%s-preview.png" % output_base
out_image.save(outfile, "PNG")
screen.pack()
with open(args.output + ".%d" % seq, "wb") as f:
if args.save_intermediate:
outfile = "%s-%d%s" % (output_base, seq, output_ext)
else:
outfile = "%s%s" % (output_base, output_ext)
with open(outfile, "wb") as f:
f.write(bytes(screen.memory))
seq += 1
if args.show_final_score:
print("FINAL_SCORE:", total_image_error)