Initial attempt at fitting palettes to arbitrary lines instead of line ranges.

Works OK but isn't converging as well as I hoped.
This commit is contained in:
kris 2021-11-24 10:41:25 +00:00
parent 50c71d3a35
commit de8a303de2
2 changed files with 71 additions and 27 deletions

View File

@ -1,7 +1,9 @@
"""Image converter to Apple II Double Hi-Res format."""
import argparse
from collections import defaultdict
import os.path
import random
from typing import Tuple, List
from PIL import Image
@ -72,8 +74,13 @@ class ClusterPalette:
# Delta applied to palette split in previous iteration
self._palette_mutate_delta = (0, 0)
self._palette_lines = defaultdict(list)
for i, lh in enumerate(self._palette_splits):
l, h = lh
self._palette_lines[i].extend(list(range(l, h)))
def _image_colours_cam(self, image: Image):
colours_rgb = np.asarray(image).reshape((-1, 3))
colours_rgb = np.asarray(image) # .reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
colours_cam = colour.convert(colours_rgb, "RGB",
"CAM16UCS").astype(np.float32)
@ -109,11 +116,18 @@ class ClusterPalette:
palettes_linear_rgb = colour.convert(
palettes_cam, "CAM16UCS", "RGB").astype(np.float32)
output_4bit, line_to_palette, total_image_error = \
output_4bit, line_to_palette, total_image_error, palette_line_errors = \
dither_pyx.dither_shr(
self._image_rgb, palettes_cam, palettes_linear_rgb,
self._rgb24_to_cam16ucs, float(penalty))
palette_lines = defaultdict(list)
for line, palette in enumerate(line_to_palette):
palette_lines[palette].append(line)
self._palette_lines = palette_lines
self._palette_line_errors = palette_line_errors
return (output_4bit, line_to_palette, palettes_linear_rgb,
total_image_error)
@ -121,7 +135,7 @@ class ClusterPalette:
iterations_since_improvement = 0
total_image_error = 1e9
last_good_splits = self._palette_splits
self._fit_global_palette()
while iterations_since_improvement < max_iterations:
# print("Iterations %d" % iterations_since_improvement)
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
@ -136,6 +150,7 @@ class ClusterPalette:
self._reassign_unused_palettes(line_to_palette,
last_good_splits)
print(total_image_error, new_total_image_error)
if new_total_image_error >= total_image_error:
iterations_since_improvement += 1
continue
@ -177,13 +192,15 @@ class ClusterPalette:
# Compute a new 16-colour global palette for the entire image,
# used as the starting center positions for k-means clustering of the
# individual palettes
self._fit_global_palette()
# self._fit_global_palette()
self._mutate_palette_splits()
for palette_idx in range(16):
palette_lower, palette_upper = self._palette_splits[palette_idx]
palette_pixels = self._colours_cam[
palette_lower * 320:palette_upper * 320, :]
# print(palette_idx, self._palette_lines[palette_idx])
# palette_lower, palette_upper = self._palette_splits[palette_idx]
palette_pixels = (
self._colours_cam[
self._palette_lines[palette_idx], :, :].reshape(-1, 3))
palettes_rgb12_iigs, palette_error = \
dither_pyx.k_means_with_fixed_centroids(
@ -220,7 +237,7 @@ class ClusterPalette:
same colours."""
clusters = cluster.MiniBatchKMeans(n_clusters=16, max_iter=10000)
clusters.fit_predict(self._colours_cam)
clusters.fit_predict(self._colours_cam.reshape(-1, 3))
# Dict of {palette idx : frequency count}
palette_freq = {idx: 0 for idx in range(16)}
@ -286,25 +303,39 @@ class ClusterPalette:
if palette_used:
continue
print("Reassigning palette %d" % palette_idx)
max_width = 0
worst_average_palette_error = 0
split_palette_idx = -1
idx = 0
for lower, upper in last_good_splits:
width = upper - lower
if width > max_width:
for idx, lines in self._palette_lines.items():
if len(lines) < 10:
continue
average_palette_error = np.sum(self._palette_line_errors[
lines]) / len(lines)
print(idx, average_palette_error)
if average_palette_error > worst_average_palette_error:
worst_average_palette_error = average_palette_error
split_palette_idx = idx
idx += 1
lower, upper = last_good_splits[split_palette_idx]
if upper - lower > 20:
mid = (lower + upper) // 2
self._palette_splits[split_palette_idx] = (
lower, mid - 1)
self._palette_splits[palette_idx] = (mid, upper)
else:
lower = np.random.randint(0, 199)
upper = np.random.randint(lower + 1, 200)
self._palette_splits[palette_idx] = (lower, upper)
print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error))
# TODO: split off lines with largest error
palette_line_errors = self._palette_line_errors[
self._palette_lines[split_palette_idx]]
print(sorted(
list(zip(palette_line_errors, self._palette_lines[
split_palette_idx])), reverse=True))
best_palette_lines = [v for k, v in sorted(
list(zip(palette_line_errors, self._palette_lines[
split_palette_idx])))]
num_max_lines = len(self._palette_lines[split_palette_idx])
self._palette_lines[split_palette_idx] = best_palette_lines[
:num_max_lines // 2]
# Move worst half to new palette
self._palette_lines[palette_idx] = best_palette_lines[
num_max_lines // 2:]
def main():

View File

@ -352,6 +352,8 @@ def dither_shr(
cdef float[:, ::1] line_cam = np.zeros((320, 3), dtype=np.float32)
cdef int[::1] line_to_palette = np.zeros(200, dtype=np.int32)
cdef double[::1] palette_line_errors = np.zeros(200, dtype=np.float64)
cdef PaletteSelection palette_line
best_palette = -1
total_image_error = 0.0
@ -360,7 +362,10 @@ def dither_shr(
line_cam[x, :] = convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
best_palette = best_palette_for_line(line_cam, palettes_cam, best_palette, penalty)
palette_line = best_palette_for_line(line_cam, palettes_cam, best_palette, penalty)
best_palette = palette_line.palette_idx
palette_line_errors[y] = palette_line.total_error
palette_rgb = palettes_rgb[best_palette, :, :]
palette_cam = palettes_cam[best_palette, :, :]
line_to_palette[y] = best_palette
@ -449,12 +454,16 @@ def dither_shr(
# working_image[y + 2, x + 2, i] + quant_error * (1 / 48),
# 0, 1)
return np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error
return np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error, np.array(palette_line_errors, dtype=np.float64)
cdef struct PaletteSelection:
int palette_idx
double total_error
@cython.boundscheck(False)
@cython.wraparound(False)
cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int last_palette_idx, float last_penalty) nogil:
cdef PaletteSelection best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int last_palette_idx, float last_penalty) nogil:
cdef int palette_idx, best_palette_idx, palette_entry_idx, pixel_idx
cdef double best_total_dist, total_dist, best_pixel_dist, pixel_dist
cdef float[:, ::1] palette_cam
@ -479,7 +488,11 @@ cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palette
if total_dist < best_total_dist:
best_total_dist = total_dist
best_palette_idx = palette_idx
return best_palette_idx
cdef PaletteSelection res
res.palette_idx = best_palette_idx
res.total_error = best_total_dist
return res
@cython.boundscheck(False)