Skip to content

Commit d2a2a8d

Browse files
committed
Support zstd compression in python 3.14+
1 parent bf21956 commit d2a2a8d

3 files changed

Lines changed: 65 additions & 6 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
runs-on: ubuntu-latest
1919
strategy:
2020
matrix:
21-
python: [ "3.10", "3.11", "3.12", "3.13" ]
21+
python: [ "3.10", "3.11", "3.12", "3.13", "3.14" ]
2222
env:
2323
UV_PYTHON: ${{ matrix.python }}
2424
steps:

python/restate/aws_lambda.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def request_to_receive(req: RestateLambdaRequest) -> Receive:
6262
assert req["isBase64Encoded"]
6363
body = base64.b64decode(req["body"])
6464

65+
# Decompress zstd-encoded request body
66+
headers = {k.lower(): v for k, v in req.get("headers", {}).items()}
67+
if "zstd" in headers.get("content-encoding", ""):
68+
body = zstd_decompress(body)
69+
6570
events = cast(
6671
list[HTTPRequestEvent],
6772
[
@@ -79,16 +84,18 @@ async def recv() -> HTTPRequestEvent:
7984

8085
return recv
8186

87+
RESPONSE_COMPRESSION_THRESHOLD = 3 * 1024 * 1024
8288

8389
class ResponseCollector:
8490
"""
8591
Response collector from ASGI Send to Lambda
8692
"""
8793

88-
def __init__(self):
94+
def __init__(self, accept_encoding: str = ""):
8995
self.body = bytearray()
90-
self.headers = {}
96+
self.headers: dict[str, str] = {}
9197
self.status_code = 500
98+
self.accept_encoding = accept_encoding
9299

93100
async def __call__(self, message: Union[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> None:
94101
"""
@@ -105,11 +112,20 @@ def to_lambda_response(self) -> RestateLambdaResponse:
105112
"""
106113
Convert collected values to lambda response
107114
"""
115+
body_bytes = bytes(self.body)
116+
117+
# Compress response if it exceeds threshold and client accepts zstd
118+
if (len(body_bytes) > RESPONSE_COMPRESSION_THRESHOLD
119+
and "zstd" in self.accept_encoding
120+
and zstd_available()):
121+
body_bytes = zstd_compress(body_bytes)
122+
self.headers["content-encoding"] = "zstd"
123+
108124
return {
109125
"statusCode": self.status_code,
110126
"headers": self.headers,
111127
"isBase64Encoded": True,
112-
"body": base64.b64encode(self.body).decode(),
128+
"body": base64.b64encode(body_bytes).decode(),
113129
}
114130

115131

@@ -120,7 +136,6 @@ def is_running_on_lambda() -> bool:
120136
# https://docs.aws.amazon.com/lambda/latest/dg/configuration-envvars.html
121137
return "AWS_LAMBDA_FUNCTION_NAME" in os.environ
122138

123-
124139
def wrap_asgi_as_lambda_handler(asgi_app: ASGIApp) -> RestateLambdaHandler:
125140
"""
126141
Wrap the given asgi_app in a Lambda handler
@@ -134,7 +149,8 @@ def lambda_handler(event: RestateLambdaRequest, _context: Any) -> RestateLambdaR
134149

135150
scope = create_scope(event)
136151
recv = request_to_receive(event)
137-
send = ResponseCollector()
152+
req_headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
153+
send = ResponseCollector(accept_encoding=req_headers.get("accept-encoding", ""))
138154

139155
asgi_instance = asgi_app(scope, recv, send)
140156
asgi_task = loop.create_task(asgi_instance) # type: ignore[var-annotated, arg-type]
@@ -143,3 +159,41 @@ def lambda_handler(event: RestateLambdaRequest, _context: Any) -> RestateLambdaR
143159
return send.to_lambda_response()
144160

145161
return lambda_handler
162+
163+
def get_lambda_compression():
164+
"""Return 'zstd' if running on Lambda and compression.zstd is available (Python 3.14+), else None."""
165+
if is_running_on_lambda() and zstd_available():
166+
return "zstd"
167+
return None
168+
169+
def zstd_available() -> bool:
170+
"""Return True if zstd compression is available (Python 3.14+)."""
171+
try:
172+
import compression.zstd # type: ignore[import-not-found]
173+
return compression.zstd is not None
174+
except ImportError:
175+
return False
176+
177+
178+
def zstd_compress(data: bytes) -> bytes:
179+
"""Compress data using zstd."""
180+
try:
181+
import compression.zstd # type: ignore[import-not-found]
182+
except ImportError as e:
183+
raise RuntimeError(
184+
"zstd compression requested but compression.zstd is not available. "
185+
"Python 3.14+ is required for zstd compression support."
186+
) from e
187+
return compression.zstd.compress(data)
188+
189+
190+
def zstd_decompress(data: bytes) -> bytes:
191+
"""Decompress zstd-compressed data."""
192+
try:
193+
import compression.zstd # type: ignore[import-not-found]
194+
except ImportError as e:
195+
raise RuntimeError(
196+
"Received zstd-compressed request but compression.zstd is not available. "
197+
"Python 3.14+ is required for zstd compression support."
198+
) from e
199+
return compression.zstd.decompress(data)

python/restate/discovery.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from restate.handler import TypeHint
3535
from restate.object import VirtualObject
3636
from restate.workflow import Workflow
37+
from restate.aws_lambda import get_lambda_compression
3738

3839

3940
class ProtocolMode(Enum):
@@ -159,6 +160,7 @@ def __init__(
159160
self.minProtocolVersion = minProtocolVersion
160161
self.maxProtocolVersion = maxProtocolVersion
161162
self.services = services
163+
self.lambdaCompression = get_lambda_compression()
162164

163165

164166
PROTOCOL_MODES = {"bidi": ProtocolMode.BIDI_STREAM, "request_response": ProtocolMode.REQUEST_RESPONSE}
@@ -235,6 +237,9 @@ def compute_discovery_json(
235237

236238
# Validate that new discovery fields aren't used with older protocol versions
237239
if version <= 3:
240+
# Strip lambdaCompression for older discovery versions
241+
ep.lambdaCompression = None
242+
238243
for service in ep.services:
239244
if service.retryPolicyInitialInterval is not None:
240245
raise ValueError("retryPolicyInitialInterval is only supported in discovery protocol version 4")

0 commit comments

Comments
 (0)