Added user data serialization and expanded user data to actually use the user_index field so multiple types of user data are actually supported

This commit is contained in:
Rob McMullen 2016-06-06 14:31:51 -07:00
parent 66bb7e63ea
commit f64ffb777b
2 changed files with 116 additions and 3 deletions

View File

@ -64,6 +64,8 @@ class UserExtraData(object):
def __init__(self):
self.comments = dict()
self.user_data = dict()
for i in range(1, user_bit_mask):
self.user_data[i] = dict()
class SegmentData(object):
@ -256,7 +258,48 @@ class DefaultSegment(object):
if other.rawdata.is_indexed:
r = r.get_indexed[other.order]
return r
def serialize_extra_to_dict(self, mdict):
"""Save extra metadata to a dict so that it can be serialized
This is not saved by __getstate__ because child segments will point to
the same data and this allows it to only be saved for the base segment.
As well as allowing it to be pulled out of the main json so that it can
be more easily edited by hand if desired.
"""
mdict["comment ranges"] = [list(a) for a in self.get_style_ranges(comment=True)]
mdict["data ranges"] = [list(a) for a in self.get_style_ranges(data=True)]
for i in range(1, user_bit_mask):
r = self.get_sorted_user_data(i)
if r:
slot = "user ranges %d" % i
mdict[slot] = r
# json serialization doesn't allow int keys, so convert to list of
# pairs
mdict["comments"] = self.get_sorted_comments()
def restore_extra_from_dict(self, e):
if 'comments' in e:
for k, v in e['comments']:
self.rawdata.extra.comments[k] = v
if 'comment ranges' in e:
self.set_style_ranges(e['comment ranges'], comment=True)
if 'data ranges' in e:
self.set_style_ranges(e['data ranges'], data=True)
if 'display list ranges' in e:
# DEPRECATED, but supported on read. Converts display list to
# disassembly type 0 for user index 1
self.set_style_ranges(e['display list ranges'], data=True, user=1)
self.set_user_data(e['display list ranges'], 1, 0)
for i in range(1, user_bit_mask):
slot = "user ranges %d" % i
if slot in e:
for r, val in e[slot]:
self.set_style_ranges([r], user=i)
self.set_user_data([r], i, val)
def __str__(self):
s = "%s ($%x bytes)" % (self.name, len(self))
if self.error:
@ -485,15 +528,34 @@ class DefaultSegment(object):
# FIXME: this is slow
for i in range(start, end):
rawindex = self.get_raw_index(i)
self.rawdata.extra.user_data[rawindex] = user_data
self.rawdata.extra.user_data[user_index][rawindex] = user_data
def get_user_data(self, index, user_index):
rawindex = self.get_raw_index(index)
try:
return self.rawdata.extra.user_data[rawindex]
return self.rawdata.extra.user_data[user_index][rawindex]
except KeyError:
return 0
def get_sorted_user_data(self, user_index):
d = self.rawdata.extra.user_data[user_index]
indexes = sorted(d.keys())
ranges = []
start, end, current = None, None, None
for i in indexes:
if start is None:
start = i
current = d[i]
else:
if d[i] != current or i != end:
ranges.append([[start, end], current])
start = i
current = d[i]
end = i + 1
if start is not None:
ranges.append([[start, end], current])
return ranges
def set_comment(self, ranges, text):
self.set_style_ranges(ranges, comment=True)
for start, end in ranges:

51
test/test_serialize.py Normal file
View File

@ -0,0 +1,51 @@
import os
import numpy as np
import pytest
from atrcopy import DefaultSegment, SegmentData, get_xex, interleave_segments
class TestSegment(object):
def setup(self):
data = np.ones([4000], dtype=np.uint8)
r = SegmentData(data)
self.segment = DefaultSegment(r, 0)
def test_s1(self):
s = self.segment
s.set_comment([[4,5]], "test1")
s.set_comment([[40,50]], "test2")
s.set_style_ranges([[2,100]], comment=True)
s.set_style_ranges([[200, 299]], data=True)
for i in range(1,4):
for j in range(1, 4):
# create some with overlapping regions, some without
r = [500*j, 500*j + 200*i + 200]
s.set_style_ranges([r], user=i)
s.set_user_data([r], i, i*10 + j)
r = [100, 200]
s.set_style_ranges([r], user=4)
s.set_user_data([r], 4, 99)
r = [3100, 3200]
s.set_style_ranges([r], user=4)
s.set_user_data([r], 4, 99)
out = dict()
s.serialize_extra_to_dict(out)
print "saved", out
data = np.ones([4000], dtype=np.uint8)
r = SegmentData(data)
s2 = DefaultSegment(r, 0)
s2.restore_extra_from_dict(out)
out2 = dict()
s2.serialize_extra_to_dict(out2)
print "loaded", out2
assert out == out2
if __name__ == "__main__":
t = TestSegment()
t.setup()
t.test_s1()