From 0bd1efab10be29f7d5d5a3b1e4cd58746e0624be Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 30 Jun 2026 15:11:00 -0700 Subject: [PATCH] fix: Support ref and def for memory schemas PiperOrigin-RevId: 940680657 --- agentplatform/_genai/agent_engines.py | 27 +++++++ agentplatform/_genai/types/common.py | 7 ++ .../genai/replays/test_structured_memories.py | 79 +++++++++++++++---- 3 files changed, 98 insertions(+), 15 deletions(-) diff --git a/agentplatform/_genai/agent_engines.py b/agentplatform/_genai/agent_engines.py index f160d7327b..3bd6be2f10 100644 --- a/agentplatform/_genai/agent_engines.py +++ b/agentplatform/_genai/agent_engines.py @@ -2272,6 +2272,29 @@ def _set_package_spec( for class_method_spec in class_methods_spec_list ] + def _resolve_context_spec( + self, *, context_spec: Optional[types.ReasoningEngineContextSpecDict] = None + ) -> Optional[types.ReasoningEngineContextSpecDict]: + if context_spec is None: + return None + context_spec_obj = types.ReasoningEngineContextSpec(**context_spec) + if context_spec_obj.memory_bank_config is None: + return context_spec + if context_spec_obj.memory_bank_config.structured_memory_configs is None: + return context_spec + for schema in context_spec_obj.memory_bank_config.structured_memory_configs: + for schema_config in schema.schema_configs: + if not schema_config.memory_json_schema: + continue + # `from_json_schema` handles the resolution of `$ref` paths. + schema_config.memory_schema = genai_types.Schema.from_json_schema( + json_schema=genai_types.JSONSchema( + **schema_config.memory_json_schema + ) + ) + + return json.loads(context_spec_obj.model_dump_json()) + def _create_config( self, *, @@ -2338,6 +2361,7 @@ def _create_config( config["description"] = description if context_spec is not None: update_masks.append("context_spec") + context_spec = self._resolve_context_spec(context_spec=context_spec) config["context_spec"] = context_spec if encryption_spec is not None: update_masks.append("encryption_spec") @@ -2720,6 +2744,9 @@ def update( IOError: If `config.requirements` is a string that corresponds to a nonexistent file. """ + # Access context spec if available and convert to dict. + # "Fix" context spec if needed. + # Then run model_validate. if isinstance(config, dict): config = types.AgentEngineConfig.model_validate(config) elif not isinstance(config, types.AgentEngineConfig): diff --git a/agentplatform/_genai/types/common.py b/agentplatform/_genai/types/common.py index a13952014e..5aab3da876 100644 --- a/agentplatform/_genai/types/common.py +++ b/agentplatform/_genai/types/common.py @@ -7540,6 +7540,10 @@ class StructuredMemorySchemaConfig(_common.BaseModel): default=None, description="""Optional. Represents the type of the structured memories associated with the schema. If not set, then `STRUCTURED_PROFILE` will be used.""", ) + memory_json_schema: Optional[Any] = Field( + default=None, + description="""Required. Represents the OpenAPI schema of the structured memories.""", + ) class StructuredMemorySchemaConfigDict(TypedDict, total=False): @@ -7554,6 +7558,9 @@ class StructuredMemorySchemaConfigDict(TypedDict, total=False): memory_type: Optional[MemoryType] """Optional. Represents the type of the structured memories associated with the schema. If not set, then `STRUCTURED_PROFILE` will be used.""" + memory_json_schema: Optional[Any] + """Required. Represents the OpenAPI schema of the structured memories.""" + StructuredMemorySchemaConfigOrDict = Union[ StructuredMemorySchemaConfig, StructuredMemorySchemaConfigDict diff --git a/tests/unit/agentplatform/genai/replays/test_structured_memories.py b/tests/unit/agentplatform/genai/replays/test_structured_memories.py index 5d1c2a39a5..6205f4414e 100644 --- a/tests/unit/agentplatform/genai/replays/test_structured_memories.py +++ b/tests/unit/agentplatform/genai/replays/test_structured_memories.py @@ -14,15 +14,43 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.agentplatform.genai.replays import pytest_helper from agentplatform._genai import types +from tests.unit.agentplatform.genai.replays import pytest_helper +import pydantic def test_generate_and_retrieve_profile(client): - # TODO: Use prod once available. - client._api_client._http_options.base_url = ( - "https://us-central1-autopush-aiplatform.sandbox.googleapis.com" - ) + + class ProfileSchema(pydantic.BaseModel): + + class DemographicDetails(pydantic.BaseModel): + hometown: str + + name: str = pydantic.Field(description="User's name") + demographics: DemographicDetails + + expected_schema = { + "title": "ProfileSchema", + "type": "object", + "required": ["name", "demographics"], + "properties": { + "name": { + "title": "Name", + "description": "User's name", + "type": "string", + }, + "demographics": { + "type": "object", + "properties": { + "hometown": { + "type": "string", + }, + }, + "required": ["hometown"], + }, + }, + } + customization_config = {"disable_natural_language_memories": True} memory_bank_customization_config = types.MemoryBankCustomizationConfig( **customization_config @@ -31,14 +59,17 @@ def test_generate_and_retrieve_profile(client): "scope_keys": ["user_id"], "schema_configs": [ { - "id": "user-profile", - "memory_schema": { - "properties": { - "name": {"description": "User's name", "type": "string"} - }, - "type": "object", - }, - } + "id": "user-profile-1", + "memory_json_schema": ProfileSchema.model_json_schema(), + }, + { + "id": "user-profile-2", + "memory_schema": expected_schema, + }, + { + "id": "user-profile-3", + "memory_json_schema": expected_schema, + }, ], } structured_memory_config_obj = types.StructuredMemoryConfig( @@ -61,8 +92,25 @@ def test_generate_and_retrieve_profile(client): assert memory_bank_config.customization_configs == [ memory_bank_customization_config ] + assert memory_bank_config.structured_memory_configs == [ - structured_memory_config_obj + types.StructuredMemoryConfig( + scope_keys=["user_id"], + schema_configs=[ + types.StructuredMemorySchemaConfig( + id="user-profile-1", + memory_schema=expected_schema, + ), + types.StructuredMemorySchemaConfig( + id="user-profile-2", + memory_schema=expected_schema, + ), + types.StructuredMemorySchemaConfig( + id="user-profile-3", + memory_schema=expected_schema, + ), + ], + ) ] scope = {"user_id": "123"} @@ -86,7 +134,8 @@ def test_generate_and_retrieve_profile(client): response = client.agent_engines.memories.retrieve_profiles( name=agent_engine.api_resource.name, scope=scope ) - assert len(response.profiles) == 1 + # One profile is generated for each schema config. + assert len(response.profiles) == 3 finally: # Clean up resources.