mirror of
https://github.com/robmcmullen/atrcopy.git
synced 2024-05-31 18:41:29 +00:00
Added segment interleaving support
This commit is contained in:
parent
0f209cf50a
commit
ea9466cda7
|
@ -11,7 +11,7 @@ from errors import *
|
||||||
from ataridos import AtariDosDiskImage, AtariDosFile, get_xex
|
from ataridos import AtariDosDiskImage, AtariDosFile, get_xex
|
||||||
from diskimages import AtrHeader, BootDiskImage, add_atr_header
|
from diskimages import AtrHeader, BootDiskImage, add_atr_header
|
||||||
from kboot import KBootImage, add_xexboot_header
|
from kboot import KBootImage, add_xexboot_header
|
||||||
from segments import SegmentData, SegmentSaver, DefaultSegment, EmptySegment, ObjSegment, RawSectorsSegment, user_bit_mask, match_bit_mask, comment_bit_mask, data_bit_mask, selected_bit_mask, diff_bit_mask, not_user_bit_mask
|
from segments import SegmentData, SegmentSaver, DefaultSegment, EmptySegment, ObjSegment, RawSectorsSegment, user_bit_mask, match_bit_mask, comment_bit_mask, data_bit_mask, selected_bit_mask, diff_bit_mask, not_user_bit_mask, interleave_segments
|
||||||
from spartados import SpartaDosDiskImage
|
from spartados import SpartaDosDiskImage
|
||||||
from utils import to_numpy
|
from utils import to_numpy
|
||||||
|
|
||||||
|
|
|
@ -134,6 +134,18 @@ class SegmentData(object):
|
||||||
base_start, base_end = np.byte_bounds(self.data.base)
|
base_start, base_end = np.byte_bounds(self.data.base)
|
||||||
return int(data_start - base_start + i)
|
return int(data_start - base_start + i)
|
||||||
|
|
||||||
|
def get_indexes_from_base(self):
|
||||||
|
"""Get array of indexes from the base array, as if this raw data were
|
||||||
|
indexed.
|
||||||
|
"""
|
||||||
|
if self.is_indexed:
|
||||||
|
return np.copy(self.order[i])
|
||||||
|
if self.data.base is None:
|
||||||
|
i = 0
|
||||||
|
else:
|
||||||
|
i = self.get_raw_index(0)
|
||||||
|
return np.arange(i, i + len(self), dtype=np.uint32)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
if self.is_indexed:
|
if self.is_indexed:
|
||||||
order = self.data.sub_index(index)
|
order = self.data.sub_index(index)
|
||||||
|
@ -146,6 +158,15 @@ class SegmentData(object):
|
||||||
c = self.comments
|
c = self.comments
|
||||||
return SegmentData(d, s, c, order=order)
|
return SegmentData(d, s, c, order=order)
|
||||||
|
|
||||||
|
def get_bases(self):
|
||||||
|
if self.data.base is None:
|
||||||
|
data_base = self.data
|
||||||
|
style_base = self.style
|
||||||
|
else:
|
||||||
|
data_base = self.data.base
|
||||||
|
style_base = self.style.base
|
||||||
|
return data_base, style_base
|
||||||
|
|
||||||
def get_indexed(self, index):
|
def get_indexed(self, index):
|
||||||
index = to_numpy_list(index)
|
index = to_numpy_list(index)
|
||||||
if self.is_indexed:
|
if self.is_indexed:
|
||||||
|
@ -156,12 +177,7 @@ class SegmentData(object):
|
||||||
|
|
||||||
# index needs to be relative to the base array
|
# index needs to be relative to the base array
|
||||||
base_index = index + self.get_raw_index(0)
|
base_index = index + self.get_raw_index(0)
|
||||||
if self.data.base is None:
|
data_base, style_base = self.get_bases()
|
||||||
data_base = self.data
|
|
||||||
style_base = self.style
|
|
||||||
else:
|
|
||||||
data_base = self.data.base
|
|
||||||
style_base = self.style.base
|
|
||||||
return SegmentData(data_base, style_base, self.comments, order=base_index)
|
return SegmentData(data_base, style_base, self.comments, order=base_index)
|
||||||
|
|
||||||
def get_reverse_index(self, base_index):
|
def get_reverse_index(self, base_index):
|
||||||
|
@ -587,3 +603,33 @@ class RawSectorsSegment(DefaultSegment):
|
||||||
if lower_case:
|
if lower_case:
|
||||||
return "s%03d:%02x" % (sector + self.first_sector, byte)
|
return "s%03d:%02x" % (sector + self.first_sector, byte)
|
||||||
return "s%03d:%02X" % (sector + self.first_sector, byte)
|
return "s%03d:%02X" % (sector + self.first_sector, byte)
|
||||||
|
|
||||||
|
def interleave_indexes(segments, num_bytes):
|
||||||
|
num_segments = len(segments)
|
||||||
|
size = len(segments[0])
|
||||||
|
for s in segments[1:]:
|
||||||
|
if size != len(s):
|
||||||
|
raise ValueError("All segments to interleave must be the same size")
|
||||||
|
_, rem = divmod(size, num_bytes)
|
||||||
|
if rem != 0:
|
||||||
|
raise ValueError("Segment size must be a multiple of the byte interleave")
|
||||||
|
interleave = np.empty(size * num_segments, dtype=np.uint32)
|
||||||
|
factor = num_bytes * num_segments
|
||||||
|
start = 0
|
||||||
|
for s in segments:
|
||||||
|
order = s.rawdata.get_indexes_from_base()
|
||||||
|
for i in range(num_bytes):
|
||||||
|
interleave[start::factor] = order[i::num_bytes]
|
||||||
|
start += 1
|
||||||
|
return interleave
|
||||||
|
|
||||||
|
def interleave_segments(segments, num_bytes):
|
||||||
|
new_index = interleave_indexes(segments, num_bytes)
|
||||||
|
data_base, style_base = segments[0].rawdata.get_bases()
|
||||||
|
for s in segments[1:]:
|
||||||
|
d, s = s.rawdata.get_bases()
|
||||||
|
if id(d) != id(data_base) or id(s) != id(style_base):
|
||||||
|
raise ValueError("Can't interleave segments with different base arrays")
|
||||||
|
raw = SegmentData(data_base, style_base, segments[0].rawdata.comments, order=new_index)
|
||||||
|
segment = DefaultSegment(raw, 0)
|
||||||
|
return segment
|
||||||
|
|
|
@ -3,9 +3,15 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from atrcopy import DefaultSegment, SegmentData, get_xex
|
from atrcopy import DefaultSegment, SegmentData, get_xex, interleave_segments
|
||||||
|
|
||||||
|
|
||||||
|
def get_indexed(segment, num, scale):
|
||||||
|
indexes = np.arange(num) * scale
|
||||||
|
raw = segment.rawdata.get_indexed(indexes)
|
||||||
|
s = DefaultSegment(raw, segment.start_addr + indexes[0])
|
||||||
|
return s, indexes
|
||||||
|
|
||||||
class TestSegment1(object):
|
class TestSegment1(object):
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.segments = []
|
self.segments = []
|
||||||
|
@ -37,21 +43,15 @@ class TestIndexed(object):
|
||||||
r = SegmentData(data)
|
r = SegmentData(data)
|
||||||
self.segment = DefaultSegment(r, 0)
|
self.segment = DefaultSegment(r, 0)
|
||||||
|
|
||||||
def get_indexed(self, segment, num, scale):
|
|
||||||
indexes = np.arange(num) * scale
|
|
||||||
raw = segment.rawdata.get_indexed(indexes)
|
|
||||||
s = DefaultSegment(raw, segment.start_addr + indexes[0])
|
|
||||||
return s, indexes
|
|
||||||
|
|
||||||
def test_indexed(self):
|
def test_indexed(self):
|
||||||
assert not self.segment.rawdata.is_indexed
|
assert not self.segment.rawdata.is_indexed
|
||||||
s, indexes = self.get_indexed(self.segment, 1024, 3)
|
s, indexes = get_indexed(self.segment, 1024, 3)
|
||||||
assert s.rawdata.is_indexed
|
assert s.rawdata.is_indexed
|
||||||
for i in range(len(indexes)):
|
for i in range(len(indexes)):
|
||||||
assert s.get_raw_index(i) == indexes[i]
|
assert s.get_raw_index(i) == indexes[i]
|
||||||
|
|
||||||
# get indexed into indexed, will result in every 9th byte
|
# get indexed into indexed, will result in every 9th byte
|
||||||
s2, indexes2 = self.get_indexed(s, 256, 3)
|
s2, indexes2 = get_indexed(s, 256, 3)
|
||||||
assert s2.rawdata.is_indexed
|
assert s2.rawdata.is_indexed
|
||||||
for i in range(len(indexes2)):
|
for i in range(len(indexes2)):
|
||||||
assert s2.get_raw_index(i) == indexes2[i] * 3
|
assert s2.get_raw_index(i) == indexes2[i] * 3
|
||||||
|
@ -73,10 +73,10 @@ class TestIndexed(object):
|
||||||
|
|
||||||
with pytest.raises(IndexError) as e:
|
with pytest.raises(IndexError) as e:
|
||||||
# attempt to get indexes to 1024 * 3... Index to big => fail!
|
# attempt to get indexes to 1024 * 3... Index to big => fail!
|
||||||
s, indexes = self.get_indexed(sub, 1024, 3)
|
s, indexes = get_indexed(sub, 1024, 3)
|
||||||
|
|
||||||
# try with elements up to 256 * 3
|
# try with elements up to 256 * 3
|
||||||
s, indexes = self.get_indexed(sub, 256, 3)
|
s, indexes = get_indexed(sub, 256, 3)
|
||||||
print sub.data
|
print sub.data
|
||||||
print indexes
|
print indexes
|
||||||
print s.data[:]
|
print s.data[:]
|
||||||
|
@ -91,7 +91,7 @@ class TestIndexed(object):
|
||||||
assert end == len(base)
|
assert end == len(base)
|
||||||
|
|
||||||
# get indexed into indexed, will result in every 9th byte
|
# get indexed into indexed, will result in every 9th byte
|
||||||
s2, indexes2 = self.get_indexed(s, 64, 3)
|
s2, indexes2 = get_indexed(s, 64, 3)
|
||||||
assert s2.rawdata.is_indexed
|
assert s2.rawdata.is_indexed
|
||||||
for i in range(len(indexes2)):
|
for i in range(len(indexes2)):
|
||||||
assert s2.get_raw_index(i) == sub.start_addr + indexes2[i] * 3
|
assert s2.get_raw_index(i) == sub.start_addr + indexes2[i] * 3
|
||||||
|
@ -99,9 +99,65 @@ class TestIndexed(object):
|
||||||
assert start == 0
|
assert start == 0
|
||||||
assert end == len(base)
|
assert end == len(base)
|
||||||
|
|
||||||
|
def test_interleave(self):
|
||||||
|
base = self.segment
|
||||||
|
r1 = base.rawdata[512:1024] # 512 byte segment
|
||||||
|
s1 = DefaultSegment(r1, 512)
|
||||||
|
r2 = base.rawdata[1024:1536] # 512 byte segment
|
||||||
|
s2 = DefaultSegment(r2, 1024)
|
||||||
|
|
||||||
|
indexes1 = r1.get_indexes_from_base()
|
||||||
|
verify1 = np.arange(512, 1024, dtype=np.uint32)
|
||||||
|
assert np.array_equal(indexes1, verify1)
|
||||||
|
|
||||||
|
indexes2 = r2.get_indexes_from_base()
|
||||||
|
verify2 = np.arange(1024, 1536, dtype=np.uint32)
|
||||||
|
assert np.array_equal(indexes2, verify2)
|
||||||
|
|
||||||
|
s = interleave_segments([s1, s2], 2)
|
||||||
|
a = np.empty(len(s1) + len(s2), dtype=np.uint8)
|
||||||
|
a[0::4] = s1[0::2]
|
||||||
|
a[1::4] = s1[1::2]
|
||||||
|
a[2::4] = s2[0::2]
|
||||||
|
a[3::4] = s2[1::2]
|
||||||
|
print list(s[:])
|
||||||
|
print list(a[:])
|
||||||
|
print s.rawdata.order
|
||||||
|
assert np.array_equal(s[:], a)
|
||||||
|
|
||||||
|
s = interleave_segments([s1, s2], 4)
|
||||||
|
a = np.empty(len(s1) + len(s2), dtype=np.uint8)
|
||||||
|
a[0::8] = s1[0::4]
|
||||||
|
a[1::8] = s1[1::4]
|
||||||
|
a[2::8] = s1[2::4]
|
||||||
|
a[3::8] = s1[3::4]
|
||||||
|
a[4::8] = s2[0::4]
|
||||||
|
a[5::8] = s2[1::4]
|
||||||
|
a[6::8] = s2[2::4]
|
||||||
|
a[7::8] = s2[3::4]
|
||||||
|
assert np.array_equal(s[:], a)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as e:
|
||||||
|
s = interleave_segments([s1, s2], 3)
|
||||||
|
|
||||||
|
r1 = base.rawdata[512:1025] # 513 byte segment
|
||||||
|
s1 = DefaultSegment(r1, 512)
|
||||||
|
r2 = base.rawdata[1024:1537] # 513 byte segment
|
||||||
|
s2 = DefaultSegment(r2, 1024)
|
||||||
|
s = interleave_segments([s1, s2], 3)
|
||||||
|
a = np.empty(len(s1) + len(s2), dtype=np.uint8)
|
||||||
|
a[0::6] = s1[0::3]
|
||||||
|
a[1::6] = s1[1::3]
|
||||||
|
a[2::6] = s1[2::3]
|
||||||
|
a[3::6] = s2[0::3]
|
||||||
|
a[4::6] = s2[1::3]
|
||||||
|
a[5::6] = s2[2::3]
|
||||||
|
assert np.array_equal(s[:], a)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
t = TestIndexed()
|
t = TestIndexed()
|
||||||
t.setup()
|
t.setup()
|
||||||
t.test_indexed()
|
t.test_indexed()
|
||||||
t.test_indexed_sub()
|
t.test_indexed_sub()
|
||||||
|
t.test_interleave()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user