Added Dirent abstract base class and ability for find_dirent to take filename or dirent

This commit is contained in:
Rob McMullen 2017-02-26 14:06:29 -08:00
parent 4d1f17677d
commit f4057f6ad5
4 changed files with 54 additions and 9 deletions

View File

@ -62,7 +62,7 @@ class AtariDosDirectory(Directory):
num += 1
class AtariDosDirent(object):
class AtariDosDirent(Dirent):
# ATR Dirent structure described at http://atari.kensclassics.org/dos.htm
format = np.dtype([
('FLAG', 'u1'),
@ -73,7 +73,7 @@ class AtariDosDirent(object):
])
def __init__(self, image, file_num=0, bytes=None):
self.file_num = file_num
Dirent.__init__(self, file_num)
self.flag = 0
self.opened_output = False
self.dos_2 = False
@ -95,6 +95,9 @@ class AtariDosDirent(object):
def __str__(self):
return "File #%-2d (%s) %03d %-8s%-3s %03d" % (self.file_num, self.summary, self.starting_sector, self.basename, self.ext, self.num_sectors)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.filename == other.filename and self.starting_sector == other.starting_sector and self.num_sectors == other.num_sectors
@property
def filename(self):
ext = ("." + self.ext) if self.ext else ""

View File

@ -282,8 +282,11 @@ class DiskImageBase(object):
return []
def find_dirent(self, filename):
# check if we've been passed a dirent instead of a filename
if hasattr(filename, "filename"):
return filename
for dirent in self.files:
if filename == dirent.filename:
if filename_or_dirent == dirent.filename:
return dirent
raise FileNotFound("%s not found on disk" % filename)

View File

@ -1,7 +1,8 @@
import numpy as np
from errors import *
from diskimages import BaseHeader, DiskImageBase, Directory, VTOC, WriteableSector, BaseSectorList
from diskimages import BaseHeader, DiskImageBase
from utils import Directory, VTOC, WriteableSector, BaseSectorList, Dirent
from segments import DefaultSegment, EmptySegment, ObjSegment, RawTrackSectorSegment, SegmentSaver
import logging
@ -143,7 +144,7 @@ class Dos33Directory(Directory):
current_sector = next_sector
class Dos33Dirent(object):
class Dos33Dirent(Dirent):
format = np.dtype([
('track', 'u1'),
('sector', 'u1'),
@ -153,7 +154,7 @@ class Dos33Dirent(object):
])
def __init__(self, image, file_num=0, bytes=None):
self.file_num = file_num
Dirent.__init__(self, file_num)
self._file_type = 0
self.locked = False
self.deleted = False
@ -170,6 +171,9 @@ class Dos33Dirent(object):
def __str__(self):
return "File #%-2d (%s) %03d %-30s %03d %03d" % (self.file_num, self.summary, self.num_sectors, self.filename, self.track, self.sector)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.filename == other.filename and self.track == other.track and self.sector == other.sector and self.num_sectors == other.num_sectors
type_map = {
0x0: "T", # text

View File

@ -133,6 +133,35 @@ class BaseSectorList(object):
self.sectors.extend(sectors)
class Dirent(object):
"""Abstract base class for a directory entry
"""
def __init__(self, file_num=0):
self.file_num = file_num
def __eq__(self, other):
raise NotImplementedError
def mark_deleted(self):
raise NotImplementedError
def parse_raw_dirent(self, image, bytes):
raise NotImplementedError
def encode_dirent(self):
raise NotImplementedError
def get_sectors_in_vtoc(self, image):
raise NotImplementedError
def start_read(self, image):
raise NotImplementedError
def read_sector(self, image):
raise NotImplementedError
class Directory(BaseSectorList):
def __init__(self, header, num_dirents=-1, sector_class=WriteableSector):
BaseSectorList.__init__(self, header)
@ -167,9 +196,15 @@ class Directory(BaseSectorList):
return dirent
def find_dirent(self, filename):
for dirent in self.dirents.values():
if filename == dirent.filename:
return dirent
if hasattr(filename, "filename"):
# we've been passed a dirent instead of a filename
for dirent in self.dirents.values():
if dirent == filename:
return dirent
else:
for dirent in self.dirents.values():
if filename == dirent.filename:
return dirent
raise FileNotFound("%s not found on disk" % filename)
def save_dirent(self, image, dirent, vtoc, sector_list):