Skip to content

Commit e334922

Browse files
lucylqfacebook-github-bot
authored andcommitted
FlatTensor alignment tests (#16317)
Summary: 1. Pass segment_alignment from ExecutorchBackendConfig to flat tensor serializer 2. Set segment_alignment=128 as default to match ExecutorchBackendConfig 3. Remove tensor_alignment from the config (do not have multiple tensors per segment) 4. Additional tests for varying segment alignment. Differential Revision: D89422691
1 parent 813e26a commit e334922

File tree

3 files changed

+37
-27
lines changed

3 files changed

+37
-27
lines changed

exir/program/_program.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,9 @@ def __init__(
596596
FlatTensorSerializer,
597597
)
598598

599-
self._data_serializer: DataSerializer = FlatTensorSerializer(FlatTensorConfig(tensor_alignment=4096, segment_alignment=4096))
599+
self._data_serializer: DataSerializer = FlatTensorSerializer(
600+
FlatTensorConfig(self._segment_alignment)
601+
)
600602

601603
def _get_emitter_output(self) -> EmitterOutput:
602604
if self._emitter_output is None:
@@ -1855,7 +1857,9 @@ def __init__(
18551857
FlatTensorSerializer,
18561858
)
18571859

1858-
self._data_serializer = FlatTensorSerializer()
1860+
self._data_serializer = FlatTensorSerializer(
1861+
FlatTensorConfig(segment_alignment=backend_config.segment_alignment)
1862+
)
18591863
self._pte_data, self._tensor_data = serialize_for_executorch(
18601864
self._emitter_output,
18611865
backend_config,

extension/flat_tensor/serialize/serialize.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
# endian.
4040
_HEADER_BYTEORDER: Literal["little"] = "little"
4141

42+
# Alignment of the flatbuffer (after the header).
43+
_FLATBUFFER_ALIGNMENT: int = 16
44+
4245
# Current version. Keep in sync with c++ version number in serialize.
4346
_FLAT_TENSOR_VERSION: int = 0
4447

@@ -95,8 +98,7 @@ def _deserialize_to_flat_tensor(flatbuffer: bytes) -> FlatTensor:
9598

9699
@dataclass
97100
class FlatTensorConfig:
98-
tensor_alignment: int = 16
99-
segment_alignment: int = 16
101+
segment_alignment: int = 128
100102

101103

102104
@dataclass
@@ -334,18 +336,13 @@ def serialize(
334336
)
335337

336338
flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)
337-
padded_flatbuffer_length: int = aligned_size(
338-
input_size=len(flatbuffer_payload),
339-
alignment=self.config.tensor_alignment,
340-
)
341-
342339
padded_header_length: int = aligned_size(
343340
input_size=FlatTensorHeader.EXPECTED_LENGTH,
344-
alignment=self.config.tensor_alignment,
341+
alignment=_FLATBUFFER_ALIGNMENT,
345342
)
346343

347344
segment_base_offset = aligned_size(
348-
padded_flatbuffer_length + padded_header_length,
345+
len(flatbuffer_payload) + padded_header_length,
349346
self.config.segment_alignment,
350347
)
351348

@@ -360,19 +357,16 @@ def serialize(
360357

361358
# Pad header and payload to segment alignment.
362359
header_data = pad_to(header_data, padded_header_length)
363-
original_flatbuffer_payload_size = len(flatbuffer_payload)
364-
flatbuffer_payload.append(
365-
b"\x00" * (padded_flatbuffer_length - len(flatbuffer_payload))
366-
)
367360
injected_flatbuffer_data: bytes = _insert_flatbuffer_header(
368361
flatbuffer_data=flatbuffer_payload.__bytes__(),
369362
magic_regex=r"FT[0-9a-zA-Z][0-9a-zA-Z]",
370363
header_data=header_data,
371364
)
365+
injected_flatbuffer_data = pad_to(injected_flatbuffer_data, segment_base_offset)
372366

373367
eh = _get_extended_header(injected_flatbuffer_data)
374368
assert eh is not None
375-
assert eh.flatbuffer_size == original_flatbuffer_payload_size
369+
assert eh.flatbuffer_size == len(flatbuffer_payload)
376370
assert eh.segment_base_offset == segment_base_offset
377371
assert eh.flatbuffer_offset == padded_header_length
378372
assert eh.segment_data_size == len(aggregated_segment_data)

extension/flat_tensor/test/test_serialize.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from executorch.extension.flat_tensor.serialize.serialize import (
3131
_deserialize_to_flat_tensor,
32+
_FLATBUFFER_ALIGNMENT,
3233
FlatTensorConfig,
3334
FlatTensorHeader,
3435
FlatTensorSerializer,
@@ -109,8 +110,7 @@ def _check_named_data_entries(
109110
f"Named data record {key}.{field.name} does not match.",
110111
)
111112

112-
def test_serialize(self) -> None:
113-
config = FlatTensorConfig()
113+
def _serialize_with_alignment(self, config: FlatTensorConfig) -> None:
114114
serializer: DataSerializer = FlatTensorSerializer(config)
115115
serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))
116116

@@ -120,15 +120,15 @@ def test_serialize(self) -> None:
120120
)
121121
self.assertTrue(header.is_valid())
122122

