From 0f209cf50a7458e3db36af3008fedd8a2a0b6410 Mon Sep 17 00:00:00 2001 From: Rob McMullen Date: Mon, 9 May 2016 23:45:10 -0700 Subject: [PATCH] Fixed indexing of sub-arrays --- atrcopy/segments.py | 19 +++++++++-- test/test_segment.py | 77 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 92 insertions(+), 4 deletions(-) diff --git a/atrcopy/segments.py b/atrcopy/segments.py index 7cbb35d..160855e 100644 --- a/atrcopy/segments.py +++ b/atrcopy/segments.py @@ -90,6 +90,9 @@ class SegmentData(object): self.comments = comments self.reverse_index_mapping = None + def __str__(self): + return "SegmentData id=%x indexed=%s data=%s" % (id(self), self.is_indexed, type(self.data)) + def __len__(self): return len(self.data) @@ -124,7 +127,7 @@ class SegmentData(object): segment """ if self.is_indexed: - i = self.order[i] + return int(self.order[i]) if self.data.base is None: return int(i) data_start, data_end = np.byte_bounds(self.data) @@ -147,7 +150,19 @@ class SegmentData(object): index = to_numpy_list(index) if self.is_indexed: return self[index] - return SegmentData(self.data, self.style, self.comments, order=index) + + # check to make sure all indexes are valid, raises IndexError if not + check = self.data[index] + + # index needs to be relative to the base array + base_index = index + self.get_raw_index(0) + if self.data.base is None: + data_base = self.data + style_base = self.style + else: + data_base = self.data.base + style_base = self.style.base + return SegmentData(data_base, style_base, self.comments, order=base_index) def get_reverse_index(self, base_index): """Get index into this segment's data given the index into the base data diff --git a/test/test_segment.py b/test/test_segment.py index 1439833..fe33569 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -1,6 +1,7 @@ import os import numpy as np +import pytest from atrcopy import DefaultSegment, SegmentData, get_xex @@ -28,7 +29,79 @@ class TestSegment1(object): size = reduce(lambda a, b:a + 4 + len(b), s, 0) assert len(bytes) == 2 + 6 + size + +class TestIndexed(object): + def setup(self): + data = np.arange(4096, dtype=np.uint8) + data[1::2] = np.repeat(np.arange(16, dtype=np.uint8), 128) + r = SegmentData(data) + self.segment = DefaultSegment(r, 0) + + def get_indexed(self, segment, num, scale): + indexes = np.arange(num) * scale + raw = segment.rawdata.get_indexed(indexes) + s = DefaultSegment(raw, segment.start_addr + indexes[0]) + return s, indexes + + def test_indexed(self): + assert not self.segment.rawdata.is_indexed + s, indexes = self.get_indexed(self.segment, 1024, 3) + assert s.rawdata.is_indexed + for i in range(len(indexes)): + assert s.get_raw_index(i) == indexes[i] + + # get indexed into indexed, will result in every 9th byte + s2, indexes2 = self.get_indexed(s, 256, 3) + assert s2.rawdata.is_indexed + for i in range(len(indexes2)): + assert s2.get_raw_index(i) == indexes2[i] * 3 + + def test_indexed_sub(self): + base = self.segment + assert not base.rawdata.is_indexed + raw = base.rawdata[512:1536] # 1024 byte segment + sub = DefaultSegment(raw, 512) + + assert not sub.rawdata.is_indexed + for i in range(len(sub)): + ri = sub.get_raw_index(i) + assert ri == sub.start_addr + i + assert sub[i] == base[ri] + start, end = sub.byte_bounds_offset() + assert start == 512 + assert end == 1536 + + with pytest.raises(IndexError) as e: + # attempt to get indexes to 1024 * 3... Index to big => fail! + s, indexes = self.get_indexed(sub, 1024, 3) + + # try with elements up to 256 * 3 + s, indexes = self.get_indexed(sub, 256, 3) + print sub.data + print indexes + print s.data[:] + assert s.rawdata.is_indexed + for i in range(len(indexes)): + ri = s.get_raw_index(i) + print ri, "base[ri]=%d" % base[ri], i, indexes[i], "s[i]=%d" % s[i] + assert ri == sub.start_addr + indexes[i] + assert s[i] == base[ri] + start, end = s.byte_bounds_offset() + assert start == 0 + assert end == len(base) + + # get indexed into indexed, will result in every 9th byte + s2, indexes2 = self.get_indexed(s, 64, 3) + assert s2.rawdata.is_indexed + for i in range(len(indexes2)): + assert s2.get_raw_index(i) == sub.start_addr + indexes2[i] * 3 + start, end = s.byte_bounds_offset() + assert start == 0 + assert end == len(base) + + if __name__ == "__main__": - t = TestSegment1() + t = TestIndexed() t.setup() - t.test_xex() + t.test_indexed() + t.test_indexed_sub()