Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,13 @@ async def create_system_response_token_message(
self,
message_type: Literal[WebSocketMessageType.RESPONSE_MESSAGE,
WebSocketMessageType.ERROR_MESSAGE] = WebSocketMessageType.RESPONSE_MESSAGE,
message_id: str | None = str(uuid.uuid4()),
message_id: str | None = None,
thread_id: str = "default",
parent_id: str = "default",
conversation_id: str | None = None,
content: SystemResponseContent | Error = SystemResponseContent(),
content: SystemResponseContent | Error | None = None,
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
timestamp: str = str(datetime.datetime.now(datetime.UTC))
timestamp: str | None = None
) -> WebSocketSystemResponseTokenMessage | None:
"""
Creates a system response token message with default values.
Expand All @@ -283,6 +283,13 @@ async def create_system_response_token_message(
:return: A WebSocketSystemResponseTokenMessage instance.
"""
try:
if message_id is None:
message_id = str(uuid.uuid4())
if content is None:
content = SystemResponseContent()
if timestamp is None:
timestamp = str(datetime.datetime.now(datetime.UTC))

return WebSocketSystemResponseTokenMessage(type=message_type,
id=message_id,
thread_id=thread_id,
Expand All @@ -300,13 +307,13 @@ async def create_system_intermediate_step_message(
self,
message_type: Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE] = (
WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE),
message_id: str = str(uuid.uuid4()),
message_id: str | None = None,
thread_id: str = "default",
parent_id: str = "default",
conversation_id: str | None = None,
content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
content: SystemIntermediateStepContent | None = None,
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
timestamp: str = str(datetime.datetime.now(datetime.UTC))
timestamp: str | None = None
) -> WebSocketSystemIntermediateStepMessage | None:
"""
Creates a system intermediate step message with default values.
Expand All @@ -322,6 +329,13 @@ async def create_system_intermediate_step_message(
:return: A WebSocketSystemIntermediateStepMessage instance.
"""
try:
if message_id is None:
message_id = str(uuid.uuid4())
if content is None:
content = SystemIntermediateStepContent(name="default", payload="default")
if timestamp is None:
timestamp = str(datetime.datetime.now(datetime.UTC))

