diff --git a/atrcopy/segments.py b/atrcopy/segments.py index ace40df..d767064 100644 --- a/atrcopy/segments.py +++ b/atrcopy/segments.py @@ -156,6 +156,14 @@ class SegmentData(object): def stringio(self): buf = cStringIO.StringIO(self.data[:]) return buf + + @property + def data_base(self): + return self.data.base if self.data.base is not None else self.data + + @property + def style_base(self): + return self.style.base if self.style.base is not None else self.style def get_data(self): return self.data @@ -227,10 +235,17 @@ class SegmentData(object): d = self.data.np_data.copy() s = self.style.np_data.copy() copy = SegmentData(d, s, order=self.order) - else: + elif self.data.base is None: + # if there is no base array, we aren't looking at a slice so we + # must be copying the entire array. d = self.data.copy() s = self.style.copy() + copy = SegmentData(d, s) + else: + d = self.data.base.copy() + s = self.style.base.copy() start, end = self.byte_bounds_offset() + print "copy: start, end =", start, end copy = SegmentData(d[start:end], s[start:end]) return copy @@ -480,6 +495,17 @@ class DefaultSegment(object): matches = (self.style & style_bits) == style_bits return self.bool_to_ranges(matches) + def get_comment_locations(self, **kwargs): + style_bits = self.get_style_bits(**kwargs) + r = self.rawdata.copy() + print len(r.style) + print len(r.style_base) + base = r.style_base & style_bits + comment_indexes = np.asarray(self.rawdata.extra.comments.keys(), dtype=np.uint32) + print comment_indexes + base[comment_indexes] |= comment_bit_mask + return r.style + def get_entire_style_ranges(self, split_comments=None, **kwargs): """Find sections of the segment that have the same style value. @@ -492,32 +518,41 @@ class DefaultSegment(object): tuple; and an integer with the style value. """ style_bits = self.get_style_bits(**kwargs) - matches = self.style & style_bits - if split_comments is None: - split_comments = set() - else: - split_comments = set(split_comments) + matches = self.get_comment_locations(**kwargs) groups = np.split(matches, np.where(np.diff(matches) != 0)[0] + 1) + # print groups # split into groups with the same numbers ranges = [] last_end = 0 if len(groups) == 1 and len(groups[0]) == 0: # check for degenerate case return + last_style = -1 for group in groups: - next_end = last_end + len(group) + # each group is guaranteed to have the same style + size = len(group) + next_end = last_end + size style = matches[last_end] - if style in split_comments: - comment_list = self.get_comments_in_range(last_end, next_end) - for index in sorted(comment_list.keys()): - if last_end == index: - # skip if the comment is at the start point because it - # will always be split at the start point - continue - ranges.append(((last_end, index), style)) - last_end = index - if last_end < next_end: - ranges.append(((last_end, next_end), style)) + masked_style = style & style_bits + # print last_end, next_end, style, masked_style, size, group + if style & comment_bit_mask: + if masked_style in split_comments: + # print "interesting comment", last_end, next_end + ranges.append(((last_end, next_end), masked_style)) + else: + # print "non-interesting comment", last_end, next_end + if last_style == masked_style: + ((prev_end, _), _) = ranges.pop() + ranges.append(((prev_end, next_end), masked_style)) + else: + ranges.append(((last_end, next_end), masked_style)) + else: + if last_style == masked_style: + ((prev_end, _), _) = ranges.pop() + ranges.append(((prev_end, next_end), masked_style)) + else: + ranges.append(((last_end, next_end), masked_style)) + last_style = masked_style last_end = next_end return ranges diff --git a/test/test_segment.py b/test/test_segment.py index fc71289..f173167 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -3,7 +3,7 @@ import os import numpy as np import pytest -from atrcopy import DefaultSegment, SegmentData, get_xex, interleave_segments +from atrcopy import DefaultSegment, SegmentData, get_xex, interleave_segments, user_bit_mask def get_indexed(segment, num, scale): @@ -181,6 +181,87 @@ class TestIndexed(object): assert not np.all((c.data[:] - s.data[:]) == 0) +class TestComments(object): + def setup(self): + data = np.ones([4000], dtype=np.uint8) + r = SegmentData(data) + self.segment = DefaultSegment(r, 0) + self.sub_segment = DefaultSegment(r[2:202], 2) + + def test_locations(self): + s = self.segment + s.set_comment([[4,5]], "test1") + s.set_comment([[40,50]], "test2") + s.set_style_ranges([[2,100]], comment=True) + s.set_style_ranges([[200, 299]], data=True) + for i in range(1,4): + for j in range(1, 4): + # create some with overlapping regions, some without + r = [500*j, 500*j + 200*i + 200] + s.set_style_ranges([r], user=i) + s.set_user_data([r], i, i*10 + j) + r = [100, 200] + s.set_style_ranges([r], user=4) + s.set_user_data([r], 4, 99) + r = [3100, 3200] + s.set_style_ranges([r], user=4) + s.set_user_data([r], 4, 99) + + s2 = self.sub_segment + print len(s2) + copy = s2.get_comment_locations() + print copy + # comments at 4 and 40 in the original means 2 and 38 in the copy + orig = s.get_comment_locations() + assert copy[2] == orig[4] + assert copy[28] == orig[38] + + def test_split_data_at_comment(self): + s = self.segment + s.set_style_ranges([[0,1000]], data=True) + for i in range(0, len(s), 25): + s.set_comment([[i,i+1]], "comment at %d" % i) + + s2 = self.sub_segment + print len(s2) + copy = s2.get_comment_locations() + print copy + # comments at 4 and 40 in the original means 2 and 38 in the copy + orig = s.get_comment_locations() + print orig[0:200] + assert copy[2] == orig[4] + assert copy[28] == orig[38] + + r = s2.get_entire_style_ranges([1], user=True) + print r + assert r == [((0, 23), 1), ((23, 48), 1), ((48, 73), 1), ((73, 98), 1), ((98, 123), 1), ((123, 148), 1), ((148, 173), 1), ((173, 198), 1), ((198, 200), 1)] + + def test_split_data_at_comment2(self): + s = self.segment + start = 0 + i = 0 + for end in range(40, 1000, 40): + s.set_style_ranges([[start, end]], user=i) + start = end + i = (i + 1) % 8 + for i in range(0, len(s), 25): + s.set_comment([[i,i+1]], "comment at %d" % i) + + s2 = self.sub_segment + print len(s2) + copy = s2.get_comment_locations() + print copy + # comments at 4 and 40 in the original means 2 and 38 in the copy + orig = s.get_comment_locations() + print orig[0:200] + assert copy[2] == orig[4] + assert copy[28] == orig[38] + + r = s2.get_entire_style_ranges([1], user=user_bit_mask) + print r + assert r == [((0, 38), 0), ((38, 48), 1), ((48, 73), 1), ((73, 78), 1), ((78, 118), 2), ((118, 158), 3), ((158, 198), 4), ((198, 200), 5)] + + if __name__ == "__main__": t = TestIndexed() t.setup() @@ -191,3 +272,6 @@ if __name__ == "__main__": t.setup() t.test_xex() t.test_copy() + t = TestComments() + t.setup() + t.test_split_data_at_comment()