Whoops, fix a major bug with the iterated image fitting: we don't want

to mutate our source image!

Fix another bug introduced in the previous commit: convert from linear
rgb before quantizing //gs RGB palette since //gs RGB values are in
Rec.601 colour space.

Switch to double for colour_squared_distance and related variables,
not sure if it matters though.

When iterating palette clustering, reject the new palettes if they
would increase the total image error.  This prevents accepting changes
that are local improvements to one palette but which would introduce
more net errors elsewhere when this palette is reused.

This now seems to give monotonic improvements in image quality so no need
to write out intermediate images any more.
This commit is contained in:
kris 2021-11-16 15:44:04 +00:00
parent 8694ab364e
commit 83b047b73f
2 changed files with 57 additions and 39 deletions

View File

@ -24,10 +24,8 @@ class ClusterPalette:
def __init__(self, image: Image):
self._colours_cam = self._image_colours_cam(image)
self._best_palette_distances = [1e9] * 16
self._iterations = 0
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.float32)
self._global_palette = self._fit_global_palette()
def _image_colours_cam(self, image: Image):
colours_rgb = np.asarray(image).reshape((-1, 3))
@ -46,16 +44,15 @@ class ClusterPalette:
return clusters.cluster_centers_
def iterate(self):
self._iterations += 1
print("Iteration %d" % self._iterations)
self._global_palette = self._fit_global_palette()
for palette_idx in range(16):
p_lower = max(palette_idx - 1.5, 0)
p_upper = min(palette_idx + 2.5, 16)
# TODO: dynamically tune palette cuts
palette_pixels = self._colours_cam[
int(p_lower * (200 / 16)) * 320:int(p_upper * (
200 / 16)) * 320, :]
int(p_lower * (200 / 16)) * 320:int(p_upper * (
200 / 16)) * 320, :]
best_wce = self._best_palette_distances[palette_idx]
# TODO: tolerance
@ -63,12 +60,7 @@ class ClusterPalette:
n_clusters=16, max_iter=10000, init=self._global_palette,
n_init=1)
clusters.fit_predict(palette_pixels)
if clusters.inertia_ < (best_wce * 0.99):
# TODO: sentinel
if best_wce < 1e9:
print("Improved palette %d (+%f%%)" % (
palette_idx, best_wce / clusters.inertia_))
if clusters.inertia_ < best_wce:
self._palettes_cam[palette_idx, :, :] = np.array(
clusters.cluster_centers_).astype(np.float32)
best_wce = clusters.inertia_
@ -147,26 +139,51 @@ def main():
gamma=args.gamma_correct)).astype(np.float32) / 255
# TODO: flags
penalty = 10 # 1e9
iterations = 50
penalty = 1e9 # 0 # 1e9
iterations = 50 # 0
pygame.init()
# TODO: for some reason I need to execute this twice - the first time
# the window is created and immediately destroyed
pygame.display.set_mode((640, 400))
_ = pygame.display.set_mode((640, 400))
canvas = pygame.display.set_mode((640, 400))
canvas.fill((0, 0, 0))
pygame.display.flip()
total_image_error = 1e9
cluster_palette = ClusterPalette(rgb)
image_generation = 0
for iteration in range(iterations):
palettes_cam, palettes_rgb = cluster_palette.iterate()
for i in range(16):
screen.set_palette(i, (np.round(palettes_rgb[i, :, :] * 15)).astype(
np.uint8))
old_best_palette_distances = cluster_palette._best_palette_distances
old_palettes_cam = cluster_palette._palettes_cam
old_palettes_rgb = cluster_palette._palettes_rgb
output_4bit, line_to_palette = dither_pyx.dither_shr(
rgb, palettes_cam, palettes_rgb, rgb_to_cam16, float(penalty))
new_palettes_cam, new_palettes_rgb = cluster_palette.iterate()
output_4bit, line_to_palette, new_total_image_error = \
dither_pyx.dither_shr(
rgb, new_palettes_cam, new_palettes_rgb, rgb_to_cam16,
float(penalty)
)
if new_total_image_error < total_image_error:
if total_image_error < 1e9:
print("Improved quality +%f%% (%f)" % (
(1 - new_total_image_error / total_image_error) * 100,
new_total_image_error))
total_image_error = new_total_image_error
palettes_rgb = new_palettes_rgb
else:
cluster_palette._palettes_cam = old_palettes_cam
cluster_palette._palettes_rgb = old_palettes_rgb
cluster_palette._best_palette_distances = old_best_palette_distances
continue
image_generation += 1
for i in range(16):
screen.set_palette(i, (
np.round(image_py.linear_to_srgb(palettes_rgb[i, :,
:] * 255) / 255 * 15)).astype(
np.uint8))
screen.set_pixels(output_4bit)
output_rgb = np.empty((200, 320, 3), dtype=np.uint8)
for i in range(200):
@ -199,17 +216,15 @@ def main():
canvas.blit(surface, (0, 0))
pygame.display.flip()
# Save Double hi-res image
outfile = os.path.join(os.path.splitext(args.output)[0] +
"-%d-preview.png" % cluster_palette._iterations)
out_image.save(outfile, "PNG")
screen.pack()
# with open(args.output, "wb") as f:
# f.write(bytes(screen.aux))
# f.write(bytes(screen.main))
with open("%s-%s" % (args.output, cluster_palette._iterations),
"wb") as f:
f.write(bytes(screen.memory))
# Save Double hi-res image
outfile = os.path.join(os.path.splitext(args.output)[0] + "-preview.png")
out_image.save(outfile, "PNG")
screen.pack()
# with open(args.output, "wb") as f:
# f.write(bytes(screen.aux))
# f.write(bytes(screen.main))
with open(args.output, "wb") as f:
f.write(bytes(screen.memory))
if __name__ == "__main__":

