From ea9466cda7b2fd58b7c5dcf06d9ec5b7db8eaf6c Mon Sep 17 00:00:00 2001 From: Rob McMullen Date: Tue, 17 May 2016 14:53:52 -0700 Subject: [PATCH] Added segment interleaving support --- atrcopy/__init__.py | 2 +- atrcopy/segments.py | 58 ++++++++++++++++++++++++++++---- test/test_segment.py | 80 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 121 insertions(+), 19 deletions(-) diff --git a/atrcopy/__init__.py b/atrcopy/__init__.py index c26e7f5..92536dd 100644 --- a/atrcopy/__init__.py +++ b/atrcopy/__init__.py @@ -11,7 +11,7 @@ from errors import * from ataridos import AtariDosDiskImage, AtariDosFile, get_xex from diskimages import AtrHeader, BootDiskImage, add_atr_header from kboot import KBootImage, add_xexboot_header -from segments import SegmentData, SegmentSaver, DefaultSegment, EmptySegment, ObjSegment, RawSectorsSegment, user_bit_mask, match_bit_mask, comment_bit_mask, data_bit_mask, selected_bit_mask, diff_bit_mask, not_user_bit_mask +from segments import SegmentData, SegmentSaver, DefaultSegment, EmptySegment, ObjSegment, RawSectorsSegment, user_bit_mask, match_bit_mask, comment_bit_mask, data_bit_mask, selected_bit_mask, diff_bit_mask, not_user_bit_mask, interleave_segments from spartados import SpartaDosDiskImage from utils import to_numpy diff --git a/atrcopy/segments.py b/atrcopy/segments.py index 160855e..7ba2cf8 100644 --- a/atrcopy/segments.py +++ b/atrcopy/segments.py @@ -134,6 +134,18 @@ class SegmentData(object): base_start, base_end = np.byte_bounds(self.data.base) return int(data_start - base_start + i) + def get_indexes_from_base(self): + """Get array of indexes from the base array, as if this raw data were + indexed. + """ + if self.is_indexed: + return np.copy(self.order[i]) + if self.data.base is None: + i = 0 + else: + i = self.get_raw_index(0) + return np.arange(i, i + len(self), dtype=np.uint32) + def __getitem__(self, index): if self.is_indexed: order = self.data.sub_index(index) @@ -146,6 +158,15 @@ class SegmentData(object): c = self.comments return SegmentData(d, s, c, order=order) + def get_bases(self): + if self.data.base is None: + data_base = self.data + style_base = self.style + else: + data_base = self.data.base + style_base = self.style.base + return data_base, style_base + def get_indexed(self, index): index = to_numpy_list(index) if self.is_indexed: @@ -156,12 +177,7 @@ class SegmentData(object): # index needs to be relative to the base array base_index = index + self.get_raw_index(0) - if self.data.base is None: - data_base = self.data - style_base = self.style - else: - data_base = self.data.base - style_base = self.style.base + data_base, style_base = self.get_bases() return SegmentData(data_base, style_base, self.comments, order=base_index) def get_reverse_index(self, base_index): @@ -587,3 +603,33 @@ class RawSectorsSegment(DefaultSegment): if lower_case: return "s%03d:%02x" % (sector + self.first_sector, byte) return "s%03d:%02X" % (sector + self.first_sector, byte) + +def interleave_indexes(segments, num_bytes): + num_segments = len(segments) + size = len(segments[0]) + for s in segments[1:]: + if size != len(s): + raise ValueError("All segments to interleave must be the same size") + _, rem = divmod(size, num_bytes) + if rem != 0: + raise ValueError("Segment size must be a multiple of the byte interleave") + interleave = np.empty(size * num_segments, dtype=np.uint32) + factor = num_bytes * num_segments + start = 0 + for s in segments: + order = s.rawdata.get_indexes_from_base() + for i in range(num_bytes): + interleave[start::factor] = order[i::num_bytes] + start += 1 + return interleave + +def interleave_segments(segments, num_bytes): + new_index = interleave_indexes(segments, num_bytes) + data_base, style_base = segments[0].rawdata.get_bases() + for s in segments[1:]: + d, s = s.rawdata.get_bases() + if id(d) != id(data_base) or id(s) != id(style_base): + raise ValueError("Can't interleave segments with different base arrays") + raw = SegmentData(data_base, style_base, segments[0].rawdata.comments, order=new_index) + segment = DefaultSegment(raw, 0) + return segment diff --git a/test/test_segment.py b/test/test_segment.py index fe33569..119f7f3 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -3,9 +3,15 @@ import os import numpy as np import pytest -from atrcopy import DefaultSegment, SegmentData, get_xex +from atrcopy import DefaultSegment, SegmentData, get_xex, interleave_segments +def get_indexed(segment, num, scale): + indexes = np.arange(num) * scale + raw = segment.rawdata.get_indexed(indexes) + s = DefaultSegment(raw, segment.start_addr + indexes[0]) + return s, indexes + class TestSegment1(object): def setup(self): self.segments = [] @@ -37,21 +43,15 @@ class TestIndexed(object): r = SegmentData(data) self.segment = DefaultSegment(r, 0) - def get_indexed(self, segment, num, scale): - indexes = np.arange(num) * scale - raw = segment.rawdata.get_indexed(indexes) - s = DefaultSegment(raw, segment.start_addr + indexes[0]) - return s, indexes - def test_indexed(self): assert not self.segment.rawdata.is_indexed - s, indexes = self.get_indexed(self.segment, 1024, 3) + s, indexes = get_indexed(self.segment, 1024, 3) assert s.rawdata.is_indexed for i in range(len(indexes)): assert s.get_raw_index(i) == indexes[i] # get indexed into indexed, will result in every 9th byte - s2, indexes2 = self.get_indexed(s, 256, 3) + s2, indexes2 = get_indexed(s, 256, 3) assert s2.rawdata.is_indexed for i in range(len(indexes2)): assert s2.get_raw_index(i) == indexes2[i] * 3 @@ -73,10 +73,10 @@ class TestIndexed(object): with pytest.raises(IndexError) as e: # attempt to get indexes to 1024 * 3... Index to big => fail! - s, indexes = self.get_indexed(sub, 1024, 3) + s, indexes = get_indexed(sub, 1024, 3) # try with elements up to 256 * 3 - s, indexes = self.get_indexed(sub, 256, 3) + s, indexes = get_indexed(sub, 256, 3) print sub.data print indexes print s.data[:] @@ -91,7 +91,7 @@ class TestIndexed(object): assert end == len(base) # get indexed into indexed, will result in every 9th byte - s2, indexes2 = self.get_indexed(s, 64, 3) + s2, indexes2 = get_indexed(s, 64, 3) assert s2.rawdata.is_indexed for i in range(len(indexes2)): assert s2.get_raw_index(i) == sub.start_addr + indexes2[i] * 3 @@ -99,9 +99,65 @@ class TestIndexed(object): assert start == 0 assert end == len(base) + def test_interleave(self): + base = self.segment + r1 = base.rawdata[512:1024] # 512 byte segment + s1 = DefaultSegment(r1, 512) + r2 = base.rawdata[1024:1536] # 512 byte segment + s2 = DefaultSegment(r2, 1024) + + indexes1 = r1.get_indexes_from_base() + verify1 = np.arange(512, 1024, dtype=np.uint32) + assert np.array_equal(indexes1, verify1) + + indexes2 = r2.get_indexes_from_base() + verify2 = np.arange(1024, 1536, dtype=np.uint32) + assert np.array_equal(indexes2, verify2) + + s = interleave_segments([s1, s2], 2) + a = np.empty(len(s1) + len(s2), dtype=np.uint8) + a[0::4] = s1[0::2] + a[1::4] = s1[1::2] + a[2::4] = s2[0::2] + a[3::4] = s2[1::2] + print list(s[:]) + print list(a[:]) + print s.rawdata.order + assert np.array_equal(s[:], a) + + s = interleave_segments([s1, s2], 4) + a = np.empty(len(s1) + len(s2), dtype=np.uint8) + a[0::8] = s1[0::4] + a[1::8] = s1[1::4] + a[2::8] = s1[2::4] + a[3::8] = s1[3::4] + a[4::8] = s2[0::4] + a[5::8] = s2[1::4] + a[6::8] = s2[2::4] + a[7::8] = s2[3::4] + assert np.array_equal(s[:], a) + + with pytest.raises(ValueError) as e: + s = interleave_segments([s1, s2], 3) + + r1 = base.rawdata[512:1025] # 513 byte segment + s1 = DefaultSegment(r1, 512) + r2 = base.rawdata[1024:1537] # 513 byte segment + s2 = DefaultSegment(r2, 1024) + s = interleave_segments([s1, s2], 3) + a = np.empty(len(s1) + len(s2), dtype=np.uint8) + a[0::6] = s1[0::3] + a[1::6] = s1[1::3] + a[2::6] = s1[2::3] + a[3::6] = s2[0::3] + a[4::6] = s2[1::3] + a[5::6] = s2[2::3] + assert np.array_equal(s[:], a) + if __name__ == "__main__": t = TestIndexed() t.setup() t.test_indexed() t.test_indexed_sub() + t.test_interleave()