Checkpoint
- Repeatedly refit palettes since k-means is only a local optimization. This can produce incremental improvements in image quality but may also overfit, especially on complex images. - use pygame to render incremental images - Fix off-by-one in palette striping - When fitting palettes, first cluster a 16-colour palette for the entire image and use this to initialize the centroids for individual palettes. This improves quality when fitting images with large blocks of colour, since they will otherwise be fit separately and may have slight differences. With a global initializer these will tend to be the same. This also improves performance.
This commit is contained in:
parent
b363d60754
commit
10c829906b
256
convert.py
256
convert.py
|
@ -6,6 +6,7 @@ import os.path
|
|||
import time
|
||||
import collections
|
||||
import random
|
||||
import pygame
|
||||
|
||||
import colour
|
||||
from PIL import Image
|
||||
|
@ -14,6 +15,7 @@ from pyclustering.cluster.kmedians import kmedians
|
|||
from pyclustering.cluster.kmeans import kmeans
|
||||
from pyclustering.utils.metric import distance_metric, type_metric
|
||||
from pyclustering.cluster.center_initializer import kmeans_plusplus_initializer
|
||||
from sklearn import cluster
|
||||
|
||||
import dither as dither_pyx
|
||||
import dither_pattern
|
||||
|
@ -26,117 +28,103 @@ import screen as screen_py
|
|||
# - support LR/DLR
|
||||
# - support HGR
|
||||
|
||||
def cluster_palette(image: Image):
|
||||
# line_to_palette = {}
|
||||
|
||||
# shuffle_lines = liprint(st(range(200))
|
||||
# random.shuffle(shuffle_lines)
|
||||
# for idx, line in enumerate(shuffle_lines):
|
||||
# line_to_palette[line] = idx % 16
|
||||
|
||||
# for line in range(200):
|
||||
# if line % 3 == 0:
|
||||
# line_to_palette[line] = int(line / (200 / 16))
|
||||
# elif line % 3 == 1:
|
||||
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 1, 0, 15)
|
||||
# else:
|
||||
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 2, 0, 15)
|
||||
|
||||
# for line in range(200):
|
||||
# if line % 3 == 0:
|
||||
# line_to_palette[line] = int(line / (200 / 16))
|
||||
# elif line % 3 == 1:
|
||||
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 1, 0, 15)
|
||||
# else:
|
||||
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 2, 0, 15)
|
||||
|
||||
colours_rgb = np.asarray(image).reshape((-1, 3))
|
||||
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
|
||||
colours_cam = colour.convert(colours_rgb, "RGB",
|
||||
"CAM16UCS").astype(np.float32)
|
||||
palettes_rgb = np.empty((16, 16, 3), dtype=np.float32)
|
||||
palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
|
||||
for palette_idx in range(16):
|
||||
print("Fitting palette %d" % palette_idx)
|
||||
p_lower = max(palette_idx - 2, 0)
|
||||
p_upper = min(palette_idx + 2, 16)
|
||||
palette_pixels = colours_cam[
|
||||
int(p_lower * (200 / 16)) * 320:int(p_upper * (
|
||||
200 / 16)) * 320, :]
|
||||
|
||||
# kmeans = KMeans(n_clusters=16, max_iter=10000)
|
||||
# kmeans.fit_predict(palette_pixels)
|
||||
# palettes_cam[palette_idx] = kmeans.cluster_centers_
|
||||
|
||||
# fixed_centroids = None
|
||||
# print(np.array(line_colours), fixed_centroids)
|
||||
# palettes_cam[palette_idx] = dither_pyx.k_means_with_fixed_centroids(
|
||||
# 16, palette_pixels, fixed_centroids=fixed_centroids,
|
||||
# tolerance=1e-6)
|
||||
|
||||
best_wce = 1e9
|
||||
best_medians = None
|
||||
for i in range(500):
|
||||
# print(i)
|
||||
initial_centers = kmeans_plusplus_initializer(
|
||||
palette_pixels, 16).initialize()
|
||||
kmedians_instance = kmedians(
|
||||
palette_pixels, initial_centers, tolerance=0.1, itermax=100,
|
||||
metric=distance_metric(type_metric.MANHATTAN))
|
||||
kmedians_instance.process()
|
||||
if kmedians_instance.get_total_wce() < best_wce:
|
||||
best_wce = kmedians_instance.get_total_wce()
|
||||
print(i, best_wce)
|
||||
best_medians = kmedians_instance
|
||||
print("Best %f" % best_wce)
|
||||
palettes_cam[palette_idx, :, :] = np.array(
|
||||
best_medians.get_medians()).astype(np.float32)
|
||||
|
||||
# palette_colours = collections.defaultdict(list)
|
||||
# for line in range(200):
|
||||
# palette = line_to_palette[line]
|
||||
# palette_colours[palette].extend(
|
||||
# colours_cam[line * 320:(line + 1) * 320])
|
||||
|
||||
# For each line grouping, find big palette entries with minimal total
|
||||
# distance
|
||||
|
||||
# palette_cam = None
|
||||
# for palette_idx in range(16):
|
||||
# line_colours = palette_colours[palette_idx]
|
||||
# #if palette_idx < 15:
|
||||
# # line_colours += palette_colours[palette_idx + 1]
|
||||
# # if palette_idx < 14:
|
||||
# # line_colours += palette_colours[palette_idx + 2]
|
||||
# # if palette_idx > 0:
|
||||
# # fixed_centroids = palette_cam[:8, :]
|
||||
# # else:
|
||||
# fixed_centroids = None
|
||||
# # print(np.array(line_colours), fixed_centroids)
|
||||
# palette_cam = dither_pyx.k_means_with_fixed_centroids(16, np.array(
|
||||
# line_colours), fixed_centroids=fixed_centroids, tolerance=1e-6)
|
||||
|
||||
# kmeans = KMeans(n_clusters=16, max_iter=10000)
|
||||
# kmeans.fit_predict(line_colours)
|
||||
# palette_cam = kmeans.cluster_centers_
|
||||
class ClusterPalette:
|
||||
def __init__(self, image: Image):
|
||||
self._colours_cam = self._image_colours_cam(image)
|
||||
self._best_palette_distances = {i: (1e9, None) for i in range(16)}
|
||||
self._iterations = 0
|
||||
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
|
||||
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.float32)
|
||||
self._global_palette = self._fit_global_palette()
|
||||
|
||||
def _image_colours_cam(self, image: Image):
|
||||
colours_rgb = np.asarray(image).reshape((-1, 3))
|
||||
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
|
||||
palette_rgb = colour.convert(palettes_cam[palette_idx], "CAM16UCS",
|
||||
"RGB")
|
||||
# SHR colour palette only uses 4-bit values
|
||||
palette_rgb = np.round(palette_rgb * 15) / 15
|
||||
palettes_rgb[palette_idx, :, :] = palette_rgb.astype(np.float32)
|
||||
# print(palettes_rgb)
|
||||
colours_cam = colour.convert(colours_rgb, "RGB",
|
||||
"CAM16UCS").astype(np.float32)
|
||||
return colours_cam
|
||||
|
||||
# For each line, pick the palette with lowest total distance
|
||||
# best_palette = 15
|
||||
# for line in range(200):
|
||||
# line_pixels = colours_cam[line*320:(line+1)*320]
|
||||
# best_palette = dither_pyx.best_palette_for_line(
|
||||
# line_pixels, palettes_cam, best_palette)
|
||||
# line_to_palette[line] = best_palette
|
||||
# print(line, line_to_palette[line])
|
||||
return palettes_cam, palettes_rgb
|
||||
def _fit_global_palette(self):
|
||||
"""Compute a 16-colour palette for the entire image to use as
|
||||
starting point for the sub-palettes. This should help when the image
|
||||
has large blocks of colour since the sub-palettes will tend to pick the same coloursx."""
|
||||
clusters = cluster.MiniBatchKMeans(n_clusters=16, max_iter=10000)
|
||||
# tol=0.0000000001, algorithm="elkan")
|
||||
clusters.fit_predict(self._colours_cam)
|
||||
return clusters.cluster_centers_
|
||||
|
||||
def iterate(self):
|
||||
self._iterations += 1
|
||||
print("Iteration %d" % self._iterations)
|
||||
for palette_idx in range(16):
|
||||
|
||||
|
||||
# i=5: 3 * (200/16) : 7 * (200/16)
|
||||
# print("Fitting palette %d" % palette_idx)
|
||||
p_lower2 = max(palette_idx - 1.5, 0)
|
||||
p_lower1 = max(palette_idx - 1, 0)
|
||||
p_lower0 = palette_idx
|
||||
p_upper0 = max(palette_idx + 1, 16)
|
||||
p_upper1 = max(palette_idx + 2, 16)
|
||||
p_upper2 = min(palette_idx + 2.5, 16)
|
||||
# TODO: weight +/-1 and 0 bands higher
|
||||
# TODO: dynamically tune palette cuts
|
||||
palette_pixels = np.concatenate(
|
||||
[
|
||||
self._colours_cam[
|
||||
int(p_lower2 * (200 / 16)) * 320:int(p_upper2 * (
|
||||
200 / 16)) * 320, :],
|
||||
# self._colours_cam[
|
||||
# int(p_lower1 * (200 / 16)) * 320:int(p_upper1 * (
|
||||
# 200 / 16)) * 320, :],
|
||||
# self._colours_cam[
|
||||
# int(p_lower0 * (200 / 16)) * 320:int(p_upper0 * (
|
||||
# 200 / 16)) * 320, :],
|
||||
], axis=0)
|
||||
|
||||
best_wce, best_medians = self._best_palette_distances[palette_idx]
|
||||
# if palette_idx == 0:
|
||||
# initial_centers = kmeans_plusplus_initializer(
|
||||
# palette_pixels, 16).initialize()
|
||||
# else:
|
||||
# initial_centers = kmedians_instance.get_medians()
|
||||
|
||||
# kmedians_instance = kmeans(
|
||||
# palette_pixels, initial_centers, tolerance=0.0000000001,
|
||||
# itermax=100,
|
||||
# metric=distance_metric(type_metric.EUCLIDEAN_SQUARE))
|
||||
# kmedians_instance.process()
|
||||
# TODO: tolerance
|
||||
clusters = cluster.MiniBatchKMeans(
|
||||
n_clusters=16, max_iter=10000, init=self._global_palette,
|
||||
n_init=1)
|
||||
# tol=0.0000000001, algorithm="elkan")
|
||||
clusters.fit_predict(palette_pixels)
|
||||
# if kmedians_instance.get_total_wce() < best_wce:
|
||||
# best_wce = kmedians_instance.get_total_wce()
|
||||
# best_medians = kmedians_instance
|
||||
if clusters.inertia_ < (best_wce * 0.99):
|
||||
best_wce = clusters.inertia_
|
||||
print("Improved palette %d: %f" % (palette_idx, best_wce))
|
||||
|
||||
# self._palettes_cam[palette_idx, :, :] = np.array(
|
||||
# best_medians.get_centers()).astype(np.float32)
|
||||
|
||||
self._palettes_cam[palette_idx, :, :] = np.array(
|
||||
clusters.cluster_centers_).astype(np.float32)
|
||||
self._best_palette_distances[palette_idx] = (
|
||||
best_wce, best_medians)
|
||||
|
||||
with colour.utilities.suppress_warnings(
|
||||
colour_usage_warnings=True):
|
||||
palette_rgb = colour.convert(
|
||||
self._palettes_cam[palette_idx], "CAM16UCS", "RGB")
|
||||
# SHR colour palette only uses 4-bit values
|
||||
palette_rgb = np.round(palette_rgb * 15) / 15
|
||||
self._palettes_rgb[palette_idx, :, :] = palette_rgb.astype(
|
||||
np.float32)
|
||||
|
||||
return self._palettes_cam, self._palettes_rgb
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -199,15 +187,24 @@ def main():
|
|||
gamma=args.gamma_correct, srgb_output=True)).astype(
|
||||
np.float32) / 255
|
||||
|
||||
palettes_cam, palettes_rgb = cluster_palette(rgb)
|
||||
# print(palette_rgb)
|
||||
# screen.set_palette(0, (image_py.linear_to_srgb_array(palette_rgb) *
|
||||
# 15).astype(np.uint8))
|
||||
for i in range(16):
|
||||
screen.set_palette(i, (np.round(palettes_rgb[i, :, :] * 15)).astype(
|
||||
np.uint8))
|
||||
penalty = 10 # 1e9
|
||||
iterations = 50
|
||||
|
||||
pygame.init()
|
||||
canvas = pygame.display.set_mode((640, 400))
|
||||
canvas = pygame.display.set_mode((640, 400))
|
||||
canvas.fill((0, 0, 0))
|
||||
pygame.display.flip()
|
||||
# print("Foo")
|
||||
|
||||
cluster_palette = ClusterPalette(rgb)
|
||||
for iteration in range(iterations):
|
||||
palettes_cam, palettes_rgb = cluster_palette.iterate()
|
||||
# print((palettes_rgb*255).astype(np.uint8))
|
||||
for i in range(16):
|
||||
screen.set_palette(i, (np.round(palettes_rgb[i, :, :] * 15)).astype(
|
||||
np.uint8))
|
||||
|
||||
for penalty in [1,2,3,4,5,6,7,8,9,10,1e9]:
|
||||
output_4bit, line_to_palette = dither_pyx.dither_shr(
|
||||
rgb, palettes_cam, palettes_rgb, rgb_to_cam16, float(penalty))
|
||||
screen.set_pixels(output_4bit)
|
||||
|
@ -233,21 +230,26 @@ def main():
|
|||
# output_srgb = image_py.linear_to_srgb(
|
||||
# output_screen.bitmap_to_image_rgb(bitmap)).astype(np.uint8)
|
||||
out_image = image_py.resize(
|
||||
Image.fromarray(output_srgb), screen.X_RES, screen.Y_RES,
|
||||
Image.fromarray(output_srgb), screen.X_RES * 2, screen.Y_RES * 2,
|
||||
srgb_output=False) # XXX true
|
||||
|
||||
if args.show_output:
|
||||
out_image.show()
|
||||
surface = pygame.surfarray.make_surface(np.asarray(
|
||||
out_image).transpose((1, 0, 2)))
|
||||
canvas.blit(surface, (0, 0))
|
||||
pygame.display.flip()
|
||||
|
||||
# Save Double hi-res image
|
||||
outfile = os.path.join(os.path.splitext(args.output)[0] + "-preview.png")
|
||||
out_image.save(outfile, "PNG")
|
||||
screen.pack()
|
||||
# with open(args.output, "wb") as f:
|
||||
# f.write(bytes(screen.aux))
|
||||
# f.write(bytes(screen.main))
|
||||
with open(args.output, "wb") as f:
|
||||
f.write(bytes(screen.memory))
|
||||
# Save Double hi-res image
|
||||
outfile = os.path.join(os.path.splitext(args.output)[0] +
|
||||
"-%d-preview.png" % cluster_palette._iterations)
|
||||
out_image.save(outfile, "PNG")
|
||||
screen.pack()
|
||||
# with open(args.output, "wb") as f:
|
||||
# f.write(bytes(screen.aux))
|
||||
# f.write(bytes(screen.main))
|
||||
with open("%s-%s" % (args.output, cluster_palette._iterations),
|
||||
"wb") as f:
|
||||
f.write(bytes(screen.memory))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue