mirror of
https://github.com/robmcmullen/atrcopy.git
synced 2024-11-25 16:32:07 +00:00
Fixed indexing of sub-arrays
This commit is contained in:
parent
4c4da4bf38
commit
0f209cf50a
@ -90,6 +90,9 @@ class SegmentData(object):
|
||||
self.comments = comments
|
||||
self.reverse_index_mapping = None
|
||||
|
||||
def __str__(self):
|
||||
return "SegmentData id=%x indexed=%s data=%s" % (id(self), self.is_indexed, type(self.data))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
@ -124,7 +127,7 @@ class SegmentData(object):
|
||||
segment
|
||||
"""
|
||||
if self.is_indexed:
|
||||
i = self.order[i]
|
||||
return int(self.order[i])
|
||||
if self.data.base is None:
|
||||
return int(i)
|
||||
data_start, data_end = np.byte_bounds(self.data)
|
||||
@ -147,7 +150,19 @@ class SegmentData(object):
|
||||
index = to_numpy_list(index)
|
||||
if self.is_indexed:
|
||||
return self[index]
|
||||
return SegmentData(self.data, self.style, self.comments, order=index)
|
||||
|
||||
# check to make sure all indexes are valid, raises IndexError if not
|
||||
check = self.data[index]
|
||||
|
||||
# index needs to be relative to the base array
|
||||
base_index = index + self.get_raw_index(0)
|
||||
if self.data.base is None:
|
||||
data_base = self.data
|
||||
style_base = self.style
|
||||
else:
|
||||
data_base = self.data.base
|
||||
style_base = self.style.base
|
||||
return SegmentData(data_base, style_base, self.comments, order=base_index)
|
||||
|
||||
def get_reverse_index(self, base_index):
|
||||
"""Get index into this segment's data given the index into the base data
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from atrcopy import DefaultSegment, SegmentData, get_xex
|
||||
|
||||
@ -28,7 +29,79 @@ class TestSegment1(object):
|
||||
size = reduce(lambda a, b:a + 4 + len(b), s, 0)
|
||||
assert len(bytes) == 2 + 6 + size
|
||||
|
||||
|
||||
class TestIndexed(object):
|
||||
def setup(self):
|
||||
data = np.arange(4096, dtype=np.uint8)
|
||||
data[1::2] = np.repeat(np.arange(16, dtype=np.uint8), 128)
|
||||
r = SegmentData(data)
|
||||
self.segment = DefaultSegment(r, 0)
|
||||
|
||||
def get_indexed(self, segment, num, scale):
|
||||
indexes = np.arange(num) * scale
|
||||
raw = segment.rawdata.get_indexed(indexes)
|
||||
s = DefaultSegment(raw, segment.start_addr + indexes[0])
|
||||
return s, indexes
|
||||
|
||||
def test_indexed(self):
|
||||
assert not self.segment.rawdata.is_indexed
|
||||
s, indexes = self.get_indexed(self.segment, 1024, 3)
|
||||
assert s.rawdata.is_indexed
|
||||
for i in range(len(indexes)):
|
||||
assert s.get_raw_index(i) == indexes[i]
|
||||
|
||||
# get indexed into indexed, will result in every 9th byte
|
||||
s2, indexes2 = self.get_indexed(s, 256, 3)
|
||||
assert s2.rawdata.is_indexed
|
||||
for i in range(len(indexes2)):
|
||||
assert s2.get_raw_index(i) == indexes2[i] * 3
|
||||
|
||||
def test_indexed_sub(self):
|
||||
base = self.segment
|
||||
assert not base.rawdata.is_indexed
|
||||
raw = base.rawdata[512:1536] # 1024 byte segment
|
||||
sub = DefaultSegment(raw, 512)
|
||||
|
||||
assert not sub.rawdata.is_indexed
|
||||
for i in range(len(sub)):
|
||||
ri = sub.get_raw_index(i)
|
||||
assert ri == sub.start_addr + i
|
||||
assert sub[i] == base[ri]
|
||||
start, end = sub.byte_bounds_offset()
|
||||
assert start == 512
|
||||
assert end == 1536
|
||||
|
||||
with pytest.raises(IndexError) as e:
|
||||
# attempt to get indexes to 1024 * 3... Index to big => fail!
|
||||
s, indexes = self.get_indexed(sub, 1024, 3)
|
||||
|
||||
# try with elements up to 256 * 3
|
||||
s, indexes = self.get_indexed(sub, 256, 3)
|
||||
print sub.data
|
||||
print indexes
|
||||
print s.data[:]
|
||||
assert s.rawdata.is_indexed
|
||||
for i in range(len(indexes)):
|
||||
ri = s.get_raw_index(i)
|
||||
print ri, "base[ri]=%d" % base[ri], i, indexes[i], "s[i]=%d" % s[i]
|
||||
assert ri == sub.start_addr + indexes[i]
|
||||
assert s[i] == base[ri]
|
||||
start, end = s.byte_bounds_offset()
|
||||
assert start == 0
|
||||
assert end == len(base)
|
||||
|
||||
# get indexed into indexed, will result in every 9th byte
|
||||
s2, indexes2 = self.get_indexed(s, 64, 3)
|
||||
assert s2.rawdata.is_indexed
|
||||
for i in range(len(indexes2)):
|
||||
assert s2.get_raw_index(i) == sub.start_addr + indexes2[i] * 3
|
||||
start, end = s.byte_bounds_offset()
|
||||
assert start == 0
|
||||
assert end == len(base)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
t = TestSegment1()
|
||||
t = TestIndexed()
|
||||
t.setup()
|
||||
t.test_xex()
|
||||
t.test_indexed()
|
||||
t.test_indexed_sub()
|
||||
|
Loading…
Reference in New Issue
Block a user