diff --git a/lib/supermon/supermon.go b/lib/supermon/supermon.go index ceadf39..4d07c37 100644 --- a/lib/supermon/supermon.go +++ b/lib/supermon/supermon.go @@ -111,6 +111,22 @@ func (sm SectorMap) FileForSector(track, sector byte) byte { return sm[int(track)*16+int(sector)] } +// SetFileForSector sets the file that owns the given track/sector, or +// returns an error if the track or sector is too high. +func (sm SectorMap) SetFileForSector(track, sector, file byte) error { + if track >= 35 { + return fmt.Errorf("track %d >34", track) + } + if sector >= 16 { + return fmt.Errorf("sector %d >15", sector) + } + if file == FileIllegal || file == FileFree || file == FileReserved { + return fmt.Errorf("illegal file number: 0x%0X", file) + } + sm[int(track)*16+int(sector)] = file + return nil +} + // SectorsForFile returns the list of sectors that belong to the given // file. func (sm SectorMap) SectorsForFile(file byte) []disk.TrackSector { @@ -180,10 +196,13 @@ func (sm SectorMap) WriteFile(sd disk.SectorDisk, file byte, contents []byte, ov OUTER: for track := byte(0); track < sd.Tracks(); track++ { for sector := byte(0); sector < sd.Sectors(); sector++ { - if sm.FileForSector(track, sector) == file { + if sm.FileForSector(track, sector) == FileFree { if err := sd.WritePhysicalSector(track, sector, cts[i*256:(i+1)*256]); err != nil { return err } + if err := sm.SetFileForSector(track, sector, file); err != nil { + return err + } i++ if i == sectorsNeeded { break OUTER diff --git a/lib/supermon/supermon_test.go b/lib/supermon/supermon_test.go index f1f45d3..1d1007d 100644 --- a/lib/supermon/supermon_test.go +++ b/lib/supermon/supermon_test.go @@ -3,9 +3,11 @@ package supermon import ( + "reflect" "strings" "testing" + "github.com/kr/pretty" "github.com/zellyn/diskii/lib/disk" ) @@ -179,3 +181,27 @@ func TestEncodeDecode(t *testing.T) { } } } + +// TestReadWriteSymbolTable tests reading, writing, and re-reading of +// the symbol table, ensuring that no details are lost along the way. +func TestReadWriteSymbolTable(t *testing.T) { + sm, sd, err := loadSectorMap(testDisk) + if err != nil { + t.Fatal(err) + } + st1, err := sm.ReadSymbolTable(sd) + if err != nil { + t.Fatal(err) + } + if err := sm.WriteSymbolTable(sd, st1); err != nil { + t.Fatal(err) + } + st2, err := sm.ReadSymbolTable(sd) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(st1, st2) { + pretty.Ldiff(t, st1, st2) + t.Fatal("Saved and reloaded symbol table differs from original symbol table") + } +}