diff --git a/atrcopy/ataridos.py b/atrcopy/ataridos.py index ef30602..ae34895 100644 --- a/atrcopy/ataridos.py +++ b/atrcopy/ataridos.py @@ -62,7 +62,7 @@ class AtariDosDirectory(Directory): num += 1 -class AtariDosDirent(object): +class AtariDosDirent(Dirent): # ATR Dirent structure described at http://atari.kensclassics.org/dos.htm format = np.dtype([ ('FLAG', 'u1'), @@ -73,7 +73,7 @@ class AtariDosDirent(object): ]) def __init__(self, image, file_num=0, bytes=None): - self.file_num = file_num + Dirent.__init__(self, file_num) self.flag = 0 self.opened_output = False self.dos_2 = False @@ -95,6 +95,9 @@ class AtariDosDirent(object): def __str__(self): return "File #%-2d (%s) %03d %-8s%-3s %03d" % (self.file_num, self.summary, self.starting_sector, self.basename, self.ext, self.num_sectors) + def __eq__(self, other): + return self.__class__ == other.__class__ and self.filename == other.filename and self.starting_sector == other.starting_sector and self.num_sectors == other.num_sectors + @property def filename(self): ext = ("." + self.ext) if self.ext else "" diff --git a/atrcopy/diskimages.py b/atrcopy/diskimages.py index 58b195c..bcf16cb 100644 --- a/atrcopy/diskimages.py +++ b/atrcopy/diskimages.py @@ -282,8 +282,11 @@ class DiskImageBase(object): return [] def find_dirent(self, filename): + # check if we've been passed a dirent instead of a filename + if hasattr(filename, "filename"): + return filename for dirent in self.files: - if filename == dirent.filename: + if filename_or_dirent == dirent.filename: return dirent raise FileNotFound("%s not found on disk" % filename) diff --git a/atrcopy/dos33.py b/atrcopy/dos33.py index 488005d..287457a 100644 --- a/atrcopy/dos33.py +++ b/atrcopy/dos33.py @@ -1,7 +1,8 @@ import numpy as np from errors import * -from diskimages import BaseHeader, DiskImageBase, Directory, VTOC, WriteableSector, BaseSectorList +from diskimages import BaseHeader, DiskImageBase +from utils import Directory, VTOC, WriteableSector, BaseSectorList, Dirent from segments import DefaultSegment, EmptySegment, ObjSegment, RawTrackSectorSegment, SegmentSaver import logging @@ -143,7 +144,7 @@ class Dos33Directory(Directory): current_sector = next_sector -class Dos33Dirent(object): +class Dos33Dirent(Dirent): format = np.dtype([ ('track', 'u1'), ('sector', 'u1'), @@ -153,7 +154,7 @@ class Dos33Dirent(object): ]) def __init__(self, image, file_num=0, bytes=None): - self.file_num = file_num + Dirent.__init__(self, file_num) self._file_type = 0 self.locked = False self.deleted = False @@ -170,6 +171,9 @@ class Dos33Dirent(object): def __str__(self): return "File #%-2d (%s) %03d %-30s %03d %03d" % (self.file_num, self.summary, self.num_sectors, self.filename, self.track, self.sector) + + def __eq__(self, other): + return self.__class__ == other.__class__ and self.filename == other.filename and self.track == other.track and self.sector == other.sector and self.num_sectors == other.num_sectors type_map = { 0x0: "T", # text diff --git a/atrcopy/utils.py b/atrcopy/utils.py index 6495426..0071ad8 100644 --- a/atrcopy/utils.py +++ b/atrcopy/utils.py @@ -133,6 +133,35 @@ class BaseSectorList(object): self.sectors.extend(sectors) +class Dirent(object): + """Abstract base class for a directory entry + + """ + def __init__(self, file_num=0): + self.file_num = file_num + + def __eq__(self, other): + raise NotImplementedError + + def mark_deleted(self): + raise NotImplementedError + + def parse_raw_dirent(self, image, bytes): + raise NotImplementedError + + def encode_dirent(self): + raise NotImplementedError + + def get_sectors_in_vtoc(self, image): + raise NotImplementedError + + def start_read(self, image): + raise NotImplementedError + + def read_sector(self, image): + raise NotImplementedError + + class Directory(BaseSectorList): def __init__(self, header, num_dirents=-1, sector_class=WriteableSector): BaseSectorList.__init__(self, header) @@ -167,9 +196,15 @@ class Directory(BaseSectorList): return dirent def find_dirent(self, filename): - for dirent in self.dirents.values(): - if filename == dirent.filename: - return dirent + if hasattr(filename, "filename"): + # we've been passed a dirent instead of a filename + for dirent in self.dirents.values(): + if dirent == filename: + return dirent + else: + for dirent in self.dirents.values(): + if filename == dirent.filename: + return dirent raise FileNotFound("%s not found on disk" % filename) def save_dirent(self, image, dirent, vtoc, sector_list):