diff --git a/atrcopy/segments.py b/atrcopy/segments.py index 48d341f..0dcfce4 100644 --- a/atrcopy/segments.py +++ b/atrcopy/segments.py @@ -64,6 +64,8 @@ class UserExtraData(object): def __init__(self): self.comments = dict() self.user_data = dict() + for i in range(1, user_bit_mask): + self.user_data[i] = dict() class SegmentData(object): @@ -256,7 +258,48 @@ class DefaultSegment(object): if other.rawdata.is_indexed: r = r.get_indexed[other.order] return r - + + def serialize_extra_to_dict(self, mdict): + """Save extra metadata to a dict so that it can be serialized + + This is not saved by __getstate__ because child segments will point to + the same data and this allows it to only be saved for the base segment. + As well as allowing it to be pulled out of the main json so that it can + be more easily edited by hand if desired. + """ + mdict["comment ranges"] = [list(a) for a in self.get_style_ranges(comment=True)] + mdict["data ranges"] = [list(a) for a in self.get_style_ranges(data=True)] + for i in range(1, user_bit_mask): + r = self.get_sorted_user_data(i) + if r: + slot = "user ranges %d" % i + mdict[slot] = r + + # json serialization doesn't allow int keys, so convert to list of + # pairs + mdict["comments"] = self.get_sorted_comments() + + def restore_extra_from_dict(self, e): + if 'comments' in e: + for k, v in e['comments']: + self.rawdata.extra.comments[k] = v + if 'comment ranges' in e: + self.set_style_ranges(e['comment ranges'], comment=True) + if 'data ranges' in e: + self.set_style_ranges(e['data ranges'], data=True) + if 'display list ranges' in e: + # DEPRECATED, but supported on read. Converts display list to + # disassembly type 0 for user index 1 + self.set_style_ranges(e['display list ranges'], data=True, user=1) + self.set_user_data(e['display list ranges'], 1, 0) + for i in range(1, user_bit_mask): + slot = "user ranges %d" % i + if slot in e: + for r, val in e[slot]: + self.set_style_ranges([r], user=i) + self.set_user_data([r], i, val) + + def __str__(self): s = "%s ($%x bytes)" % (self.name, len(self)) if self.error: @@ -485,15 +528,34 @@ class DefaultSegment(object): # FIXME: this is slow for i in range(start, end): rawindex = self.get_raw_index(i) - self.rawdata.extra.user_data[rawindex] = user_data + self.rawdata.extra.user_data[user_index][rawindex] = user_data def get_user_data(self, index, user_index): rawindex = self.get_raw_index(index) try: - return self.rawdata.extra.user_data[rawindex] + return self.rawdata.extra.user_data[user_index][rawindex] except KeyError: return 0 + def get_sorted_user_data(self, user_index): + d = self.rawdata.extra.user_data[user_index] + indexes = sorted(d.keys()) + ranges = [] + start, end, current = None, None, None + for i in indexes: + if start is None: + start = i + current = d[i] + else: + if d[i] != current or i != end: + ranges.append([[start, end], current]) + start = i + current = d[i] + end = i + 1 + if start is not None: + ranges.append([[start, end], current]) + return ranges + def set_comment(self, ranges, text): self.set_style_ranges(ranges, comment=True) for start, end in ranges: diff --git a/test/test_serialize.py b/test/test_serialize.py new file mode 100644 index 0000000..e7120ff --- /dev/null +++ b/test/test_serialize.py @@ -0,0 +1,51 @@ +import os + +import numpy as np +import pytest + +from atrcopy import DefaultSegment, SegmentData, get_xex, interleave_segments + + +class TestSegment(object): + def setup(self): + data = np.ones([4000], dtype=np.uint8) + r = SegmentData(data) + self.segment = DefaultSegment(r, 0) + + def test_s1(self): + s = self.segment + s.set_comment([[4,5]], "test1") + s.set_comment([[40,50]], "test2") + s.set_style_ranges([[2,100]], comment=True) + s.set_style_ranges([[200, 299]], data=True) + for i in range(1,4): + for j in range(1, 4): + # create some with overlapping regions, some without + r = [500*j, 500*j + 200*i + 200] + s.set_style_ranges([r], user=i) + s.set_user_data([r], i, i*10 + j) + r = [100, 200] + s.set_style_ranges([r], user=4) + s.set_user_data([r], 4, 99) + r = [3100, 3200] + s.set_style_ranges([r], user=4) + s.set_user_data([r], 4, 99) + + out = dict() + s.serialize_extra_to_dict(out) + print "saved", out + + data = np.ones([4000], dtype=np.uint8) + r = SegmentData(data) + s2 = DefaultSegment(r, 0) + s2.restore_extra_from_dict(out) + out2 = dict() + s2.serialize_extra_to_dict(out2) + print "loaded", out2 + assert out == out2 + + +if __name__ == "__main__": + t = TestSegment() + t.setup() + t.test_s1()