diff --git a/AGENTS.md b/AGENTS.md index a281fff511..c653689142 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -65,6 +65,7 @@ uv run tox -e typecheck - The monorepo uses `uv` workspaces. - `tox.ini` defines the test matrix - check it for available test environments. - Do not add `type: ignore` comments. If a type error arises, solve it properly or write a follow-up plan to address it in another PR. +- Annotate function signatures (parameters and return types) and class attributes. Prefer `from __future__ import annotations` over runtime-quoted strings. - When a file uses `from __future__ import annotations`, do not quote type annotations just to avoid forward references. Keep quotes only for expressions still evaluated at runtime, such as `typing.cast(...)`, unless the referenced type is imported at runtime. diff --git a/CLAUDE.md b/CLAUDE.md index 43c994c2d3..ce60f10b9b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1 +1,2 @@ @AGENTS.md + diff --git a/instrumentation-genai/AGENTS.md b/instrumentation-genai/AGENTS.md index 786b973854..bf65fe025d 100644 --- a/instrumentation-genai/AGENTS.md +++ b/instrumentation-genai/AGENTS.md @@ -19,6 +19,13 @@ This layer is responsible only for: Everything else (span creation, metric recording, event emission, context propagation) belongs in `util/opentelemetry-util-genai`. +For GenAI streaming wrappers, prefer the shared `SyncStreamWrapper` and `AsyncStreamWrapper` +helpers from `opentelemetry.util.genai.stream` instead of reimplementing iteration, +close/context-manager, and finalization behavior in provider packages. + +Put provider-specific chunk parsing and telemetry finalization in private hook methods or a +narrow mixin. Do not make async stream wrappers inherit from sync stream wrappers. + ## 2. TelemetryHandler Initialization Construct `TelemetryHandler` once inside `_instrument()`, passing all OTel providers and the diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/CHANGELOG.md b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/CHANGELOG.md index d480c2e880..bbfb3d275c 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/CHANGELOG.md +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Refactor chat completion stream wrappers to use shared GenAI stream lifecycle helpers. + ([#4500](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4500)) - Pass tool definitions from `tools` kwarg to `InferenceInvocation.tool_definitions` so `gen_ai.tool.definitions` span attribute is populated on chat completion spans ([#4554](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4554)) diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_buffers.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_buffers.py new file mode 100644 index 0000000000..ad1535c933 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_buffers.py @@ -0,0 +1,52 @@ +# Copyright The OpenTelemetry Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall + + +class ToolCallBuffer: + def __init__( + self, + index: int, + tool_call_id: str | None, + function_name: str | None, + ) -> None: + self.index: int = index + self.function_name: str | None = function_name + self.tool_call_id: str | None = tool_call_id + self.arguments: list[str] = [] + + def append_arguments(self, arguments: str | None) -> None: + if arguments is not None: + self.arguments.append(arguments) + + +class ChoiceBuffer: + def __init__(self, index: int) -> None: + self.index: int = index + self.finish_reason: str | None = None + self.text_content: list[str] = [] + self.tool_calls_buffers: list[ToolCallBuffer | None] = [] + + def append_text_content(self, content: str) -> None: + self.text_content.append(content) + + def append_tool_call(self, tool_call: ChoiceDeltaToolCall) -> None: + idx = tool_call.index + for _ in range(len(self.tool_calls_buffers), idx + 1): + self.tool_calls_buffers.append(None) + + function = tool_call.function + buffer = self.tool_calls_buffers[idx] + if buffer is None: + buffer = ToolCallBuffer( + idx, + tool_call.id, + function.name if function else None, + ) + self.tool_calls_buffers[idx] = buffer + + if function: + buffer.append_arguments(function.arguments) diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_wrappers.py new file mode 100644 index 0000000000..4c5dfa3297 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_wrappers.py @@ -0,0 +1,220 @@ +# Copyright The OpenTelemetry Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from typing import Optional + +from openai import AsyncStream, Stream +from openai.types.chat import ChatCompletionChunk + +from opentelemetry.semconv._incubating.attributes import ( + openai_attributes as OpenAIAttributes, +) +from opentelemetry.util.genai.invocation import InferenceInvocation +from opentelemetry.util.genai.stream import ( + AsyncStreamWrapper, + SyncStreamWrapper, +) +from opentelemetry.util.genai.types import ( + OutputMessage, + Text, + ToolCallRequest, +) + +from .chat_buffers import ChoiceBuffer + + +class _ChatStreamMixin: + """Chat-specific hooks shared by sync and async stream wrappers.""" + + _self_invocation: InferenceInvocation + _self_capture_content: bool + _self_choice_buffers: list[ChoiceBuffer] + _self_response_id: Optional[str] + _self_response_model: Optional[str] + _self_service_tier: Optional[str] + _self_prompt_tokens: Optional[int] + _self_completion_tokens: Optional[int] + + def _set_response_model(self, chunk: ChatCompletionChunk) -> None: + if self._self_response_model: + return + + if chunk.model: + self._self_response_model = chunk.model + + def _set_response_id(self, chunk: ChatCompletionChunk) -> None: + if self._self_response_id: + return + + if chunk.id: + self._self_response_id = chunk.id + + def _set_response_service_tier(self, chunk: ChatCompletionChunk) -> None: + if self._self_service_tier: + return + + service_tier = getattr(chunk, "service_tier", None) + if service_tier: + self._self_service_tier = service_tier + + def _build_streaming_response(self, chunk: ChatCompletionChunk) -> None: + if chunk.choices is None: + return + + for choice in chunk.choices: + if not choice.delta: + continue + + for idx in range(len(self._self_choice_buffers), choice.index + 1): + self._self_choice_buffers.append(ChoiceBuffer(idx)) + + if choice.finish_reason: + self._self_choice_buffers[ + choice.index + ].finish_reason = choice.finish_reason + + if choice.delta.content is not None: + self._self_choice_buffers[choice.index].append_text_content( + choice.delta.content + ) + + if choice.delta.tool_calls is not None: + for tool_call in choice.delta.tool_calls: + self._self_choice_buffers[choice.index].append_tool_call( + tool_call + ) + + def _set_usage(self, chunk: ChatCompletionChunk) -> None: + usage = getattr(chunk, "usage", None) + if usage: + self._self_completion_tokens = usage.completion_tokens + self._self_prompt_tokens = usage.prompt_tokens + + def _process_chunk(self, chunk: ChatCompletionChunk) -> None: + self._set_response_id(chunk) + self._set_response_model(chunk) + self._set_response_service_tier(chunk) + self._build_streaming_response(chunk) + self._set_usage(chunk) + + def _set_output_messages(self) -> None: + if not self._self_capture_content: # optimization + return + output_messages = [] + for choice in self._self_choice_buffers: + message = OutputMessage( + role="assistant", + finish_reason=choice.finish_reason or "error", + parts=[], + ) + if choice.text_content: + message.parts.append( + Text(content="".join(choice.text_content)) + ) + if choice.tool_calls_buffers: + tool_calls = [] + for tool_call in filter(None, choice.tool_calls_buffers): + arguments = None + arguments_str = "".join(tool_call.arguments) + if arguments_str: + try: + arguments = json.loads(arguments_str) + except json.JSONDecodeError: + arguments = arguments_str + tool_call_part = ToolCallRequest( + name=tool_call.function_name, + id=tool_call.tool_call_id, + arguments=arguments, + ) + tool_calls.append(tool_call_part) + message.parts.extend(tool_calls) + output_messages.append(message) + + self._self_invocation.output_messages = output_messages + + def _on_stream_end(self) -> None: + self._cleanup() + + def _on_stream_error(self, error: BaseException) -> None: + self._cleanup(error) + + def parse(self) -> _ChatStreamMixin: + """Called when using with_raw_response with stream=True.""" + return self + + def _cleanup(self, error: Optional[BaseException] = None) -> None: + self._self_invocation.response_model_name = self._self_response_model + self._self_invocation.response_id = self._self_response_id + self._self_invocation.input_tokens = self._self_prompt_tokens + self._self_invocation.output_tokens = self._self_completion_tokens + finish_reasons = [ + choice.finish_reason + for choice in self._self_choice_buffers + if choice.finish_reason + ] + if finish_reasons: + self._self_invocation.finish_reasons = finish_reasons + if self._self_service_tier: + self._self_invocation.attributes.update( + { + OpenAIAttributes.OPENAI_RESPONSE_SERVICE_TIER: self._self_service_tier + }, + ) + + self._set_output_messages() + + if error: + self._self_invocation.fail(error) + else: + self._self_invocation.stop() + + +class ChatStreamWrapper( + _ChatStreamMixin, + SyncStreamWrapper[ChatCompletionChunk], +): + def __init__( + self, + stream: Stream[ChatCompletionChunk], + invocation: InferenceInvocation, + capture_content: bool, + ) -> None: + super().__init__(stream) + self._self_invocation = invocation + self._self_choice_buffers = [] + self._self_capture_content = capture_content + self._self_response_id = None + self._self_response_model = None + self._self_service_tier = None + self._self_prompt_tokens = None + self._self_completion_tokens = None + + +class AsyncChatStreamWrapper( + _ChatStreamMixin, + AsyncStreamWrapper[ChatCompletionChunk], +): + def __init__( + self, + stream: AsyncStream[ChatCompletionChunk], + invocation: InferenceInvocation, + capture_content: bool, + ) -> None: + super().__init__(stream) + self._self_invocation = invocation + self._self_choice_buffers = [] + self._self_capture_content = capture_content + self._self_response_id = None + self._self_response_model = None + self._self_service_tier = None + self._self_prompt_tokens = None + self._self_completion_tokens = None + + +__all__ = [ + "AsyncChatStreamWrapper", + "ChatStreamWrapper", +] diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py index 833c803503..6554ca870e 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import json from timeit import default_timer from typing import Any, Optional @@ -25,11 +24,10 @@ from opentelemetry.util.genai.invocation import InferenceInvocation from opentelemetry.util.genai.types import ( Error, - OutputMessage, - Text, - ToolCallRequest, ) +from .chat_buffers import ChoiceBuffer +from .chat_wrappers import AsyncChatStreamWrapper, ChatStreamWrapper from .instruments import Instruments from .utils import ( _prepare_output_messages, @@ -229,7 +227,7 @@ async def traced_method(wrapped, instance, args, kwargs): else: parsed_result = result if is_streaming(kwargs): - return ChatStreamWrapper( + return AsyncChatStreamWrapper( parsed_result, chat_invocation, capture_content ) @@ -557,46 +555,6 @@ def _set_embeddings_response_attributes( # Don't set output tokens for embeddings as all tokens are input tokens -class ToolCallBuffer: - def __init__(self, index, tool_call_id, function_name): - self.index = index - self.function_name = function_name - self.tool_call_id = tool_call_id - self.arguments = [] - - def append_arguments(self, arguments): - if arguments is not None: - self.arguments.append(arguments) - - -class ChoiceBuffer: - def __init__(self, index): - self.index = index - self.finish_reason = None - self.text_content = [] - self.tool_calls_buffers = [] - - def append_text_content(self, content): - self.text_content.append(content) - - def append_tool_call(self, tool_call): - idx = tool_call.index - # make sure we have enough tool call buffers - for _ in range(len(self.tool_calls_buffers), idx + 1): - self.tool_calls_buffers.append(None) - - function = tool_call.function - if not self.tool_calls_buffers[idx]: - self.tool_calls_buffers[idx] = ToolCallBuffer( - idx, - tool_call.id, - function.name if function else None, - ) - - if function: - self.tool_calls_buffers[idx].append_arguments(function.arguments) - - class BaseStreamWrapper: response_id: Optional[str] = None response_model: Optional[str] = None @@ -812,7 +770,7 @@ def cleanup(self, error: Optional[BaseException] = None): message["content"] = "".join(choice.text_content) if choice.tool_calls_buffers: tool_calls = [] - for tool_call in choice.tool_calls_buffers: + for tool_call in filter(None, choice.tool_calls_buffers): function = {"name": tool_call.function_name} if self.capture_content: function["arguments"] = "".join(tool_call.arguments) @@ -848,83 +806,3 @@ def cleanup(self, error: Optional[BaseException] = None): else: self.span.end() self._started = False - - -class ChatStreamWrapper(BaseStreamWrapper): - invocation: InferenceInvocation - response_id: Optional[str] = None - response_model: Optional[str] = None - service_tier: Optional[str] = None - finish_reasons: list = [] - prompt_tokens: Optional[int] = None - completion_tokens: Optional[int] = None - - def __init__( - self, - stream: Stream, - invocation: InferenceInvocation, - capture_content: bool, - ): - super().__init__(stream, capture_content=capture_content) - self.stream = stream - self.invocation = invocation - self.choice_buffers = [] - - def _set_output_messages(self): - if not self.capture_content: # optimization - return - output_messages = [] - for choice in self.choice_buffers: - message = OutputMessage( - role="assistant", - finish_reason=choice.finish_reason or "error", - parts=[], - ) - if choice.text_content: - message.parts.append( - Text(content="".join(choice.text_content)) - ) - if choice.tool_calls_buffers: - tool_calls = [] - for tool_call in choice.tool_calls_buffers: - arguments = None - arguments_str = "".join(tool_call.arguments) - if arguments_str: - try: - arguments = json.loads(arguments_str) - except json.JSONDecodeError: - arguments = arguments_str - tool_call_part = ToolCallRequest( - name=tool_call.function_name, - id=tool_call.tool_call_id, - arguments=arguments, - ) - tool_calls.append(tool_call_part) - message.parts.extend(tool_calls) - output_messages.append(message) - - self.invocation.output_messages = output_messages - - def cleanup(self, error: Optional[BaseException] = None): - if not self._started: - return - - self.invocation.response_model_name = self.response_model - self.invocation.response_id = self.response_id - self.invocation.input_tokens = self.prompt_tokens - self.invocation.output_tokens = self.completion_tokens - self.invocation.finish_reasons = self.finish_reasons - if self.service_tier: - self.invocation.attributes.update( - { - OpenAIAttributes.OPENAI_RESPONSE_SERVICE_TIER: self.service_tier - }, - ) - - self._set_output_messages() - - if error: - self.invocation.fail(Error(type=type(error), message=str(error))) - else: - self.invocation.stop() - self._started = False diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/utils.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/utils.py index 81dd700ebb..c691e30412 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/utils.py +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/utils.py @@ -390,7 +390,7 @@ def create_chat_invocation( extra_body = get_value(kwargs.get("extra_body")) if isinstance(extra_body, Mapping): service_tier = get_value(extra_body.get("service_tier")) - if service_tier is not None: + if service_tier is not None and service_tier != "auto": invocation.attributes[OpenAIAttributes.OPENAI_REQUEST_SERVICE_TIER] = ( service_tier ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/requirements.oldest.txt b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/requirements.oldest.txt index 1ebf378cf9..32114829c3 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/requirements.oldest.txt +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/requirements.oldest.txt @@ -32,6 +32,6 @@ opentelemetry-exporter-otlp-proto-http~=1.30 opentelemetry-api==1.40 # when updating, also update in pyproject.toml opentelemetry-sdk==1.40 # when updating, also update in pyproject.toml opentelemetry-semantic-conventions==0.61b0 # when updating, also update in pyproject.toml -opentelemetry-util-genai==0.4b0 # when updating, also update in pyproject.toml +-e util/opentelemetry-util-genai -e instrumentation-genai/opentelemetry-instrumentation-openai-v2 diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_async_chat_completions.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_async_chat_completions.py index 788cd05c48..e04f964409 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_async_chat_completions.py +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_async_chat_completions.py @@ -187,6 +187,18 @@ async def test_async_chat_completion_404( assert "NotFoundError" == spans[0].attributes[ErrorAttributes.ERROR_TYPE] +@pytest.mark.asyncio() +async def test_async_chat_completion_api_exception_propagates( + async_openai_client, instrument_no_content, vcr +): + with vcr.use_cassette("test_async_chat_completion_404.yaml"): + with pytest.raises(NotFoundError): + await async_openai_client.chat.completions.create( + messages=USER_ONLY_PROMPT, + model="this-model-does-not-exist", + ) + + @pytest.mark.asyncio() async def test_async_chat_completion_extra_params( span_exporter, async_openai_client, instrument_no_content, vcr @@ -879,6 +891,44 @@ async def test_async_chat_completion_streaming( ) +@pytest.mark.asyncio() +async def test_async_chat_completion_streaming_user_exception_propagates( + span_exporter, + async_openai_client, + instrument_with_content, + vcr, +): + latest_experimental_enabled = is_experimental_mode() + llm_model_value = "gpt-4" + kwargs = { + "model": llm_model_value, + "messages": USER_ONLY_PROMPT, + "stream": True, + "stream_options": {"include_usage": True}, + } + response_stream_model = None + response_stream_id = None + + with vcr.use_cassette("test_async_chat_completion_streaming.yaml"): + response = await async_openai_client.chat.completions.create(**kwargs) + with pytest.raises(RuntimeError, match="user failure"): + async with response: + async for chunk in response: + response_stream_model = chunk.model + response_stream_id = chunk.id + raise RuntimeError("user failure") + + spans = span_exporter.get_finished_spans() + assert_all_attributes( + spans[0], + llm_model_value, + latest_experimental_enabled, + response_stream_id, + response_stream_model, + ) + assert "RuntimeError" == spans[0].attributes[ErrorAttributes.ERROR_TYPE] + + @pytest.mark.asyncio() async def test_async_chat_completion_streaming_not_complete( span_exporter, @@ -917,7 +967,10 @@ async def test_async_chat_completion_streaming_not_complete( response_stream_id = chunk.id idx += 1 - response.close() + if latest_experimental_enabled: + await response.close() + else: + response.close() spans = span_exporter.get_finished_spans() assert_all_attributes( spans[0], diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_chat_completions.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_chat_completions.py index fca28e606c..07e86a6870 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_chat_completions.py +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_chat_completions.py @@ -277,6 +277,17 @@ def test_chat_completion_404( ) +def test_chat_completion_api_exception_propagates( + openai_client, instrument_no_content, vcr +): + with vcr.use_cassette("test_chat_completion_404.yaml"): + with pytest.raises(NotFoundError): + openai_client.chat.completions.create( + messages=USER_ONLY_PROMPT, + model="this-model-does-not-exist", + ) + + def test_chat_completion_extra_params( span_exporter, openai_client, instrument_no_content, vcr ): @@ -993,6 +1004,123 @@ def test_chat_completion_streaming( ) +def test_chat_completion_streaming_user_exception_propagates( + span_exporter, openai_client, instrument_with_content, vcr +): + latest_experimental_enabled = is_experimental_mode() + kwargs = { + "model": DEFAULT_MODEL, + "messages": USER_ONLY_PROMPT, + "stream": True, + "stream_options": {"include_usage": True}, + } + response_stream_model = None + response_stream_id = None + + with vcr.use_cassette("test_chat_completion_streaming.yaml"): + response = openai_client.chat.completions.create(**kwargs) + with pytest.raises(RuntimeError, match="user failure"): + with response: + for chunk in response: + response_stream_model = chunk.model + response_stream_id = chunk.id + raise RuntimeError("user failure") + + spans = span_exporter.get_finished_spans() + assert_all_attributes( + spans[0], + DEFAULT_MODEL, + latest_experimental_enabled, + response_stream_id, + response_stream_model, + ) + assert "RuntimeError" == spans[0].attributes[ErrorAttributes.ERROR_TYPE] + + +def test_chat_completion_streaming_user_exception_wins_over_close_exception( + span_exporter, openai_client, instrument_with_content, vcr, monkeypatch +): + if not is_experimental_mode(): + pytest.skip("new stream wrapper only") + + kwargs = { + "model": DEFAULT_MODEL, + "messages": USER_ONLY_PROMPT, + "stream": True, + "stream_options": {"include_usage": True}, + } + + with vcr.use_cassette("test_chat_completion_streaming.yaml"): + response = openai_client.chat.completions.create(**kwargs) + original_close = response.__wrapped__.close + + def close_raises(): + original_close() + raise RuntimeError("close failure") + + monkeypatch.setattr(response.__wrapped__, "close", close_raises) + with pytest.raises(RuntimeError, match="user failure"): + with response: + raise RuntimeError("user failure") + + spans = span_exporter.get_finished_spans() + assert "RuntimeError" == spans[0].attributes[ErrorAttributes.ERROR_TYPE] + + +def test_chat_completion_streaming_close_exception_propagates_when_first( + span_exporter, openai_client, instrument_with_content, vcr, monkeypatch +): + if not is_experimental_mode(): + pytest.skip("new stream wrapper only") + + kwargs = { + "model": DEFAULT_MODEL, + "messages": USER_ONLY_PROMPT, + "stream": True, + "stream_options": {"include_usage": True}, + } + + with vcr.use_cassette("test_chat_completion_streaming.yaml"): + response = openai_client.chat.completions.create(**kwargs) + original_close = response.__wrapped__.close + + def close_raises(): + original_close() + raise RuntimeError("close failure") + + monkeypatch.setattr(response.__wrapped__, "close", close_raises) + with pytest.raises(RuntimeError, match="close failure"): + response.close() + + spans = span_exporter.get_finished_spans() + assert "RuntimeError" == spans[0].attributes[ErrorAttributes.ERROR_TYPE] + + +def test_chat_completion_streaming_instrumentation_finalize_errors_swallowed( + span_exporter, openai_client, instrument_with_content, vcr, monkeypatch +): + if not is_experimental_mode(): + pytest.skip("new stream wrapper only") + + kwargs = { + "model": DEFAULT_MODEL, + "messages": USER_ONLY_PROMPT, + "stream": True, + "stream_options": {"include_usage": True}, + } + + with vcr.use_cassette("test_chat_completion_streaming.yaml"): + response = openai_client.chat.completions.create(**kwargs) + + def stop_raises(): + raise RuntimeError("instrumentation failure") + + monkeypatch.setattr(response, "_on_stream_end", stop_raises) + response.close() + + assert span_exporter.get_finished_spans() == () + + def test_chat_completion_streaming_not_complete( span_exporter, log_exporter, openai_client, instrument_with_content, vcr ): diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_choice_buffer.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_choice_buffer.py index 0d9a9bc07d..c489e7f322 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_choice_buffer.py +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_choice_buffer.py @@ -8,7 +8,7 @@ ChoiceDeltaToolCallFunction, ) -from opentelemetry.instrumentation.openai_v2.patch import ( +from opentelemetry.instrumentation.openai_v2.chat_buffers import ( ChoiceBuffer, ToolCallBuffer, ) diff --git a/util/opentelemetry-util-genai/CHANGELOG.md b/util/opentelemetry-util-genai/CHANGELOG.md index 97dcb93090..6c042402e1 100644 --- a/util/opentelemetry-util-genai/CHANGELOG.md +++ b/util/opentelemetry-util-genai/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Add shared sync and async stream wrapper base classes for GenAI instrumentations. + ([#4500](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4500)) - Change `InferenceInvocation` init params to only accept base params - Pass in `attributes` on invocation `_start` so samplers have access to attributes. ([#4538](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4538)) diff --git a/util/opentelemetry-util-genai/pyproject.toml b/util/opentelemetry-util-genai/pyproject.toml index d701e071e6..168e63d893 100644 --- a/util/opentelemetry-util-genai/pyproject.toml +++ b/util/opentelemetry-util-genai/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "opentelemetry-instrumentation ~= 0.61b0", "opentelemetry-semantic-conventions ~= 0.61b0", "opentelemetry-api ~= 1.40", + "wrapt >= 1.0.0, < 3.0.0", ] [project.entry-points.opentelemetry_genai_completion_hook] @@ -49,4 +50,3 @@ include = ["/src", "/tests"] [tool.hatch.build.targets.wheel] packages = ["src/opentelemetry"] - diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/stream.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/stream.py new file mode 100644 index 0000000000..881811461f --- /dev/null +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/stream.py @@ -0,0 +1,307 @@ +# Copyright The OpenTelemetry Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +from abc import ABCMeta, abstractmethod +from types import TracebackType +from typing import ( + TYPE_CHECKING, + AsyncIterable, + Generic, + Iterable, + Literal, + Protocol, + TypeVar, +) + +if TYPE_CHECKING: + + class _ObjectProxy: + def __init__(self, wrapped: object) -> None: ... + +else: + from wrapt import ObjectProxy as _ObjectProxy + + +ChunkT = TypeVar("ChunkT") +_ChunkT_co = TypeVar("_ChunkT_co", covariant=True) +_logger = logging.getLogger(__name__) + + +class _StreamWrapperMeta(ABCMeta, type(_ObjectProxy)): + """Metaclass compatible with wrapt's proxy type and ABC hooks.""" + + +class _SyncStream(Iterable[_ChunkT_co], Protocol[_ChunkT_co]): + """Structural type for streams accepted by ``SyncStreamWrapper``.""" + + def close(self) -> None: ... + + +class _AsyncStream(AsyncIterable[_ChunkT_co], Protocol[_ChunkT_co]): + """Structural type for streams accepted by ``AsyncStreamWrapper``.""" + + async def close(self) -> None: ... + + +class SyncStreamWrapper( + _ObjectProxy, + Generic[ChunkT], + metaclass=_StreamWrapperMeta, +): + """Base class for synchronous instrumented stream wrappers. + + Subclass this when wrapping a provider SDK stream that is consumed with + normal iteration. The subclass should pass the SDK stream to + ``super().__init__(stream)`` and implement the three telemetry hooks: + ``_process_chunk`` for per-chunk state, ``_on_stream_end`` for successful + finalization, and ``_on_stream_error`` for failure finalization. + + Users should consume subclasses as normal streams, for example with + ``for chunk in wrapper`` or ``with wrapper``. The hook methods are called + internally by the wrapper lifecycle and are not part of the public API. + """ + + def __init__(self, stream: _SyncStream[ChunkT]): + super().__init__(stream) + self._self_stream = stream + self._self_iterator = iter(stream) + self._self_finalized = False + + def __enter__(self): + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> Literal[False]: + if exc_val is not None: + self._safe_finalize_failure(exc_val) + try: + self._self_stream.close() + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream close error after user exception", + exc_info=True, + ) + return False + + self.close() + return False + + def close(self) -> None: + try: + self._self_stream.close() + except Exception as error: + self._safe_finalize_failure(error) + raise + self._safe_finalize_success() + + def __iter__(self): + # Override ``ObjectProxy.__iter__`` so iteration drives ``__next__`` + # below and runs ``_process_chunk`` per chunk; otherwise iteration + # would be forwarded to the wrapped stream and bypass instrumentation. + return self + + def __next__(self) -> ChunkT: + try: + chunk = next(self._self_iterator) + except StopIteration: + self._safe_finalize_success() + raise + except Exception as error: + self._safe_finalize_failure(error) + raise + try: + self._process_chunk(chunk) + except Exception as error: # pylint: disable=broad-exception-caught + self._handle_process_chunk_error(error) + return chunk + + def _finalize_success(self) -> None: + if self._self_finalized: + return + self._self_finalized = True + self._on_stream_end() + + def _finalize_failure(self, error: BaseException) -> None: + if self._self_finalized: + return + self._self_finalized = True + self._on_stream_error(error) + + def _safe_finalize_success(self) -> None: + try: + self._finalize_success() + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream instrumentation error during finalization", + exc_info=True, + ) + + def _safe_finalize_failure(self, error: BaseException) -> None: + try: + self._finalize_failure(error) + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream instrumentation error during failure finalization", + exc_info=True, + ) + + @abstractmethod + def _process_chunk(self, chunk: ChunkT) -> None: + """Process one stream chunk for telemetry.""" + + @abstractmethod + def _on_stream_end(self) -> None: + """Finalize the stream successfully.""" + + @abstractmethod + def _on_stream_error(self, error: BaseException) -> None: + """Finalize the stream with failure.""" + + @staticmethod + def _handle_process_chunk_error(_error: Exception) -> None: + _logger.debug( + "GenAI stream instrumentation error during chunk processing", + exc_info=True, + ) + + +class AsyncStreamWrapper( + _ObjectProxy, + Generic[ChunkT], + metaclass=_StreamWrapperMeta, +): + """Base class for asynchronous instrumented stream wrappers. + + Subclass this when wrapping a provider SDK stream that is consumed with + async iteration. The subclass should pass the SDK stream to + ``super().__init__(stream)`` and implement the three telemetry hooks: + ``_process_chunk`` for per-chunk state, ``_on_stream_end`` for successful + finalization, and ``_on_stream_error`` for failure finalization. + + Users should consume subclasses as normal async streams, for example with + ``async for chunk in wrapper`` or ``async with wrapper``. The hook methods + remain synchronous telemetry hooks; async stream reads and close handling + are owned by this base class. + """ + + def __init__(self, stream: _AsyncStream[ChunkT]): + super().__init__(stream) + self._self_stream = stream + self._self_aiter = aiter(stream) + self._self_finalized = False + + async def __aenter__(self): + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> Literal[False]: + if exc_val is not None: + self._safe_finalize_failure(exc_val) + try: + await self._self_stream.close() + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream close error after user exception", + exc_info=True, + ) + return False + + await self.close() + return False + + async def close(self) -> None: + # Named ``close`` (not ``aclose``) to match OpenAI's ``AsyncStream``. + # Revisit when migrating SDKs that expose ``aclose`` instead. + try: + await self._self_stream.close() + except Exception as error: + self._safe_finalize_failure(error) + raise + self._safe_finalize_success() + + def __aiter__(self): + # Override ``ObjectProxy.__aiter__`` so iteration drives ``__anext__`` + # below and runs ``_process_chunk`` per chunk; otherwise iteration + # would be forwarded to the wrapped stream and bypass instrumentation. + return self + + async def __anext__(self) -> ChunkT: + try: + chunk = await anext(self._self_aiter) + except StopAsyncIteration: + self._safe_finalize_success() + raise + except Exception as error: + self._safe_finalize_failure(error) + raise + try: + self._process_chunk(chunk) + except Exception as error: # pylint: disable=broad-exception-caught + self._handle_process_chunk_error(error) + return chunk + + def _finalize_success(self) -> None: + if self._self_finalized: + return + self._self_finalized = True + self._on_stream_end() + + def _finalize_failure(self, error: BaseException) -> None: + if self._self_finalized: + return + self._self_finalized = True + self._on_stream_error(error) + + def _safe_finalize_success(self) -> None: + try: + self._finalize_success() + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream instrumentation error during finalization", + exc_info=True, + ) + + def _safe_finalize_failure(self, error: BaseException) -> None: + try: + self._finalize_failure(error) + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream instrumentation error during failure finalization", + exc_info=True, + ) + + @abstractmethod + def _process_chunk(self, chunk: ChunkT) -> None: + """Process one stream chunk for telemetry.""" + + @abstractmethod + def _on_stream_end(self) -> None: + """Finalize the stream successfully.""" + + @abstractmethod + def _on_stream_error(self, error: BaseException) -> None: + """Finalize the stream with failure.""" + + @staticmethod + def _handle_process_chunk_error(_error: Exception) -> None: + _logger.debug( + "GenAI stream instrumentation error during chunk processing", + exc_info=True, + ) + + +__all__ = [ + "AsyncStreamWrapper", + "SyncStreamWrapper", +] diff --git a/util/opentelemetry-util-genai/tests/test_stream.py b/util/opentelemetry-util-genai/tests/test_stream.py new file mode 100644 index 0000000000..dfbb47958d --- /dev/null +++ b/util/opentelemetry-util-genai/tests/test_stream.py @@ -0,0 +1,542 @@ +# Copyright The OpenTelemetry Authors +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=abstract-class-instantiated + +import asyncio +import inspect + +import pytest + +from opentelemetry.util.genai.stream import ( + AsyncStreamWrapper, + SyncStreamWrapper, +) + + +def test_stream_wrapper_abstract_method_signatures_match(): + method_names = ( + "_process_chunk", + "_on_stream_end", + "_on_stream_error", + "_handle_process_chunk_error", + ) + + for method_name in method_names: + assert inspect.signature( + getattr(SyncStreamWrapper, method_name) + ) == inspect.signature(getattr(AsyncStreamWrapper, method_name)) + + +class _FakeSyncStream: + def __init__(self, chunks=None, error=None, close_error=None): + self._chunks = list(chunks or []) + self._error = error + self._close_error = close_error + self.close_count = 0 + self.extra_attribute = "passthrough" + + def __iter__(self): + return self + + def __next__(self): + if self._chunks: + return self._chunks.pop(0) + if self._error: + raise self._error + raise StopIteration + + def close(self): + self.close_count += 1 + if self._close_error: + raise self._close_error + + def __len__(self): + return 42 + + +class _FakeSyncIterable: + def __init__(self, chunks=None): + self.iterator = iter(chunks or []) + self.close_count = 0 + + def __iter__(self): + return self.iterator + + def close(self): + self.close_count += 1 + + +class _TestSyncStreamWrapper(SyncStreamWrapper): + def __init__(self, stream): + super().__init__(stream) + self._self_processed = [] + self._self_stop_count = 0 + self._self_failures = [] + + def _process_chunk(self, chunk): + self._self_processed.append(chunk) + + def _on_stream_end(self): + self._self_stop_count += 1 + + def _on_stream_error(self, error): + self._self_failures.append(error) + + +class _FailingSyncProcessStreamWrapper(_TestSyncStreamWrapper): + def _process_chunk(self, chunk): + raise ValueError("instrumentation failed") + + +class _FailingSyncStopStreamWrapper(_TestSyncStreamWrapper): + def _on_stream_end(self): + self._self_stop_count += 1 + raise ValueError("instrumentation failed") + + +class _FailingSyncFailStreamWrapper(_TestSyncStreamWrapper): + def _on_stream_error(self, error): + self._self_failures.append(error) + raise ValueError("instrumentation failed") + + +def test_sync_stream_wrapper_processes_chunks_and_stops(): + stream = _FakeSyncStream(chunks=["chunk"]) + wrapper = _TestSyncStreamWrapper(stream) + + assert next(wrapper) == "chunk" + assert wrapper._self_processed == ["chunk"] + + try: + next(wrapper) + except StopIteration: + pass + + assert wrapper._self_stop_count == 1 + + +def test_sync_stream_wrapper_processes_iterables(): + stream = _FakeSyncIterable(chunks=["chunk"]) + wrapper = _TestSyncStreamWrapper(stream) + + assert next(wrapper) == "chunk" + assert wrapper._self_processed == ["chunk"] + + with pytest.raises(StopIteration): + next(wrapper) + + assert wrapper._self_stop_count == 1 + + +def test_sync_stream_wrapper_fails_stream_errors(): + error = ValueError("boom") + wrapper = _TestSyncStreamWrapper(_FakeSyncStream(error=error)) + + try: + next(wrapper) + except ValueError: + pass + + assert wrapper._self_failures == [error] + + +def test_sync_stream_wrapper_close_stops_once(): + stream = _FakeSyncStream(chunks=["chunk"]) + wrapper = _TestSyncStreamWrapper(stream) + + wrapper.close() + wrapper.close() + + assert stream.close_count == 2 + assert wrapper._self_stop_count == 1 + assert not wrapper._self_failures + + +def test_sync_stream_wrapper_close_fails_with_close_error(): + error = RuntimeError("close failure") + wrapper = _TestSyncStreamWrapper( + _FakeSyncStream(chunks=["chunk"], close_error=error) + ) + + with pytest.raises(RuntimeError, match="close failure"): + wrapper.close() + + assert wrapper._self_failures == [error] + assert wrapper._self_stop_count == 0 + + +def test_sync_stream_wrapper_exit_closes_and_propagates_user_errors(): + stream = _FakeSyncStream(chunks=["chunk"]) + wrapper = _TestSyncStreamWrapper(stream) + error = RuntimeError("user failure") + + assert wrapper.__exit__(RuntimeError, error, None) is False + + assert stream.close_count == 1 + assert wrapper._self_stop_count == 0 + assert wrapper._self_failures == [error] + + +def test_sync_stream_wrapper_exit_keeps_user_error_when_close_fails(): + close_error = RuntimeError("close failure") + stream = _FakeSyncStream(chunks=["chunk"], close_error=close_error) + wrapper = _TestSyncStreamWrapper(stream) + error = RuntimeError("user failure") + + assert wrapper.__exit__(RuntimeError, error, None) is False + + assert stream.close_count == 1 + assert wrapper._self_failures == [error] + assert wrapper._self_stop_count == 0 + + +def test_sync_stream_wrapper_swallows_finalize_errors(): + wrapper = _FailingSyncStopStreamWrapper(_FakeSyncStream()) + + wrapper.close() + wrapper.close() + + assert wrapper._self_stop_count == 1 + + +def test_sync_stream_wrapper_swallows_failure_finalize_errors(): + close_error = RuntimeError("close failure") + stream = _FakeSyncStream(close_error=close_error) + wrapper = _FailingSyncFailStreamWrapper(stream) + + with pytest.raises(RuntimeError, match="close failure"): + wrapper.close() + stream._close_error = None + wrapper.close() + + assert wrapper._self_failures == [close_error] + + +def test_sync_stream_wrapper_swallows_stop_iteration_finalize_errors(): + wrapper = _FailingSyncStopStreamWrapper(_FakeSyncStream()) + + with pytest.raises(StopIteration): + next(wrapper) + + +def test_sync_stream_wrapper_preserves_stream_error_when_finalize_fails(): + error = RuntimeError("stream failure") + wrapper = _FailingSyncFailStreamWrapper(_FakeSyncStream(error=error)) + + with pytest.raises(RuntimeError, match="stream failure"): + next(wrapper) + + +def test_sync_stream_wrapper_getattr_passthrough(): + wrapper = _TestSyncStreamWrapper(_FakeSyncStream()) + + assert wrapper.extra_attribute == "passthrough" + + +def test_sync_stream_wrapper_exposes_wrapped_stream(): + stream = _FakeSyncStream() + wrapper = _TestSyncStreamWrapper(stream) + + assert getattr(wrapper, "__wrapped__") is stream + + +def test_sync_stream_wrapper_magic_method_passthrough(): + wrapper = _TestSyncStreamWrapper(_FakeSyncStream()) + + assert len(wrapper) == 42 + + +def test_sync_stream_wrapper_stop_iteration_does_not_double_finalize(): + wrapper = _TestSyncStreamWrapper(_FakeSyncStream()) + + with pytest.raises(StopIteration): + next(wrapper) + wrapper.close() + + assert wrapper._self_stop_count == 1 + assert not wrapper._self_failures + + +def test_sync_stream_wrapper_swallows_process_chunk_errors(): + wrapper = _FailingSyncProcessStreamWrapper( + _FakeSyncStream(chunks=["chunk"]) + ) + + assert next(wrapper) == "chunk" + assert not wrapper._self_failures + + +class _FakeAsyncStream: + def __init__(self, chunks=None, error=None, close_error=None): + self._chunks = list(chunks or []) + self._error = error + self._close_error = close_error + self.close_count = 0 + self.extra_attribute = "passthrough" + + def __aiter__(self): + return self + + async def __anext__(self): + if self._chunks: + return self._chunks.pop(0) + if self._error: + raise self._error + raise StopAsyncIteration + + async def close(self): + self.close_count += 1 + if self._close_error: + raise self._close_error + + def __len__(self): + return 42 + + +class _FakeAsyncIterable: + def __init__(self, chunks=None): + self.iterator = _FakeAsyncStream(chunks=chunks) + self.close_count = 0 + + def __aiter__(self): + return self.iterator + + async def close(self): + self.close_count += 1 + + +class _TestAsyncStreamWrapper(AsyncStreamWrapper): + def __init__(self, stream): + super().__init__(stream) + self._self_processed = [] + self._self_stop_count = 0 + self._self_failures = [] + + def _process_chunk(self, chunk): + self._self_processed.append(chunk) + + def _on_stream_end(self): + self._self_stop_count += 1 + + def _on_stream_error(self, error): + self._self_failures.append(error) + + +class _FailingAsyncProcessStreamWrapper(_TestAsyncStreamWrapper): + def _process_chunk(self, chunk): + raise ValueError("instrumentation failed") + + +class _FailingAsyncStopStreamWrapper(_TestAsyncStreamWrapper): + def _on_stream_end(self): + self._self_stop_count += 1 + raise ValueError("instrumentation failed") + + +class _FailingAsyncFailStreamWrapper(_TestAsyncStreamWrapper): + def _on_stream_error(self, error): + self._self_failures.append(error) + raise ValueError("instrumentation failed") + + +def test_async_stream_wrapper_processes_chunks_and_stops(): + async def exercise(): + wrapper = _TestAsyncStreamWrapper(_FakeAsyncStream(chunks=["chunk"])) + + assert await anext(wrapper) == "chunk" + assert wrapper._self_processed == ["chunk"] + + try: + await anext(wrapper) + except StopAsyncIteration: + pass + + assert wrapper._self_stop_count == 1 + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_processes_async_iterables(): + async def exercise(): + stream = _FakeAsyncIterable(chunks=["chunk"]) + wrapper = _TestAsyncStreamWrapper(stream) + + assert await anext(wrapper) == "chunk" + assert wrapper._self_processed == ["chunk"] + + with pytest.raises(StopAsyncIteration): + await anext(wrapper) + + assert wrapper._self_stop_count == 1 + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_fails_stream_errors(): + async def exercise(): + error = ValueError("boom") + wrapper = _TestAsyncStreamWrapper(_FakeAsyncStream(error=error)) + + with pytest.raises(ValueError): + await anext(wrapper) + + assert wrapper._self_failures == [error] + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_close_stops_once(): + async def exercise(): + stream = _FakeAsyncStream(chunks=["chunk"]) + wrapper = _TestAsyncStreamWrapper(stream) + + await wrapper.close() + await wrapper.close() + + assert stream.close_count == 2 + assert wrapper._self_stop_count == 1 + assert not wrapper._self_failures + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_close_fails_with_close_error(): + async def exercise(): + error = RuntimeError("close failure") + wrapper = _TestAsyncStreamWrapper( + _FakeAsyncStream(chunks=["chunk"], close_error=error) + ) + + with pytest.raises(RuntimeError, match="close failure"): + await wrapper.close() + + assert wrapper._self_failures == [error] + assert wrapper._self_stop_count == 0 + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_exit_closes_and_propagates_user_errors(): + async def exercise(): + stream = _FakeAsyncStream(chunks=["chunk"]) + wrapper = _TestAsyncStreamWrapper(stream) + error = RuntimeError("user failure") + + assert await wrapper.__aexit__(RuntimeError, error, None) is False + + assert stream.close_count == 1 + assert wrapper._self_stop_count == 0 + assert wrapper._self_failures == [error] + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_exit_keeps_user_error_when_close_fails(): + async def exercise(): + close_error = RuntimeError("close failure") + stream = _FakeAsyncStream(chunks=["chunk"], close_error=close_error) + wrapper = _TestAsyncStreamWrapper(stream) + error = RuntimeError("user failure") + + assert await wrapper.__aexit__(RuntimeError, error, None) is False + + assert stream.close_count == 1 + assert wrapper._self_failures == [error] + assert wrapper._self_stop_count == 0 + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_swallows_finalize_errors(): + async def exercise(): + wrapper = _FailingAsyncStopStreamWrapper(_FakeAsyncStream()) + + await wrapper.close() + await wrapper.close() + + assert wrapper._self_stop_count == 1 + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_swallows_failure_finalize_errors(): + async def exercise(): + close_error = RuntimeError("close failure") + stream = _FakeAsyncStream(close_error=close_error) + wrapper = _FailingAsyncFailStreamWrapper(stream) + + with pytest.raises(RuntimeError, match="close failure"): + await wrapper.close() + stream._close_error = None + await wrapper.close() + + assert wrapper._self_failures == [close_error] + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_swallows_stop_iteration_finalize_errors(): + async def exercise(): + wrapper = _FailingAsyncStopStreamWrapper(_FakeAsyncStream()) + + with pytest.raises(StopAsyncIteration): + await anext(wrapper) + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_preserves_stream_error_when_finalize_fails(): + async def exercise(): + error = RuntimeError("stream failure") + wrapper = _FailingAsyncFailStreamWrapper(_FakeAsyncStream(error=error)) + + with pytest.raises(RuntimeError, match="stream failure"): + await anext(wrapper) + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_getattr_passthrough(): + wrapper = _TestAsyncStreamWrapper(_FakeAsyncStream()) + + assert wrapper.extra_attribute == "passthrough" + + +def test_async_stream_wrapper_exposes_wrapped_stream(): + stream = _FakeAsyncStream() + wrapper = _TestAsyncStreamWrapper(stream) + + assert getattr(wrapper, "__wrapped__") is stream + + +def test_async_stream_wrapper_magic_method_passthrough(): + wrapper = _TestAsyncStreamWrapper(_FakeAsyncStream()) + + assert len(wrapper) == 42 + + +def test_async_stream_wrapper_stop_iteration_does_not_double_finalize(): + async def exercise(): + wrapper = _TestAsyncStreamWrapper(_FakeAsyncStream()) + + with pytest.raises(StopAsyncIteration): + await anext(wrapper) + await wrapper.close() + + assert wrapper._self_stop_count == 1 + assert not wrapper._self_failures + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_swallows_process_chunk_errors(): + async def exercise(): + wrapper = _FailingAsyncProcessStreamWrapper( + _FakeAsyncStream(chunks=["chunk"]) + ) + + assert await anext(wrapper) == "chunk" + assert not wrapper._self_failures + + asyncio.run(exercise()) diff --git a/uv.lock b/uv.lock index e015600132..5ddaaab095 100644 --- a/uv.lock +++ b/uv.lock @@ -4416,6 +4416,7 @@ dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-semantic-conventions" }, + { name = "wrapt" }, ] [package.optional-dependencies] @@ -4433,6 +4434,7 @@ requires-dist = [ { name = "opentelemetry-instrumentation", editable = "opentelemetry-instrumentation" }, { name = "opentelemetry-semantic-conventions", git = "https://github.com/open-telemetry/opentelemetry-python?subdirectory=opentelemetry-semantic-conventions&branch=main" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=7.0.0" }, + { name = "wrapt", specifier = ">=1.0.0,<3.0.0" }, ] provides-extras = ["test", "upload"]