When initializing centroids for fitting the SHR palettes, only use the
reserved colours from the global palette, and pick unique random points from the samples for the rest. This encourages a larger range of colours in the resulting images and may improve quality. Iterate a max number of times without improvement in the outer loop as well. Save intermediate preview outputs.
This commit is contained in:
parent
3b8767782b
commit
c36de2b76b
124
convert.py
124
convert.py
|
@ -65,7 +65,7 @@ class ClusterPalette:
|
|||
|
||||
# List of line ranges used to train the 16 SHR palettes
|
||||
# [(lower_0, upper_0), ...]
|
||||
self._palette_splits = self._equal_palette_splits()
|
||||
# self._palette_splits = self._equal_palette_splits()
|
||||
self._init_palette_lines()
|
||||
|
||||
# Whether the previous iteration of proposed palettes was accepted
|
||||
|
@ -78,10 +78,19 @@ class ClusterPalette:
|
|||
self._palette_mutate_delta = (0, 0)
|
||||
|
||||
def _init_palette_lines(self):
|
||||
for i, lh in enumerate(self._palette_splits):
|
||||
palette_splits = self._equal_palette_splits()
|
||||
for i, lh in enumerate(palette_splits):
|
||||
l, h = lh
|
||||
self._palette_lines[i].extend(list(range(l, h)))
|
||||
|
||||
# lines = list(range(200))
|
||||
# random.shuffle(lines)
|
||||
# idx = 0
|
||||
# while lines:
|
||||
# self._palette_lines[idx].append(lines.pop())
|
||||
# idx += 1
|
||||
|
||||
|
||||
def _image_colours_cam(self, image: Image):
|
||||
colours_rgb = np.asarray(image) # .reshape((-1, 3))
|
||||
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
|
||||
|
@ -134,18 +143,21 @@ class ClusterPalette:
|
|||
return (output_4bit, line_to_palette, palettes_linear_rgb,
|
||||
total_image_error)
|
||||
|
||||
def iterate(self, penalty: float, max_iterations: int):
|
||||
def iterate(self, penalty: float, max_inner_iterations: int,
|
||||
max_outer_iterations: int):
|
||||
total_image_error = 1e9
|
||||
last_good_splits = self._palette_splits
|
||||
while True:
|
||||
# last_good_splits = self._palette_splits
|
||||
|
||||
outer_iterations_since_improvement = 0
|
||||
while outer_iterations_since_improvement < max_outer_iterations:
|
||||
print("New iteration")
|
||||
iterations_since_improvement = 0
|
||||
self._palette_splits = self._equal_palette_splits()
|
||||
inner_iterations_since_improvement = 0
|
||||
# self._palette_splits = self._equal_palette_splits()
|
||||
self._init_palette_lines()
|
||||
|
||||
self._fit_global_palette()
|
||||
while iterations_since_improvement < max_iterations:
|
||||
# print("Iterations %d" % iterations_since_improvement)
|
||||
while inner_iterations_since_improvement < max_inner_iterations:
|
||||
# print("Iterations %d" % inner_iterations_since_improvement)
|
||||
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
|
||||
self._propose_palettes())
|
||||
|
||||
|
@ -155,17 +167,19 @@ class ClusterPalette:
|
|||
new_total_image_error) = self._dither_image(
|
||||
new_palettes_cam, penalty)
|
||||
|
||||
self._reassign_unused_palettes(line_to_palette,
|
||||
last_good_splits)
|
||||
# TODO: check for duplicate palettes and unused colours
|
||||
# within a palette
|
||||
self._reassign_unused_palettes(line_to_palette)
|
||||
|
||||
# print(total_image_error, new_total_image_error)
|
||||
if new_total_image_error >= total_image_error:
|
||||
iterations_since_improvement += 1
|
||||
inner_iterations_since_improvement += 1
|
||||
continue
|
||||
|
||||
# We found a globally better set of palettes
|
||||
iterations_since_improvement = 0
|
||||
last_good_splits = self._palette_splits
|
||||
inner_iterations_since_improvement = 0
|
||||
outer_iterations_since_improvement = -1
|
||||
# last_good_splits = self._palette_splits
|
||||
total_image_error = new_total_image_error
|
||||
|
||||
self._palettes_cam = new_palettes_cam
|
||||
|
@ -175,6 +189,7 @@ class ClusterPalette:
|
|||
|
||||
yield (new_total_image_error, output_4bit, line_to_palette,
|
||||
new_palettes_rgb12_iigs, palettes_linear_rgb)
|
||||
outer_iterations_since_improvement += 1
|
||||
|
||||
def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
|
||||
"""Attempt to find new palettes that locally improve image quality.
|
||||
|
@ -202,7 +217,7 @@ class ClusterPalette:
|
|||
# individual palettes
|
||||
self._fit_global_palette()
|
||||
|
||||
self._mutate_palette_splits()
|
||||
# self._mutate_palette_splits()
|
||||
for palette_idx in range(16):
|
||||
# print(palette_idx, self._palette_lines[palette_idx])
|
||||
# palette_lower, palette_upper = self._palette_splits[palette_idx]
|
||||
|
@ -210,11 +225,28 @@ class ClusterPalette:
|
|||
self._colours_cam[
|
||||
self._palette_lines[palette_idx], :, :].reshape(-1, 3))
|
||||
|
||||
initial_centroids = self._global_palette
|
||||
pixels_rgb_iigs = dither_pyx.convert_cam16ucs_to_rgb12_iigs(
|
||||
palette_pixels)
|
||||
seen_colours = set()
|
||||
for i in range(self._reserved_colours):
|
||||
seen_colours.add(tuple(initial_centroids[i, :]))
|
||||
for i in range(self._reserved_colours, 16):
|
||||
choice = np.random.randint(0, pixels_rgb_iigs.shape[
|
||||
0])
|
||||
new_colour = pixels_rgb_iigs[choice, :]
|
||||
if tuple(new_colour) in seen_colours:
|
||||
# print("Skipping")
|
||||
continue
|
||||
seen_colours.add(tuple(new_colour))
|
||||
# print(i, choice)
|
||||
initial_centroids[i, :] = new_colour
|
||||
|
||||
palettes_rgb12_iigs, palette_error = \
|
||||
dither_pyx.k_means_with_fixed_centroids(
|
||||
n_clusters=16, n_fixed=self._reserved_colours,
|
||||
samples=palette_pixels,
|
||||
initial_centroids=self._global_palette,
|
||||
initial_centroids=initial_centroids,
|
||||
max_iterations=1000, tolerance=0.05,
|
||||
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs
|
||||
)
|
||||
|
@ -260,50 +292,7 @@ class ClusterPalette:
|
|||
clusters.cluster_centers_[frequency_order].astype(
|
||||
np.float32)))
|
||||
|
||||
def _mutate_palette_splits(self):
|
||||
if self._palettes_accepted:
|
||||
# Last time was good, keep going
|
||||
self._apply_palette_delta(self._palette_mutate_idx,
|
||||
self._palette_mutate_delta[0],
|
||||
self._palette_mutate_delta[1])
|
||||
else:
|
||||
# undo last mutation
|
||||
self._apply_palette_delta(self._palette_mutate_idx,
|
||||
-self._palette_mutate_delta[0],
|
||||
-self._palette_mutate_delta[1])
|
||||
|
||||
# Pick a palette endpoint to move up or down
|
||||
palette_to_mutate = np.random.randint(0, 16)
|
||||
while True:
|
||||
if palette_to_mutate > 0:
|
||||
palette_lower_delta = np.random.randint(-20, 21)
|
||||
else:
|
||||
palette_lower_delta = 0
|
||||
if palette_to_mutate < 15:
|
||||
palette_upper_delta = np.random.randint(-20, 21)
|
||||
else:
|
||||
palette_upper_delta = 0
|
||||
if palette_lower_delta != 0 or palette_upper_delta != 0:
|
||||
break
|
||||
|
||||
self._apply_palette_delta(palette_to_mutate, palette_lower_delta,
|
||||
palette_upper_delta)
|
||||
|
||||
def _apply_palette_delta(
|
||||
self, palette_to_mutate, palette_lower_delta, palette_upper_delta):
|
||||
old_lower, old_upper = self._palette_splits[palette_to_mutate]
|
||||
new_lower = old_lower + palette_lower_delta
|
||||
new_upper = old_upper + palette_upper_delta
|
||||
|
||||
new_lower = np.clip(new_lower, 0, np.clip(new_upper, 1, 200) - 1)
|
||||
new_upper = np.clip(new_upper, new_lower + 1, 200)
|
||||
assert new_lower >= 0, new_upper - 1
|
||||
|
||||
self._palette_splits[palette_to_mutate] = (new_lower, new_upper)
|
||||
self._palette_mutate_idx = palette_to_mutate
|
||||
self._palette_mutate_delta = (palette_lower_delta, palette_upper_delta)
|
||||
|
||||
def _reassign_unused_palettes(self, new_line_to_palette, last_good_splits):
|
||||
def _reassign_unused_palettes(self, new_line_to_palette):
|
||||
palettes_used = [False] * 16
|
||||
for palette in new_line_to_palette:
|
||||
palettes_used[palette] = True
|
||||
|
@ -421,7 +410,8 @@ def main():
|
|||
|
||||
# TODO: flags
|
||||
penalty = 1 # 1e18 # TODO: is this needed any more?
|
||||
iterations = 10 # 20
|
||||
inner_iterations = 10 # 20
|
||||
outer_iterations = 20
|
||||
|
||||
pygame.init()
|
||||
# TODO: for some reason I need to execute this twice - the first time
|
||||
|
@ -441,7 +431,7 @@ def main():
|
|||
seq = 0
|
||||
for (new_total_image_error, output_4bit, line_to_palette,
|
||||
palettes_rgb12_iigs, palettes_linear_rgb) in cluster_palette.iterate(
|
||||
penalty, iterations):
|
||||
penalty, inner_iterations, outer_iterations):
|
||||
|
||||
if total_image_error is not None:
|
||||
print("Improved quality +%f%% (%f)" % (
|
||||
|
@ -486,6 +476,11 @@ def main():
|
|||
canvas.blit(surface, (0, 0))
|
||||
pygame.display.flip()
|
||||
|
||||
# print((palettes_rgb * 255).astype(np.uint8))
|
||||
unique_colours = np.unique(
|
||||
palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0]
|
||||
print("%d unique colours" % unique_colours)
|
||||
|
||||
seq += 1
|
||||
# Save Double hi-res image
|
||||
outfile = os.path.join(
|
||||
|
@ -498,11 +493,6 @@ def main():
|
|||
with open(args.output, "wb") as f:
|
||||
f.write(bytes(screen.memory))
|
||||
|
||||
# print((palettes_rgb * 255).astype(np.uint8))
|
||||
unique_colours = np.unique(
|
||||
palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0]
|
||||
print("%d unique colours" % unique_colours)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue