WIP - interleave 3 successive palettes for each contiguous row range.
Avoids the banding but not clear if it's overall better Also implement my own k-means clustering which is able to keep some centroids fixed, e.g. to be able to retain some fixed palette entries while swapping out others. I was hoping this would improve colour blending across neighbouring palettes but it's also not clear if it does.
This commit is contained in:
parent
322123522c
commit
8c34d87216
51
convert.py
51
convert.py
|
@ -23,13 +23,29 @@ import screen as screen_py
|
|||
# - support LR/DLR
|
||||
# - support HGR
|
||||
|
||||
|
||||
def cluster_palette(image: Image):
|
||||
shuffle_lines = list(range(200))
|
||||
random.shuffle(shuffle_lines)
|
||||
line_to_palette = {}
|
||||
for idx, line in enumerate(shuffle_lines):
|
||||
line_to_palette[line] = idx % 16
|
||||
|
||||
#shuffle_lines = liprint(st(range(200))
|
||||
#random.shuffle(shuffle_lines)
|
||||
#for idx, line in enumerate(shuffle_lines):
|
||||
# line_to_palette[line] = idx % 16
|
||||
|
||||
# for line in range(200):
|
||||
# if line % 3 == 0:
|
||||
# line_to_palette[line] = int(line / (200 / 16))
|
||||
# elif line % 3 == 1:
|
||||
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 1, 0, 15)
|
||||
# else:
|
||||
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 2, 0, 15)
|
||||
|
||||
for line in range(200):
|
||||
if line % 3 == 0:
|
||||
line_to_palette[line] = int(line / (200 / 16))
|
||||
elif line % 3 == 1:
|
||||
line_to_palette[line] = np.clip(int(line / (200 / 16)) + 1, 0, 15)
|
||||
else:
|
||||
line_to_palette[line] = np.clip(int(line / (200 / 16)) + 2, 0, 15)
|
||||
|
||||
colours_rgb = np.asarray(image).reshape((-1, 3))
|
||||
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
|
||||
|
@ -43,16 +59,25 @@ def cluster_palette(image: Image):
|
|||
palette_colours[palette].extend(
|
||||
colours_cam[line * 320:(line + 1) * 320])
|
||||
|
||||
for palette in range(16):
|
||||
kmeans = KMeans(n_clusters=16, max_iter=10000)
|
||||
kmeans.fit_predict(palette_colours[palette])
|
||||
palette_cam = kmeans.cluster_centers_
|
||||
# For each line grouping, find big palette entries with minimal total
|
||||
# distance
|
||||
|
||||
palette_cam = None
|
||||
for palette_idx in range(16):
|
||||
line_colours = palette_colours[palette_idx]
|
||||
# if palette_idx > 0:
|
||||
# fixed_centroids = palette_cam[:8, :]
|
||||
# else:
|
||||
fixed_centroids = None
|
||||
# print(np.array(line_colours), fixed_centroids)
|
||||
palette_cam = dither_pyx.k_means_with_fixed_centroids(16, np.array(
|
||||
line_colours), fixed_centroids=fixed_centroids, tolerance=1e-6)
|
||||
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
|
||||
palette_rgb = colour.convert(palette_cam, "CAM16UCS", "RGB")
|
||||
# SHR colour palette only uses 4-bit values
|
||||
palette_rgb = np.round(palette_rgb * 15) / 15
|
||||
# palette_rgb = palette_rgb.astype(np.float32) / 255
|
||||
palettes_rgb[palette] = palette_rgb.astype(np.float32)
|
||||
palettes_rgb[palette_idx] = palette_rgb.astype(np.float32)
|
||||
# print(palettes_rgb)
|
||||
return palettes_rgb, line_to_palette
|
||||
|
||||
|
||||
|
@ -130,8 +155,8 @@ def main():
|
|||
for i in range(200):
|
||||
screen.line_palette[i] = line_to_palette[i]
|
||||
output_rgb[i, :, :] = (
|
||||
palettes_rgb[line_to_palette[i]][
|
||||
output_4bit[i, :]] * 255).astype(np.uint8)
|
||||
palettes_rgb[line_to_palette[i]][
|
||||
output_4bit[i, :]] * 255).astype(np.uint8)
|
||||
output_srgb = image_py.linear_to_srgb(output_rgb).astype(np.uint8)
|
||||
|
||||
# dither = dither_pattern.PATTERNS[args.dither]()
|
||||
|
|
57
dither.pyx
57
dither.pyx
|
@ -425,4 +425,59 @@ def dither_shr(float[:, :, ::1] working_image, object palettes_rgb, float[:,::1]
|
|||
# working_image[y + 2, x + 2, i] + quant_error[i] * (1 / 48),
|
||||
# 0, 1)
|
||||
|
||||
return np.array(output_4bit, dtype=np.uint8)
|
||||
return np.array(output_4bit, dtype=np.uint8)
|
||||
|
||||
import collections
|
||||
import random
|
||||
|
||||
def k_means_with_fixed_centroids(
|
||||
int n_clusters, float[:, ::1] data, float[:, ::1] fixed_centroids = None,
|
||||
int iterations = 10000, float tolerance = 1e-3):
|
||||
cdef int i, iteration, centroid_idx, num_fixed_centroids, num_random_centroids, best_centroid_idx
|
||||
cdef float[::1] point, centroid, new_centroid, old_centroid
|
||||
cdef float[:, ::1] centroids
|
||||
cdef float best_dist, centroid_movement
|
||||
|
||||
centroids = np.zeros((n_clusters, 3), dtype=np.float32)
|
||||
if fixed_centroids is not None:
|
||||
centroids[:fixed_centroids.shape[0], :] = fixed_centroids
|
||||
num_fixed_centroids = fixed_centroids.shape[0] if fixed_centroids is not None else 0
|
||||
num_random_centroids = n_clusters - num_fixed_centroids
|
||||
|
||||
# TODO: kmeans++ initialization
|
||||
for i in range(num_random_centroids):
|
||||
centroids[num_fixed_centroids + i, :] = data[
|
||||
random.randint(0, data.shape[0]), :]
|
||||
|
||||
cdef int[::1] centroid_weights = np.zeros(n_clusters, dtype=np.int32)
|
||||
for iteration in range(iterations):
|
||||
# print("centroids ", centroids)
|
||||
closest_points = collections.defaultdict(list)
|
||||
for point in data:
|
||||
best_dist = 1e9
|
||||
best_centroid_idx = 0
|
||||
for centroid_idx, centroid in enumerate(centroids):
|
||||
dist = colour_distance_squared(centroid, point)
|
||||
if dist < best_dist:
|
||||
best_dist = dist
|
||||
best_centroid_idx = centroid_idx
|
||||
closest_points[best_centroid_idx].append(point)
|
||||
|
||||
centroid_movement = 0
|
||||
|
||||
for centroid_idx, points in closest_points.items():
|
||||
centroid_weights[centroid_idx] = len(points)
|
||||
if centroid_idx < num_fixed_centroids:
|
||||
continue
|
||||
new_centroid = np.mean(np.array(points), axis=0)
|
||||
old_centroid = centroids[centroid_idx]
|
||||
centroid_movement += colour_distance_squared(old_centroid, new_centroid)
|
||||
centroids[centroid_idx, :] = new_centroid
|
||||
# print("iteration %d: movement %f" % (iteration, centroid_movement))
|
||||
if centroid_movement < tolerance:
|
||||
break
|
||||
|
||||
weighted_centroids = list(zip(centroid_weights, [tuple(c) for c in centroids]))
|
||||
print(weighted_centroids)
|
||||
return np.array([c for w, c in sorted(weighted_centroids, reverse=True)], dtype=np.float32)
|
||||
|
||||
|
|
Loading…
Reference in New Issue