View File

@ -171,7 +171,7 @@ cdef inline float fabs(float value) nogil:
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline float colour_distance_squared(float[::1] colour1, float[::1] colour2) nogil:
cdef inline double colour_distance_squared(float[::1] colour1, float[::1] colour2) nogil:
return (colour1[0] - colour2[0]) ** 2 + (colour1[1] - colour2[1]) ** 2 + (colour1[2] - colour2[2]) ** 2
@ -339,20 +339,21 @@ import colour
@cython.boundscheck(False)
@cython.wraparound(False)
def dither_shr(float[:, :, ::1] working_image, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb, float[:,::1] rgb_to_cam16ucs, float penalty):
def dither_shr(float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb, float[:,::1] rgb_to_cam16ucs, float penalty):
cdef int y, x, idx, best_colour_idx, best_palette
cdef float best_distance, distance
cdef double best_distance, distance, total_image_error
cdef float[::1] best_colour_rgb, pixel_cam, colour_rgb, colour_cam
cdef float quant_error
cdef float[:, ::1] palette_rgb
cdef (unsigned char)[:, ::1] output_4bit = np.zeros((200, 320), dtype=np.uint8)
# cdef (unsigned char)[:, :, ::1] output_rgb = np.zeros((200, 320, 3), dtype=np.uint8)
cdef float[:, :, ::1] working_image = np.copy(input_rgb)
cdef float[:, ::1] line_cam = np.zeros((320, 3), dtype=np.float32)
cdef int[::1] line_to_palette = np.zeros(200, dtype=np.int32)
best_palette = 15
total_image_error = 0.0
for y in range(200):
# print(y)
for x in range(320):
@ -380,6 +381,8 @@ def dither_shr(float[:, :, ::1] working_image, float[:, :, ::1] palettes_cam, fl
best_colour_idx = idx
best_colour_rgb = palette_rgb[best_colour_idx]
output_4bit[y, x] = best_colour_idx
total_image_error += best_distance
# print(y,x,best_distance,total_image_error)
for i in range(3):
quant_error = working_image[y, x, i] - best_colour_rgb[i]
@ -449,7 +452,7 @@ def dither_shr(float[:, :, ::1] working_image, float[:, :, ::1] palettes_cam, fl
# working_image[y + 2, x + 2, i] + quant_error * (1 / 48),
# 0, 1)
return np.array(output_4bit, dtype=np.uint8), line_to_palette
return np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error
import collections
import random
@ -511,7 +514,7 @@ def k_means_with_fixed_centroids(
@cython.wraparound(False)
cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int base_palette_idx, int last_palette_idx, float last_penalty) nogil:
cdef int palette_idx, best_palette_idx, palette_entry_idx, pixel_idx
cdef float best_total_dist, total_dist, best_pixel_dist, pixel_dist
cdef double best_total_dist, total_dist, best_pixel_dist, pixel_dist
cdef float[:, ::1] palette_cam
cdef float[::1] pixel_cam, palette_entry