return WebSocketSystemIntermediateStepMessage(type=message_type,
id=message_id,
thread_id=thread_id,
Expand All @@ -340,13 +354,13 @@ async def create_system_interaction_message(
*,
message_type: Literal[WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = (
WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE),
message_id: str | None = str(uuid.uuid4()),
message_id: str | None = None,
thread_id: str = "default",
parent_id: str = "default",
conversation_id: str | None = None,
content: HumanPrompt,
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
timestamp: str = str(datetime.datetime.now(datetime.UTC))
timestamp: str | None = None
) -> WebSocketSystemInteractionMessage | None:
"""
Creates a system interaction message with default values.
Expand All @@ -362,6 +376,11 @@ async def create_system_interaction_message(
:return: A WebSocketSystemInteractionMessage instance.
"""
try:
if message_id is None:
message_id = str(uuid.uuid4())
if timestamp is None:
timestamp = str(datetime.datetime.now(datetime.UTC))

return WebSocketSystemInteractionMessage(type=message_type,
id=message_id,
thread_id=thread_id,
Expand All @@ -378,11 +397,11 @@ async def create_system_interaction_message(
async def create_observability_trace_message(
self,
*,
message_id: str | None = str(uuid.uuid4()),
message_id: str | None = None,
parent_id: str = "default",
conversation_id: str | None = None,
content: ObservabilityTraceContent,
timestamp: str = str(datetime.datetime.now(datetime.UTC))
timestamp: str | None = None
) -> WebSocketObservabilityTraceMessage | None:
"""
Creates an observability trace message.
Expand All @@ -395,6 +414,11 @@ async def create_observability_trace_message(
:return: A WebSocketObservabilityTraceMessage instance.
"""
try:
if message_id is None:
message_id = str(uuid.uuid4())
if timestamp is None:
timestamp = str(datetime.datetime.now(datetime.UTC))

return WebSocketObservabilityTraceMessage(id=message_id,
parent_id=parent_id,
conversation_id=conversation_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
# limitations under the License.
"""Tests for MessageValidator handling of auth_message type."""

import datetime
import uuid

import pytest

from nat.data_models.api_server import ObservabilityTraceContent
from nat.data_models.api_server import WebSocketAuthMessage
from nat.data_models.api_server import WebSocketAuthResponseMessage
from nat.data_models.api_server import WebSocketMessageType
from nat.data_models.api_server import WebSocketSystemResponseTokenMessage
from nat.data_models.interactive import HumanPromptText
from nat.front_ends.fastapi import message_validator as message_validator_module
from nat.front_ends.fastapi.message_validator import MessageValidator


Expand Down Expand Up @@ -91,6 +97,70 @@ async def test_malformed_payload_returns_error(self, validator: MessageValidator
assert isinstance(result, WebSocketSystemResponseTokenMessage)
assert result.type == WebSocketMessageType.ERROR_MESSAGE


class TestMessageFactoryDefaults:
async def test_factories_generate_fresh_ids_and_timestamps_if_defaults_are_omitted(
self, validator: MessageValidator, monkeypatch: pytest.MonkeyPatch
):
ids = iter(
[
uuid.UUID("00000000-0000-0000-0000-000000000001"),
uuid.UUID("00000000-0000-0000-0000-000000000002"),
uuid.UUID("00000000-0000-0000-0000-000000000003"),
uuid.UUID("00000000-0000-0000-0000-000000000004"),
uuid.UUID("00000000-0000-0000-0000-000000000005"),
uuid.UUID("00000000-0000-0000-0000-000000000006"),
uuid.UUID("00000000-0000-0000-0000-000000000007"),
uuid.UUID("00000000-0000-0000-0000-000000000008"),
]
)
timestamps = iter([datetime.datetime(2026, 1, 1, 0, 0, index, tzinfo=datetime.UTC) for index in range(8)])

class FreshDateTime(datetime.datetime):
@classmethod
def now(cls, tz=None):
assert tz is datetime.UTC
return next(timestamps)

monkeypatch.setattr(message_validator_module.uuid, "uuid4", lambda: next(ids))
monkeypatch.setattr(message_validator_module.datetime, "datetime", FreshDateTime)

messages = [
await validator.create_system_response_token_message(),
await validator.create_system_response_token_message(),
await validator.create_system_intermediate_step_message(),
await validator.create_system_intermediate_step_message(),
await validator.create_system_interaction_message(content=HumanPromptText(text="Continue?")),
await validator.create_system_interaction_message(content=HumanPromptText(text="Continue?")),
await validator.create_observability_trace_message(
content=ObservabilityTraceContent(observability_trace_id="trace-1")
),
await validator.create_observability_trace_message(
content=ObservabilityTraceContent(observability_trace_id="trace-1")
),
]

assert [message.id for message in messages] == [
"00000000-0000-0000-0000-000000000001",
"00000000-0000-0000-0000-000000000002",
"00000000-0000-0000-0000-000000000003",
"00000000-0000-0000-0000-000000000004",
"00000000-0000-0000-0000-000000000005",
"00000000-0000-0000-0000-000000000006",
"00000000-0000-0000-0000-000000000007",
"00000000-0000-0000-0000-000000000008",
]
assert [message.timestamp for message in messages] == [
"2026-01-01 00:00:00+00:00",
"2026-01-01 00:00:01+00:00",
"2026-01-01 00:00:02+00:00",
"2026-01-01 00:00:03+00:00",
"2026-01-01 00:00:04+00:00",
"2026-01-01 00:00:05+00:00",
"2026-01-01 00:00:06+00:00",
"2026-01-01 00:00:07+00:00",
]

async def test_missing_payload_returns_error(self, validator: MessageValidator):
raw: dict = {"type": "auth_message"}
result = await validator.validate_message(raw)
Expand Down