Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
EXIREdgeDialectVerifier,
get_aten_verifier,
)
from executorch.extension.flat_tensor.serialize.serialize import FlatTensorConfig
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch._export.utils import _detect_fake_mode_from_gm
from torch._export.verifier import Verifier
Expand Down Expand Up @@ -595,7 +596,9 @@ def __init__(
FlatTensorSerializer,
)

self._data_serializer: DataSerializer = FlatTensorSerializer()
self._data_serializer: DataSerializer = FlatTensorSerializer(
FlatTensorConfig(self._segment_alignment)
)

def _get_emitter_output(self) -> EmitterOutput:
if self._emitter_output is None:
Expand Down Expand Up @@ -1854,7 +1857,9 @@ def __init__(
FlatTensorSerializer,
)

self._data_serializer = FlatTensorSerializer()
self._data_serializer = FlatTensorSerializer(
FlatTensorConfig(segment_alignment=backend_config.segment_alignment)
)
self._pte_data, self._tensor_data = serialize_for_executorch(
self._emitter_output,
backend_config,
Expand Down
22 changes: 8 additions & 14 deletions extension/flat_tensor/serialize/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
# endian.
_HEADER_BYTEORDER: Literal["little"] = "little"

# Alignment of the flatbuffer (after the header).
_FLATBUFFER_ALIGNMENT: int = 16

# Current version. Keep in sync with c++ version number in serialize.
_FLAT_TENSOR_VERSION: int = 0

Expand Down Expand Up @@ -95,8 +98,7 @@ def _deserialize_to_flat_tensor(flatbuffer: bytes) -> FlatTensor:

@dataclass
class FlatTensorConfig:
tensor_alignment: int = 16
segment_alignment: int = 16
segment_alignment: int = 128


@dataclass
Expand Down Expand Up @@ -334,18 +336,13 @@ def serialize(
)

flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)
padded_flatbuffer_length: int = aligned_size(
input_size=len(flatbuffer_payload),
alignment=self.config.tensor_alignment,
)

padded_header_length: int = aligned_size(
input_size=FlatTensorHeader.EXPECTED_LENGTH,
alignment=self.config.tensor_alignment,
alignment=_FLATBUFFER_ALIGNMENT,
)

segment_base_offset = aligned_size(
padded_flatbuffer_length + padded_header_length,
len(flatbuffer_payload) + padded_header_length,
self.config.segment_alignment,
)

Expand All @@ -360,19 +357,16 @@ def serialize(

# Pad header and payload to segment alignment.
header_data = pad_to(header_data, padded_header_length)
original_flatbuffer_payload_size = len(flatbuffer_payload)
flatbuffer_payload.append(
b"\x00" * (padded_flatbuffer_length - len(flatbuffer_payload))
)
injected_flatbuffer_data: bytes = _insert_flatbuffer_header(
flatbuffer_data=flatbuffer_payload.__bytes__(),
magic_regex=r"FT[0-9a-zA-Z][0-9a-zA-Z]",
header_data=header_data,
)
injected_flatbuffer_data = pad_to(injected_flatbuffer_data, segment_base_offset)

eh = _get_extended_header(injected_flatbuffer_data)
assert eh is not None
assert eh.flatbuffer_size == original_flatbuffer_payload_size
assert eh.flatbuffer_size == len(flatbuffer_payload)
assert eh.segment_base_offset == segment_base_offset
assert eh.flatbuffer_offset == padded_header_length
assert eh.segment_data_size == len(aggregated_segment_data)
Expand Down
34 changes: 23 additions & 11 deletions extension/flat_tensor/test/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from executorch.extension.flat_tensor.serialize.serialize import (
_deserialize_to_flat_tensor,
_FLATBUFFER_ALIGNMENT,
FlatTensorConfig,
FlatTensorHeader,
FlatTensorSerializer,
Expand Down Expand Up @@ -109,8 +110,7 @@ def _check_named_data_entries(
f"Named data record {key}.{field.name} does not match.",
)

def test_serialize(self) -> None:
config = FlatTensorConfig()
def _serialize_with_alignment(self, config: FlatTensorConfig) -> None:
serializer: DataSerializer = FlatTensorSerializer(config)
serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))

Expand All @@ -120,15 +120,15 @@ def test_serialize(self) -> None:
)
self.assertTrue(header.is_valid())

# Header is aligned to config.segment_alignment, which is where the flatbuffer starts.
# Flatbuffer is non-empty.
self.assertTrue(header.flatbuffer_size > 0)

# Align the flatbuffer to _FLATBUFFER_ALIGNMENT.
self.assertEqual(
header.flatbuffer_offset,
aligned_size(FlatTensorHeader.EXPECTED_LENGTH, config.segment_alignment),
aligned_size(FlatTensorHeader.EXPECTED_LENGTH, _FLATBUFFER_ALIGNMENT),
)

# Flatbuffer is non-empty.
self.assertTrue(header.flatbuffer_size > 0)

# Segment base offset is aligned to config.segment_alignment.
expected_segment_base_offset = aligned_size(
header.flatbuffer_offset + header.flatbuffer_size, config.segment_alignment
Expand Down Expand Up @@ -180,12 +180,12 @@ def test_serialize(self) -> None:
segments = flat_tensor.segments
self.assertEqual(len(segments), 3)

# Segment 0 contains fqn1, fqn2; 4 bytes, aligned to config.tensor_alignment.
# Segment 0 contains fqn1, fqn2; 4 bytes, aligned to config.segment_alignment.
self.assertEqual(segments[0].offset, 0)
self.assertEqual(segments[0].size, len(TEST_BUFFER[0]))

# Segment 1 contains fqn3; 32 bytes, aligned to config.tensor_alignment.
self.assertEqual(segments[1].offset, config.tensor_alignment)
# Segment 1 contains fqn3; 32 bytes, aligned to config.segment_alignment.
self.assertEqual(segments[1].offset, config.segment_alignment)
self.assertEqual(segments[1].size, len(TEST_BUFFER[1]))

# Segment 2 contains key0; 17 bytes, aligned to 64.
Expand All @@ -194,7 +194,7 @@ def test_serialize(self) -> None:
)
self.assertEqual(
segments[2].offset,
aligned_size(config.tensor_alignment * 3, custom_alignment),
aligned_size(config.segment_alignment * 2, custom_alignment),
)
self.assertEqual(segments[2].size, len(TEST_BUFFER[2]))

Expand Down Expand Up @@ -245,6 +245,18 @@ def test_serialize(self) -> None:

self.assertEqual(segments[2].offset + segments[2].size, len(segment_data))

def test_serialize_default_alignment(self) -> None:
config = FlatTensorConfig()
self._serialize_with_alignment(config)

def test_serialize_align_4096(self) -> None:
config = FlatTensorConfig(segment_alignment=4096)
self._serialize_with_alignment(config)

def test_serialize_align_1024(self) -> None:
config = FlatTensorConfig(segment_alignment=1024)
self._serialize_with_alignment(config)

def test_round_trip(self) -> None:
# Serialize and then deserialize the test payload. Make sure it's reconstructed
# properly.
Expand Down
Loading