From 6f29e6053a7418eb94106277e5b1b677bb6e3849 Mon Sep 17 00:00:00 2001 From: Rob McMullen Date: Tue, 21 Mar 2017 16:25:59 -0700 Subject: [PATCH] Added resizable segments --- atrcopy/__init__.py | 4 +- atrcopy/ataridos.py | 6 ++- atrcopy/parsers.py | 6 ++- atrcopy/segments.py | 85 ++++++++++++++++++++++++++++++++++++++++++- test/test_ataridos.py | 18 +++++++-- test/test_segment.py | 65 +++++++++++++++++++++++++++++++++ 6 files changed, 174 insertions(+), 10 deletions(-) diff --git a/atrcopy/__init__.py b/atrcopy/__init__.py index d7327d0..5bda941 100644 --- a/atrcopy/__init__.py +++ b/atrcopy/__init__.py @@ -8,10 +8,10 @@ except ImportError: raise RuntimeError("atrcopy %s requires numpy" % __version__) from errors import * -from ataridos import AtrHeader, AtariDosDiskImage, BootDiskImage, AtariDosFile, get_xex, add_atr_header +from ataridos import AtrHeader, AtariDosDiskImage, BootDiskImage, AtariDosFile, XexContainerSegment, get_xex, add_atr_header from dos33 import Dos33DiskImage from kboot import KBootImage, add_xexboot_header -from segments import SegmentData, SegmentSaver, DefaultSegment, EmptySegment, ObjSegment, RawSectorsSegment, user_bit_mask, match_bit_mask, comment_bit_mask, data_style, selected_bit_mask, diff_bit_mask, not_user_bit_mask, interleave_segments, SegmentList, get_style_mask, get_style_bits +from segments import SegmentData, SegmentSaver, DefaultSegment, EmptySegment, ObjSegment, RawSectorsSegment, SegmentedFileSegment, user_bit_mask, match_bit_mask, comment_bit_mask, data_style, selected_bit_mask, diff_bit_mask, not_user_bit_mask, interleave_segments, SegmentList, get_style_mask, get_style_bits from spartados import SpartaDosDiskImage from cartridge import A8CartHeader, AtariCartImage from parsers import SegmentParser, DefaultSegmentParser, guess_parser_for_mime, guess_parser_for_system, iter_parsers, iter_known_segment_parsers, mime_parse_order diff --git a/atrcopy/ataridos.py b/atrcopy/ataridos.py index ec2db9a..9bda15b 100644 --- a/atrcopy/ataridos.py +++ b/atrcopy/ataridos.py @@ -2,7 +2,7 @@ import numpy as np from errors import * from diskimages import DiskImageBase, BaseHeader -from segments import SegmentData, EmptySegment, ObjSegment, RawSectorsSegment, DefaultSegment, SegmentSaver, get_style_bits +from segments import SegmentData, EmptySegment, ObjSegment, RawSectorsSegment, DefaultSegment, SegmentedFileSegment, SegmentSaver, get_style_bits from utils import * import logging @@ -242,6 +242,10 @@ class XexSegmentSaver(SegmentSaver): export_extensions = [".xex"] +class XexContainerSegment(DefaultSegment): + can_resize_default = True + + class XexSegment(ObjSegment): savers = [SegmentSaver, XexSegmentSaver] diff --git a/atrcopy/parsers.py b/atrcopy/parsers.py index a307fd6..96ab4c0 100644 --- a/atrcopy/parsers.py +++ b/atrcopy/parsers.py @@ -2,7 +2,7 @@ import numpy as np from segments import SegmentData, DefaultSegment from kboot import KBootImage -from ataridos import AtariDosDiskImage, BootDiskImage, AtariDosFile +from ataridos import AtariDosDiskImage, BootDiskImage, AtariDosFile, XexContainerSegment from spartados import SpartaDosDiskImage from cartridge import AtariCartImage, get_known_carts from mame import MameZipImage @@ -16,6 +16,7 @@ log = logging.getLogger(__name__) class SegmentParser(object): menu_name = "" image_type = None + container_segment = DefaultSegment def __init__(self, segment_data, strict=False): self.image = None @@ -26,7 +27,7 @@ class SegmentParser(object): def parse(self): r = self.segment_data - self.segments.append(DefaultSegment(r, 0, name=self.menu_name)) + self.segments.append(self.container_segment(r, 0, name=self.menu_name)) try: self.image = self.get_image(r) self.check_image() @@ -80,6 +81,7 @@ class AtariBootDiskSegmentParser(SegmentParser): class XexSegmentParser(SegmentParser): menu_name = "XEX (Atari 8-bit executable)" image_type = AtariDosFile + container_segment = XexContainerSegment class AtariCartSegmentParser(SegmentParser): diff --git a/atrcopy/segments.py b/atrcopy/segments.py index f5ef8f6..2363572 100644 --- a/atrcopy/segments.py +++ b/atrcopy/segments.py @@ -153,11 +153,45 @@ class SegmentData(object): def __len__(self): return len(self.data) + def resize(self, newsize): + if self.data.base is None: + try: + newdata = np.resize(self.data, (newsize,)) + newstyle = np.resize(self.style, (newsize,)) + except: + raise + else: + self.data = newdata + self.style = newstyle + + def replace_arrays(self, base_raw): + newsize = len(base_raw) + oldsize = len(self.data_base) + if newsize < oldsize: + raise NotImplementedError("Can't truncate yet") + if self.is_indexed: + self.data.np_data = base_raw.data + self.data.base = base_raw.data.base + self.style.np_data = base_raw.style + self.style.base = base_raw.style.base + elif self.data.base is not None: + # if there is no base array, we aren't looking at a slice so we + # must be copying the entire array. + start, end = self.byte_bounds_offset() + self.data = base_raw.data[start:end] + self.style = base_raw.style[start:end] + else: + raise ValueError("The base SegmentData object should use the resize method to replace arrays") + @property def stringio(self): buf = cStringIO.StringIO(self.data[:]) return buf + @property + def is_base(self): + return not self.is_indexed and self.data.base is None + @property def data_base(self): return self.data.np_data if self.is_indexed else self.data.base if self.data.base is not None else self.data @@ -301,6 +335,8 @@ class SegmentData(object): class DefaultSegment(object): savers = [SegmentSaver] + use_origin_default = False + can_resize_default = False def __init__(self, rawdata, start_addr=0, name="All", error=None, verbose_name=None): self.start_addr = int(start_addr) # force python int to decouple from possibly being a numpy datatype @@ -314,7 +350,11 @@ class DefaultSegment(object): # Some segments may not have a standard place in memory, so this flag # can be used to skip the memory map lookup when displaying disassembly - self.use_origin = False + self.use_origin = self.__class__.use_origin_default + + # Some segments may be resized to contain additional segments not + # present when the segment was created. + self.can_resize = self.__class__.can_resize_default def set_raw(self, rawdata): self.rawdata = rawdata @@ -324,10 +364,43 @@ class DefaultSegment(object): def get_raw(self): return self.rawdata + + def resize(self, newsize, zeros=True): + """ Resize the data arrays. + + This can only be performed on the container segment. Child segments + must adjust their rawdata to point to the correct place. + + Since segments don't keep references to other segments, it is the + user's responsibility to update any child segments that point to this + segment's data. + + Numpy can't do an in-place resize on an array that has a view, so the + data must be replaced and all segments that point to that raw data must + also be changed. This has to happen outside this method because it + doesn't know the segment list of segments using itself as a base. + """ + if not self.can_resize: + raise ValueError("Segment %s can't be resized" % str(self)) + # only makes sense for the container (outermost) object + if not self.rawdata.is_base: + raise ValueError("Only container segments can be resized") + origsize = len(self) + self.rawdata.resize(newsize) + self.set_raw(self.rawdata) # force attributes to be reset + newsize = len(self) + if zeros: + if newsize > origsize: + self.data[origsize:] = 0 + self.style[origsize:] = 0 + return origsize, newsize + + def replace_data(self, container): + self.rawdata.replace_arrays(container.rawdata) def __getstate__(self): state = dict() - for key in ['start_addr', 'error', 'name', 'verbose_name', 'page_size', 'map_width', 'uuid']: + for key in ['start_addr', 'error', 'name', 'verbose_name', 'page_size', 'map_width', 'uuid', 'use_origin', 'can_resize']: state[key] = getattr(self, key) r = self.rawdata state['_rawdata_bounds'] = list(r.byte_bounds_offset()) @@ -347,6 +420,10 @@ class DefaultSegment(object): """ if not hasattr(self, 'uuid'): self.uuid = str(uuid.uuid4()) + if not hasattr(self, 'use_origin'): + self.use_origin = self.__class__.use_origin_default + if not hasattr(self, 'can_resize'): + self.can_resize = self.__class__.can_resize_default def reconstruct_raw(self, rawdata): start, end = self._rawdata_bounds @@ -918,6 +995,10 @@ class ObjSegment(DefaultSegment): return s +class SegmentedFileSegment(ObjSegment): + can_resize_default = True + + class RawSectorsSegment(DefaultSegment): def __init__(self, rawdata, first_sector, num_sectors, count, boot_sector_size, num_boot_sectors, sector_size, **kwargs): DefaultSegment.__init__(self, rawdata, 0, **kwargs) diff --git a/test/test_ataridos.py b/test/test_ataridos.py index c5ea16c..a7775cf 100644 --- a/test/test_ataridos.py +++ b/test/test_ataridos.py @@ -1,6 +1,6 @@ from mock import * -from atrcopy import SegmentData, AtariDosFile, InvalidBinaryFile +from atrcopy import SegmentData, AtariDosFile, InvalidBinaryFile, DefaultSegment, XexContainerSegment class TestAtariDosFile(object): @@ -8,12 +8,24 @@ class TestAtariDosFile(object): pass def test_segment(self): - bytes = [0xff, 0xff, 0x00, 0x60, 0x01, 0x60, 1, 2] + bytes = np.asarray([0xff, 0xff, 0x00, 0x60, 0x01, 0x60, 1, 2], dtype=np.uint8) rawdata = SegmentData(bytes) - image = AtariDosFile(rawdata) + container = XexContainerSegment(rawdata, 0) + image = AtariDosFile(container.rawdata) image.parse_segments() + print image.segments assert len(image.segments) == 1 assert len(image.segments[0]) == 2 + assert np.all(image.segments[0] == bytes[6:8]) + container.resize(16) + for s in image.segments: + s.replace_data(container) + new_segment = DefaultSegment(rawdata[8:16]) + new_segment[:] = 99 + assert np.all(image.segments[0] == bytes[6:8]) + print new_segment[:] + assert np.all(new_segment[:] == 99) + def test_short_segment(self): bytes = [0xff, 0xff, 0x00, 0x60, 0xff, 0x60, 1, 2] diff --git a/test/test_segment.py b/test/test_segment.py index fe6c060..884aa3a 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -331,6 +331,71 @@ class TestComments(object): assert item1[3] == item2[3] +class TestResize(object): + def setup(self): + data = np.arange(4096, dtype=np.uint8) + data[1::2] = np.repeat(np.arange(16, dtype=np.uint8), 128) + r = SegmentData(data) + self.container = DefaultSegment(r, 0) + self.container.can_resize = True + + def test_subset(self): + c = self.container + assert not c.rawdata.is_indexed + offset = 1000 + s = DefaultSegment(c.rawdata[offset:offset + offset], 0) + assert not s.rawdata.is_indexed + for i in range(offset): + assert s[i] == c[i + offset] + requested = 8192 + oldraw = s.rawdata.copy() + oldid = id(s.rawdata) + oldsize, newsize = c.resize(requested) + assert newsize == requested + s.replace_data(c) + assert id(s.rawdata) == oldid + assert id(oldraw.order) == id(s.rawdata.order) + for i in range(offset): + assert s[i] == c[i + offset] + newbase = c.rawdata + newsub = s.rawdata + print c.rawdata.data + print s.rawdata.data[:] + s.rawdata.data[:] = 111 + print c.rawdata.data + print s.rawdata.data[:] + for i in range(offset): + assert s[i] == c[i + offset] + + def test_indexed(self): + c = self.container + assert not c.rawdata.is_indexed + s, indexes = get_indexed(self.container, 1024, 3) + assert s.rawdata.is_indexed + for i in range(len(indexes)): + assert s.get_raw_index(i) == indexes[i] + requested = 8192 + oldraw = s.rawdata.copy() + oldid = id(s.rawdata) + oldsize, newsize = c.resize(requested) + assert newsize == requested + s.replace_data(c) + assert id(s.rawdata) == oldid + assert id(oldraw.order) == id(s.rawdata.order) + for i in range(len(indexes)): + assert s.get_raw_index(i) == indexes[i] + newbase = c.rawdata + newsub = s.rawdata + print c.rawdata.data + print s.rawdata.data[:] + s.rawdata.data[:] = 111 + print c.rawdata.data + print s.rawdata.data[:] + for i in range(len(indexes)): + assert c.rawdata.data[indexes[i]] == s.rawdata.data[i] + + + if __name__ == "__main__": t = TestIndexed() t.setup()