123-
# Header is aligned to config.segment_alignment, which is where the flatbuffer starts.
123+
# Flatbuffer is non-empty.
124+
self.assertTrue(header.flatbuffer_size > 0)
125+
126+
# Align the flatbuffer to _FLATBUFFER_ALIGNMENT.
124127
self.assertEqual(
125128
header.flatbuffer_offset,
126-
aligned_size(FlatTensorHeader.EXPECTED_LENGTH, config.segment_alignment),
129+
aligned_size(FlatTensorHeader.EXPECTED_LENGTH, _FLATBUFFER_ALIGNMENT),
127130
)
128131

129-
# Flatbuffer is non-empty.
130-
self.assertTrue(header.flatbuffer_size > 0)
131-
132132
# Segment base offset is aligned to config.segment_alignment.
133133
expected_segment_base_offset = aligned_size(
134134
header.flatbuffer_offset + header.flatbuffer_size, config.segment_alignment
@@ -180,12 +180,12 @@ def test_serialize(self) -> None:
180180
segments = flat_tensor.segments
181181
self.assertEqual(len(segments), 3)
182182

183-
# Segment 0 contains fqn1, fqn2; 4 bytes, aligned to config.tensor_alignment.
183+
# Segment 0 contains fqn1, fqn2; 4 bytes, aligned to config.segment_alignment.
184184
self.assertEqual(segments[0].offset, 0)
185185
self.assertEqual(segments[0].size, len(TEST_BUFFER[0]))
186186

187-
# Segment 1 contains fqn3; 32 bytes, aligned to config.tensor_alignment.
188-
self.assertEqual(segments[1].offset, config.tensor_alignment)
187+
# Segment 1 contains fqn3; 32 bytes, aligned to config.segment_alignment.
188+
self.assertEqual(segments[1].offset, config.segment_alignment)
189189
self.assertEqual(segments[1].size, len(TEST_BUFFER[1]))
190190

191191
# Segment 2 contains key0; 17 bytes, aligned to 64.
@@ -194,7 +194,7 @@ def test_serialize(self) -> None:
194194
)
195195
self.assertEqual(
196196
segments[2].offset,
197-
aligned_size(config.tensor_alignment * 3, custom_alignment),
197+
aligned_size(config.segment_alignment * 2, custom_alignment),
198198
)
199199
self.assertEqual(segments[2].size, len(TEST_BUFFER[2]))
200200

@@ -245,6 +245,18 @@ def test_serialize(self) -> None:
245245

246246
self.assertEqual(segments[2].offset + segments[2].size, len(segment_data))
247247

248+
def test_serialize_default_alignment(self) -> None:
249+
config = FlatTensorConfig()
250+
self._serialize_with_alignment(config)
251+
252+
def test_serialize_align_4096(self) -> None:
253+
config = FlatTensorConfig(segment_alignment=4096)
254+
self._serialize_with_alignment(config)
255+
256+
def test_serialize_align_1024(self) -> None:
257+
config = FlatTensorConfig(segment_alignment=1024)
258+
self._serialize_with_alignment(config)
259+
248260
def test_round_trip(self) -> None:
249261
# Serialize and then deserialize the test payload. Make sure it's reconstructed
250262
# properly.

0 commit comments

Comments
 (0)