diff --git a/atrcopy/segments.py b/atrcopy/segments.py index fdd7e5b..6a2f3bf 100644 --- a/atrcopy/segments.py +++ b/atrcopy/segments.py @@ -1134,21 +1134,28 @@ class RawTrackSectorSegment(RawSectorsSegment): def interleave_indexes(segments, num_bytes): num_segments = len(segments) + + # interleave size will be the smallest segment size = len(segments[0]) for s in segments[1:]: - if size != len(s): - raise ValueError("All segments to interleave must be the same size") + if len(s) < size: + size = len(s) + + # adjust if byte spacing is not an even divisor _, rem = divmod(size, num_bytes) - if rem != 0: - raise ValueError("Segment size must be a multiple of the byte interleave") + print("size: %d, rem=%d" % (size, rem)) + size -= rem + print("size: %d, rem=%d" % (size, rem)) + interleave = np.empty(size * num_segments, dtype=np.uint32) - factor = num_bytes * num_segments - start = 0 - for s in segments: - order = s.rawdata.get_indexes_from_base() - for i in range(num_bytes): - interleave[start::factor] = order[i::num_bytes] - start += 1 + if size > 0: + factor = num_bytes * num_segments + start = 0 + for s in segments: + order = s.rawdata.get_indexes_from_base() + for i in range(num_bytes): + interleave[start::factor] = order[i:size:num_bytes] + start += 1 return interleave diff --git a/test/test_segment.py b/test/test_segment.py index 109adeb..719314b 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -168,22 +168,66 @@ class TestIndexed(object): a[6::8] = s2[2::4] a[7::8] = s2[3::4] assert np.array_equal(s[:], a) - - with pytest.raises(ValueError) as e: - s = interleave_segments([s1, s2], 3) - r1 = base.rawdata[512:1025] # 513 byte segment + def test_interleave_not_multiple(self): + base = self.segment + r1 = base.rawdata[512:1024] # 512 byte segment s1 = DefaultSegment(r1, 512) - r2 = base.rawdata[1024:1537] # 513 byte segment + r2 = base.rawdata[1024:1536] # 512 byte segment s2 = DefaultSegment(r2, 1024) + + indexes1 = r1.get_indexes_from_base() + verify1 = np.arange(512, 1024, dtype=np.uint32) + assert np.array_equal(indexes1, verify1) + + indexes2 = r2.get_indexes_from_base() + verify2 = np.arange(1024, 1536, dtype=np.uint32) + assert np.array_equal(indexes2, verify2) + s = interleave_segments([s1, s2], 3) - a = np.empty(len(s1) + len(s2), dtype=np.uint8) - a[0::6] = s1[0::3] - a[1::6] = s1[1::3] - a[2::6] = s1[2::3] - a[3::6] = s2[0::3] - a[4::6] = s2[1::3] - a[5::6] = s2[2::3] + + # when interleave size isn't a multiple of the length, the final array + # will reduce the size of the input array to force it to be a multiple. + size = (len(s1) // 3) * 3 + assert len(s) == size * 2 + a = np.empty(len(s), dtype=np.uint8) + a[0::6] = s1[0:size:3] + a[1::6] = s1[1:size:3] + a[2::6] = s1[2:size:3] + a[3::6] = s2[0:size:3] + a[4::6] = s2[1:size:3] + a[5::6] = s2[2:size:3] + assert np.array_equal(s[:], a) + + def test_interleave_different_sizes(self): + base = self.segment + r1 = base.rawdata[512:768] # 256 byte segment + s1 = DefaultSegment(r1, 512) + r2 = base.rawdata[1024:1536] # 512 byte segment + s2 = DefaultSegment(r2, 1024) + + indexes1 = r1.get_indexes_from_base() + verify1 = np.arange(512, 768, dtype=np.uint32) + assert np.array_equal(indexes1, verify1) + + indexes2 = r2.get_indexes_from_base() + verify2 = np.arange(1024, 1536, dtype=np.uint32) + assert np.array_equal(indexes2, verify2) + + s = interleave_segments([s1, s2], 3) + + # when interleave size isn't a multiple of the length, the final array + # will reduce the size of the input array to force it to be a multiple. + size = (min(len(s1), len(s2)) // 3) * 3 + assert size == (256 // 3) * 3 + assert len(s) == size * 2 + a = np.empty(len(s), dtype=np.uint8) + a[0::6] = s1[0:size:3] + a[1::6] = s1[1:size:3] + a[2::6] = s1[2:size:3] + a[3::6] = s2[0:size:3] + a[4::6] = s2[1:size:3] + a[5::6] = s2[2:size:3] assert np.array_equal(s[:], a) def test_copy(self):