Checkpoint

- Repeatedly refit palettes since k-means is only a local
  optimization.  This can produce incremental improvements in image
  quality but may also overfit, especially on complex images.
- use pygame to render incremental images
- Fix off-by-one in palette striping
- When fitting palettes, first cluster a 16-colour palette for the
  entire image and use this to initialize the centroids for individual
  palettes.  This improves quality when fitting images with large
  blocks of colour, since they will otherwise be fit separately and
  may have slight differences.  With a global initializer these will
  tend to be the same.  This also improves performance.
This commit is contained in:
kris 2021-11-16 11:21:53 +00:00
parent b363d60754
commit 10c829906b
1 changed files with 129 additions and 127 deletions

View File

@ -6,6 +6,7 @@ import os.path
import time
import collections
import random
import pygame
import colour
from PIL import Image
@ -14,6 +15,7 @@ from pyclustering.cluster.kmedians import kmedians
from pyclustering.cluster.kmeans import kmeans
from pyclustering.utils.metric import distance_metric, type_metric
from pyclustering.cluster.center_initializer import kmeans_plusplus_initializer
from sklearn import cluster
import dither as dither_pyx
import dither_pattern
@ -26,117 +28,103 @@ import screen as screen_py
# - support LR/DLR
# - support HGR
def cluster_palette(image: Image):
# line_to_palette = {}
# shuffle_lines = liprint(st(range(200))
# random.shuffle(shuffle_lines)
# for idx, line in enumerate(shuffle_lines):
# line_to_palette[line] = idx % 16
# for line in range(200):
# if line % 3 == 0:
# line_to_palette[line] = int(line / (200 / 16))
# elif line % 3 == 1:
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 1, 0, 15)
# else:
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 2, 0, 15)
# for line in range(200):
# if line % 3 == 0:
# line_to_palette[line] = int(line / (200 / 16))
# elif line % 3 == 1:
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 1, 0, 15)
# else:
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 2, 0, 15)
colours_rgb = np.asarray(image).reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
colours_cam = colour.convert(colours_rgb, "RGB",
"CAM16UCS").astype(np.float32)
palettes_rgb = np.empty((16, 16, 3), dtype=np.float32)
palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
for palette_idx in range(16):
print("Fitting palette %d" % palette_idx)
p_lower = max(palette_idx - 2, 0)
p_upper = min(palette_idx + 2, 16)
palette_pixels = colours_cam[
int(p_lower * (200 / 16)) * 320:int(p_upper * (
200 / 16)) * 320, :]
# kmeans = KMeans(n_clusters=16, max_iter=10000)
# kmeans.fit_predict(palette_pixels)
# palettes_cam[palette_idx] = kmeans.cluster_centers_
# fixed_centroids = None
# print(np.array(line_colours), fixed_centroids)
# palettes_cam[palette_idx] = dither_pyx.k_means_with_fixed_centroids(
# 16, palette_pixels, fixed_centroids=fixed_centroids,
# tolerance=1e-6)
best_wce = 1e9
best_medians = None
for i in range(500):
# print(i)
initial_centers = kmeans_plusplus_initializer(
palette_pixels, 16).initialize()
kmedians_instance = kmedians(
palette_pixels, initial_centers, tolerance=0.1, itermax=100,
metric=distance_metric(type_metric.MANHATTAN))
kmedians_instance.process()
if kmedians_instance.get_total_wce() < best_wce:
best_wce = kmedians_instance.get_total_wce()
print(i, best_wce)
best_medians = kmedians_instance
print("Best %f" % best_wce)
palettes_cam[palette_idx, :, :] = np.array(
best_medians.get_medians()).astype(np.float32)
# palette_colours = collections.defaultdict(list)
# for line in range(200):
# palette = line_to_palette[line]
# palette_colours[palette].extend(
# colours_cam[line * 320:(line + 1) * 320])
# For each line grouping, find big palette entries with minimal total
# distance
# palette_cam = None
# for palette_idx in range(16):
# line_colours = palette_colours[palette_idx]
# #if palette_idx < 15:
# # line_colours += palette_colours[palette_idx + 1]
# # if palette_idx < 14:
# # line_colours += palette_colours[palette_idx + 2]
# # if palette_idx > 0:
# # fixed_centroids = palette_cam[:8, :]
# # else:
# fixed_centroids = None
# # print(np.array(line_colours), fixed_centroids)
# palette_cam = dither_pyx.k_means_with_fixed_centroids(16, np.array(
# line_colours), fixed_centroids=fixed_centroids, tolerance=1e-6)
# kmeans = KMeans(n_clusters=16, max_iter=10000)
# kmeans.fit_predict(line_colours)
# palette_cam = kmeans.cluster_centers_
class ClusterPalette:
def __init__(self, image: Image):
self._colours_cam = self._image_colours_cam(image)
self._best_palette_distances = {i: (1e9, None) for i in range(16)}
self._iterations = 0
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.float32)
self._global_palette = self._fit_global_palette()
def _image_colours_cam(self, image: Image):
colours_rgb = np.asarray(image).reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
palette_rgb = colour.convert(palettes_cam[palette_idx], "CAM16UCS",
"RGB")
# SHR colour palette only uses 4-bit values
palette_rgb = np.round(palette_rgb * 15) / 15
palettes_rgb[palette_idx, :, :] = palette_rgb.astype(np.float32)
# print(palettes_rgb)
colours_cam = colour.convert(colours_rgb, "RGB",
"CAM16UCS").astype(np.float32)
return colours_cam
# For each line, pick the palette with lowest total distance
# best_palette = 15
# for line in range(200):
# line_pixels = colours_cam[line*320:(line+1)*320]
# best_palette = dither_pyx.best_palette_for_line(
# line_pixels, palettes_cam, best_palette)
# line_to_palette[line] = best_palette
# print(line, line_to_palette[line])
return palettes_cam, palettes_rgb
def _fit_global_palette(self):
"""Compute a 16-colour palette for the entire image to use as
starting point for the sub-palettes. This should help when the image
has large blocks of colour since the sub-palettes will tend to pick the same coloursx."""
clusters = cluster.MiniBatchKMeans(n_clusters=16, max_iter=10000)
# tol=0.0000000001, algorithm="elkan")
clusters.fit_predict(self._colours_cam)
return clusters.cluster_centers_
def iterate(self):
self._iterations += 1
print("Iteration %d" % self._iterations)
for palette_idx in range(16):
# i=5: 3 * (200/16) : 7 * (200/16)
# print("Fitting palette %d" % palette_idx)
p_lower2 = max(palette_idx - 1.5, 0)
p_lower1 = max(palette_idx - 1, 0)
p_lower0 = palette_idx
p_upper0 = max(palette_idx + 1, 16)
p_upper1 = max(palette_idx + 2, 16)
p_upper2 = min(palette_idx + 2.5, 16)
# TODO: weight +/-1 and 0 bands higher
# TODO: dynamically tune palette cuts
palette_pixels = np.concatenate(
[
self._colours_cam[
int(p_lower2 * (200 / 16)) * 320:int(p_upper2 * (
200 / 16)) * 320, :],
# self._colours_cam[
# int(p_lower1 * (200 / 16)) * 320:int(p_upper1 * (
# 200 / 16)) * 320, :],
# self._colours_cam[
# int(p_lower0 * (200 / 16)) * 320:int(p_upper0 * (
# 200 / 16)) * 320, :],
], axis=0)
best_wce, best_medians = self._best_palette_distances[palette_idx]
# if palette_idx == 0:
# initial_centers = kmeans_plusplus_initializer(
# palette_pixels, 16).initialize()
# else:
# initial_centers = kmedians_instance.get_medians()
# kmedians_instance = kmeans(
# palette_pixels, initial_centers, tolerance=0.0000000001,
# itermax=100,
# metric=distance_metric(type_metric.EUCLIDEAN_SQUARE))
# kmedians_instance.process()
# TODO: tolerance
clusters = cluster.MiniBatchKMeans(
n_clusters=16, max_iter=10000, init=self._global_palette,
n_init=1)
# tol=0.0000000001, algorithm="elkan")
clusters.fit_predict(palette_pixels)
# if kmedians_instance.get_total_wce() < best_wce:
# best_wce = kmedians_instance.get_total_wce()
# best_medians = kmedians_instance
if clusters.inertia_ < (best_wce * 0.99):
best_wce = clusters.inertia_
print("Improved palette %d: %f" % (palette_idx, best_wce))
# self._palettes_cam[palette_idx, :, :] = np.array(
# best_medians.get_centers()).astype(np.float32)
self._palettes_cam[palette_idx, :, :] = np.array(
clusters.cluster_centers_).astype(np.float32)
self._best_palette_distances[palette_idx] = (
best_wce, best_medians)
with colour.utilities.suppress_warnings(
colour_usage_warnings=True):
palette_rgb = colour.convert(
self._palettes_cam[palette_idx], "CAM16UCS", "RGB")
# SHR colour palette only uses 4-bit values
palette_rgb = np.round(palette_rgb * 15) / 15
self._palettes_rgb[palette_idx, :, :] = palette_rgb.astype(
np.float32)
return self._palettes_cam, self._palettes_rgb
def main():
@ -199,15 +187,24 @@ def main():
gamma=args.gamma_correct, srgb_output=True)).astype(
np.float32) / 255
palettes_cam, palettes_rgb = cluster_palette(rgb)
# print(palette_rgb)
# screen.set_palette(0, (image_py.linear_to_srgb_array(palette_rgb) *
# 15).astype(np.uint8))
for i in range(16):
screen.set_palette(i, (np.round(palettes_rgb[i, :, :] * 15)).astype(
np.uint8))
penalty = 10 # 1e9
iterations = 50
pygame.init()
canvas = pygame.display.set_mode((640, 400))
canvas = pygame.display.set_mode((640, 400))
canvas.fill((0, 0, 0))
pygame.display.flip()
# print("Foo")
cluster_palette = ClusterPalette(rgb)
for iteration in range(iterations):
palettes_cam, palettes_rgb = cluster_palette.iterate()
# print((palettes_rgb*255).astype(np.uint8))
for i in range(16):
screen.set_palette(i, (np.round(palettes_rgb[i, :, :] * 15)).astype(
np.uint8))
for penalty in [1,2,3,4,5,6,7,8,9,10,1e9]:
output_4bit, line_to_palette = dither_pyx.dither_shr(
rgb, palettes_cam, palettes_rgb, rgb_to_cam16, float(penalty))
screen.set_pixels(output_4bit)
@ -233,21 +230,26 @@ def main():
# output_srgb = image_py.linear_to_srgb(
# output_screen.bitmap_to_image_rgb(bitmap)).astype(np.uint8)
out_image = image_py.resize(
Image.fromarray(output_srgb), screen.X_RES, screen.Y_RES,
Image.fromarray(output_srgb), screen.X_RES * 2, screen.Y_RES * 2,
srgb_output=False) # XXX true
if args.show_output:
out_image.show()
surface = pygame.surfarray.make_surface(np.asarray(
out_image).transpose((1, 0, 2)))
canvas.blit(surface, (0, 0))
pygame.display.flip()
# Save Double hi-res image
outfile = os.path.join(os.path.splitext(args.output)[0] + "-preview.png")
out_image.save(outfile, "PNG")
screen.pack()
# with open(args.output, "wb") as f:
# f.write(bytes(screen.aux))
# f.write(bytes(screen.main))
with open(args.output, "wb") as f:
f.write(bytes(screen.memory))
# Save Double hi-res image
outfile = os.path.join(os.path.splitext(args.output)[0] +
"-%d-preview.png" % cluster_palette._iterations)
out_image.save(outfile, "PNG")
screen.pack()
# with open(args.output, "wb") as f:
# f.write(bytes(screen.aux))
# f.write(bytes(screen.main))
with open("%s-%s" % (args.output, cluster_palette._iterations),
"wb") as f:
f.write(bytes(screen.memory))
if __name__ == "__main__":