diff --git a/convert.py b/convert.py index 518d742..d96b6f2 100644 --- a/convert.py +++ b/convert.py @@ -1,7 +1,9 @@ """Image converter to Apple II Double Hi-Res format.""" import argparse +from collections import defaultdict import os.path +import random from typing import Tuple, List from PIL import Image @@ -72,8 +74,13 @@ class ClusterPalette: # Delta applied to palette split in previous iteration self._palette_mutate_delta = (0, 0) + self._palette_lines = defaultdict(list) + for i, lh in enumerate(self._palette_splits): + l, h = lh + self._palette_lines[i].extend(list(range(l, h))) + def _image_colours_cam(self, image: Image): - colours_rgb = np.asarray(image).reshape((-1, 3)) + colours_rgb = np.asarray(image) # .reshape((-1, 3)) with colour.utilities.suppress_warnings(colour_usage_warnings=True): colours_cam = colour.convert(colours_rgb, "RGB", "CAM16UCS").astype(np.float32) @@ -109,11 +116,18 @@ class ClusterPalette: palettes_linear_rgb = colour.convert( palettes_cam, "CAM16UCS", "RGB").astype(np.float32) - output_4bit, line_to_palette, total_image_error = \ + output_4bit, line_to_palette, total_image_error, palette_line_errors = \ dither_pyx.dither_shr( self._image_rgb, palettes_cam, palettes_linear_rgb, self._rgb24_to_cam16ucs, float(penalty)) + palette_lines = defaultdict(list) + for line, palette in enumerate(line_to_palette): + palette_lines[palette].append(line) + self._palette_lines = palette_lines + + self._palette_line_errors = palette_line_errors + return (output_4bit, line_to_palette, palettes_linear_rgb, total_image_error) @@ -121,7 +135,7 @@ class ClusterPalette: iterations_since_improvement = 0 total_image_error = 1e9 last_good_splits = self._palette_splits - + self._fit_global_palette() while iterations_since_improvement < max_iterations: # print("Iterations %d" % iterations_since_improvement) new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = ( @@ -136,6 +150,7 @@ class ClusterPalette: self._reassign_unused_palettes(line_to_palette, last_good_splits) + print(total_image_error, new_total_image_error) if new_total_image_error >= total_image_error: iterations_since_improvement += 1 continue @@ -177,13 +192,15 @@ class ClusterPalette: # Compute a new 16-colour global palette for the entire image, # used as the starting center positions for k-means clustering of the # individual palettes - self._fit_global_palette() + # self._fit_global_palette() self._mutate_palette_splits() for palette_idx in range(16): - palette_lower, palette_upper = self._palette_splits[palette_idx] - palette_pixels = self._colours_cam[ - palette_lower * 320:palette_upper * 320, :] + # print(palette_idx, self._palette_lines[palette_idx]) + # palette_lower, palette_upper = self._palette_splits[palette_idx] + palette_pixels = ( + self._colours_cam[ + self._palette_lines[palette_idx], :, :].reshape(-1, 3)) palettes_rgb12_iigs, palette_error = \ dither_pyx.k_means_with_fixed_centroids( @@ -220,7 +237,7 @@ class ClusterPalette: same colours.""" clusters = cluster.MiniBatchKMeans(n_clusters=16, max_iter=10000) - clusters.fit_predict(self._colours_cam) + clusters.fit_predict(self._colours_cam.reshape(-1, 3)) # Dict of {palette idx : frequency count} palette_freq = {idx: 0 for idx in range(16)} @@ -286,25 +303,39 @@ class ClusterPalette: if palette_used: continue print("Reassigning palette %d" % palette_idx) - max_width = 0 + + worst_average_palette_error = 0 split_palette_idx = -1 idx = 0 - for lower, upper in last_good_splits: - width = upper - lower - if width > max_width: + for idx, lines in self._palette_lines.items(): + if len(lines) < 10: + continue + average_palette_error = np.sum(self._palette_line_errors[ + lines]) / len(lines) + print(idx, average_palette_error) + if average_palette_error > worst_average_palette_error: + worst_average_palette_error = average_palette_error split_palette_idx = idx - idx += 1 - lower, upper = last_good_splits[split_palette_idx] - if upper - lower > 20: - mid = (lower + upper) // 2 - self._palette_splits[split_palette_idx] = ( - lower, mid - 1) - self._palette_splits[palette_idx] = (mid, upper) - else: - lower = np.random.randint(0, 199) - upper = np.random.randint(lower + 1, 200) - self._palette_splits[palette_idx] = (lower, upper) + print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error)) + # TODO: split off lines with largest error + + palette_line_errors = self._palette_line_errors[ + self._palette_lines[split_palette_idx]] + + print(sorted( + list(zip(palette_line_errors, self._palette_lines[ + split_palette_idx])), reverse=True)) + best_palette_lines = [v for k, v in sorted( + list(zip(palette_line_errors, self._palette_lines[ + split_palette_idx])))] + num_max_lines = len(self._palette_lines[split_palette_idx]) + + self._palette_lines[split_palette_idx] = best_palette_lines[ + :num_max_lines // 2] + # Move worst half to new palette + self._palette_lines[palette_idx] = best_palette_lines[ + num_max_lines // 2:] def main(): diff --git a/dither.pyx b/dither.pyx index 59b11b3..22f0e46 100644 --- a/dither.pyx +++ b/dither.pyx @@ -352,6 +352,8 @@ def dither_shr( cdef float[:, ::1] line_cam = np.zeros((320, 3), dtype=np.float32) cdef int[::1] line_to_palette = np.zeros(200, dtype=np.int32) + cdef double[::1] palette_line_errors = np.zeros(200, dtype=np.float64) + cdef PaletteSelection palette_line best_palette = -1 total_image_error = 0.0 @@ -360,7 +362,10 @@ def dither_shr( line_cam[x, :] = convert_rgb_to_cam16ucs( rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2]) - best_palette = best_palette_for_line(line_cam, palettes_cam, best_palette, penalty) + palette_line = best_palette_for_line(line_cam, palettes_cam, best_palette, penalty) + best_palette = palette_line.palette_idx + palette_line_errors[y] = palette_line.total_error + palette_rgb = palettes_rgb[best_palette, :, :] palette_cam = palettes_cam[best_palette, :, :] line_to_palette[y] = best_palette @@ -449,12 +454,16 @@ def dither_shr( # working_image[y + 2, x + 2, i] + quant_error * (1 / 48), # 0, 1) - return np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error + return np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error, np.array(palette_line_errors, dtype=np.float64) +cdef struct PaletteSelection: + int palette_idx + double total_error + @cython.boundscheck(False) @cython.wraparound(False) -cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int last_palette_idx, float last_penalty) nogil: +cdef PaletteSelection best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int last_palette_idx, float last_penalty) nogil: cdef int palette_idx, best_palette_idx, palette_entry_idx, pixel_idx cdef double best_total_dist, total_dist, best_pixel_dist, pixel_dist cdef float[:, ::1] palette_cam @@ -479,7 +488,11 @@ cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palette if total_dist < best_total_dist: best_total_dist = total_dist best_palette_idx = palette_idx - return best_palette_idx + + cdef PaletteSelection res + res.palette_idx = best_palette_idx + res.total_error = best_total_dist + return res @cython.boundscheck(False)