From 86de715ba43afe41082ec8204bd9e067e42eaf53 Mon Sep 17 00:00:00 2001 From: Darshan Mehta Date: Fri, 23 Jan 2026 09:29:45 -0800 Subject: [PATCH] feat: RAG - Add Serverless and Spanner modes in preview. PiperOrigin-RevId: 860139967 --- .../vertex_rag/test_rag_constants_preview.py | 63 +++++ .../unit/vertex_rag/test_rag_data_preview.py | 240 ++++++++++++++++++ vertexai/preview/rag/__init__.py | 4 + vertexai/preview/rag/utils/_gapic_utils.py | 68 ++++- vertexai/preview/rag/utils/resources.py | 24 +- 5 files changed, 392 insertions(+), 7 deletions(-) diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py index 594a4c3cad..0c0f3c810c 100644 --- a/tests/unit/vertex_rag/test_rag_constants_preview.py +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -69,10 +69,12 @@ RankService, Ranking, Scaled, + Serverless, SharePointSource, SharePointSources, SlackChannel, SlackChannelsSource, + Spanner, Unprovisioned, VertexAiSearchConfig, VertexFeatureStore, @@ -584,6 +586,34 @@ TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME = ( f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragEngineConfig" ) +TEST_RAG_ENGINE_CONFIG_SERVERLESS = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig(mode=Serverless()), +) +TEST_RAG_ENGINE_CONFIG_SPANNER_BASIC = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig( + mode=Spanner(tier=Basic()), + ), +) +TEST_RAG_ENGINE_CONFIG_SPANNER_SCALED = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig( + mode=Spanner(tier=Scaled()), + ), +) +TEST_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig( + mode=Spanner(tier=Unprovisioned()), + ), +) +TEST_RAG_ENGINE_CONFIG_SPANNER_NO_TIER = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig( + mode=Spanner(), + ), +) TEST_RAG_ENGINE_CONFIG_BASIC = RagEngineConfig( name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, rag_managed_db_config=RagManagedDbConfig(tier=Basic()), @@ -604,6 +634,39 @@ name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, rag_managed_db_config=None, ) +TEST_BAD_RAG_ENGINE_CONFIG_WITH_MODE_AND_TIER = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig( + mode=Spanner(tier=Basic()), + tier=Scaled(), + ), +) +TEST_GAPIC_RAG_ENGINE_CONFIG_SERVERLESS = GapicRagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=GapicRagManagedDbConfig( + serverless=GapicRagManagedDbConfig.Serverless() + ), +) +TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_BASIC = GapicRagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=GapicRagManagedDbConfig( + spanner=GapicRagManagedDbConfig.Spanner(basic=GapicRagManagedDbConfig.Basic()) + ), +) +TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_SCALED = GapicRagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=GapicRagManagedDbConfig( + spanner=GapicRagManagedDbConfig.Spanner(scaled=GapicRagManagedDbConfig.Scaled()) + ), +) +TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED = GapicRagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=GapicRagManagedDbConfig( + spanner=GapicRagManagedDbConfig.Spanner( + unprovisioned=GapicRagManagedDbConfig.Unprovisioned() + ) + ), +) TEST_GAPIC_RAG_ENGINE_CONFIG_BASIC = GapicRagEngineConfig( name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, rag_managed_db_config=GapicRagManagedDbConfig( diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py index b887d40f52..b1e7d4c3b0 100644 --- a/tests/unit/vertex_rag/test_rag_data_preview.py +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -492,6 +492,91 @@ def update_rag_engine_config_enterprise_mock(): yield update_rag_engine_config_enterprise_mock +@pytest.fixture() +def update_rag_engine_config_serverless_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_serverless_mock: + update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_engine_config_lro_mock.done.return_value = True + update_rag_engine_config_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SERVERLESS + ) + update_rag_engine_config_serverless_mock.return_value = ( + update_rag_engine_config_lro_mock + ) + yield update_rag_engine_config_serverless_mock + + +@pytest.fixture() +def update_rag_engine_config_spanner_basic_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_spanner_basic_mock: + update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_engine_config_lro_mock.done.return_value = True + update_rag_engine_config_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_BASIC + ) + update_rag_engine_config_spanner_basic_mock.return_value = ( + update_rag_engine_config_lro_mock + ) + yield update_rag_engine_config_spanner_basic_mock + + +@pytest.fixture() +def update_rag_engine_config_spanner_scaled_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_spanner_scaled_mock: + update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_engine_config_lro_mock.done.return_value = True + update_rag_engine_config_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_SCALED + ) + update_rag_engine_config_spanner_scaled_mock.return_value = ( + update_rag_engine_config_lro_mock + ) + yield update_rag_engine_config_spanner_scaled_mock + + +@pytest.fixture() +def update_rag_engine_config_spanner_no_tier_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_spanner_no_tier_mock: + update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_engine_config_lro_mock.done.return_value = True + update_rag_engine_config_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_SCALED + ) + update_rag_engine_config_spanner_no_tier_mock.return_value = ( + update_rag_engine_config_lro_mock + ) + yield update_rag_engine_config_spanner_no_tier_mock + + +@pytest.fixture() +def update_rag_engine_config_spanner_unprovisioned_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_spanner_unprovisioned_mock: + update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_engine_config_lro_mock.done.return_value = True + update_rag_engine_config_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED + ) + update_rag_engine_config_spanner_unprovisioned_mock.return_value = ( + update_rag_engine_config_lro_mock + ) + yield update_rag_engine_config_spanner_unprovisioned_mock + + @pytest.fixture() def update_rag_engine_config_scaled_mock(): with mock.patch.object( @@ -584,6 +669,54 @@ def get_rag_engine_enterprise_config_mock(): yield get_rag_engine_enterprise_config_mock +@pytest.fixture() +def get_rag_engine_spanner_basic_config_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "get_rag_engine_config", + ) as get_rag_engine_spanner_basic_config_mock: + get_rag_engine_spanner_basic_config_mock.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_BASIC + ) + yield get_rag_engine_spanner_basic_config_mock + + +@pytest.fixture() +def get_rag_engine_spanner_scaled_config_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "get_rag_engine_config", + ) as get_rag_engine_spanner_scaled_config_mock: + get_rag_engine_spanner_scaled_config_mock.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_SCALED + ) + yield get_rag_engine_spanner_scaled_config_mock + + +@pytest.fixture() +def get_rag_engine_spanner_unprovisioned_config_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "get_rag_engine_config", + ) as get_rag_engine_spanner_unprovisioned_config_mock: + get_rag_engine_spanner_unprovisioned_config_mock.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED + ) + yield get_rag_engine_spanner_unprovisioned_config_mock + + +@pytest.fixture() +def get_rag_engine_serverless_config_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "get_rag_engine_config", + ) as get_rag_engine_serverless_config_mock: + get_rag_engine_serverless_config_mock.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SERVERLESS + ) + yield get_rag_engine_serverless_config_mock + + @pytest.fixture() def get_rag_engine_config_mock_exception(): with mock.patch.object( @@ -1765,6 +1898,73 @@ def test_update_rag_engine_config_unprovisioned_success( test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED, ) + def test_update_rag_engine_config_spanner_basic_success( + self, update_rag_engine_config_spanner_basic_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_BASIC, + ) + assert update_rag_engine_config_spanner_basic_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_BASIC, + ) + + def test_update_rag_engine_config_spanner_scaled_success( + self, update_rag_engine_config_spanner_scaled_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_SCALED, + ) + assert update_rag_engine_config_spanner_scaled_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_SCALED, + ) + + def test_update_rag_engine_config_spanner_unprovisioned_success( + self, update_rag_engine_config_spanner_unprovisioned_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED, + ) + assert update_rag_engine_config_spanner_unprovisioned_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED, + ) + + def test_update_rag_engine_config_spanner_no_tier_success( + self, update_rag_engine_config_spanner_no_tier_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_NO_TIER, + ) + assert update_rag_engine_config_spanner_no_tier_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_SCALED, + ) + + def test_update_rag_engine_config_serverless_success( + self, update_rag_engine_config_serverless_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SERVERLESS, + ) + assert update_rag_engine_config_serverless_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SERVERLESS, + ) + + def test_update_rag_engine_config_with_mode_and_tier_failure(self): + with pytest.raises(ValueError) as e: + rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_BAD_RAG_ENGINE_CONFIG_WITH_MODE_AND_TIER, + ) + e.match("mode and tier both cannot be set at the same time") + @pytest.mark.usefixtures("update_rag_engine_config_mock_exception") def test_update_rag_engine_config_failure(self): with pytest.raises(RuntimeError) as e: @@ -1786,6 +1986,46 @@ def test_update_rag_engine_config_bad_input( test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_BASIC, ) + @pytest.mark.usefixtures("get_rag_engine_serverless_config_mock") + def test_get_rag_engine_config_serverless_success(self): + rag_config = rag.get_rag_engine_config( + name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + ) + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SERVERLESS, + ) + + @pytest.mark.usefixtures("get_rag_engine_spanner_basic_config_mock") + def test_get_rag_engine_config_spanner_basic_success(self): + rag_config = rag.get_rag_engine_config( + name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + ) + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_BASIC, + ) + + @pytest.mark.usefixtures("get_rag_engine_spanner_scaled_config_mock") + def test_get_rag_engine_config_spanner_scaled_success(self): + rag_config = rag.get_rag_engine_config( + name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + ) + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_SCALED, + ) + + @pytest.mark.usefixtures("get_rag_engine_spanner_unprovisioned_config_mock") + def test_get_rag_engine_config_spanner_unprovisioned_success(self): + rag_config = rag.get_rag_engine_config( + name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + ) + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED, + ) + @pytest.mark.usefixtures("get_rag_engine_basic_config_mock") def test_get_rag_engine_config_success(self): rag_config = rag.get_rag_engine_config( diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index 44d9644901..86d7f1ef45 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -66,10 +66,12 @@ RankService, Ranking, Scaled, + Serverless, SharePointSource, SharePointSources, SlackChannel, SlackChannelsSource, + Spanner, TransformationConfig, Unprovisioned, VertexAiSearchConfig, @@ -111,10 +113,12 @@ "RankService", "Retrieval", "Scaled", + "Serverless", "SharePointSource", "SharePointSources", "SlackChannel", "SlackChannelsSource", + "Spanner", "TransformationConfig", "Unprovisioned", "VertexAiSearchConfig", diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index a44dbf16a3..4150111c60 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -64,8 +64,10 @@ Basic, Enterprise, Scaled, + Serverless, SharePointSources, SlackChannelsSource, + Spanner, TransformationConfig, Unprovisioned, VertexAiSearchConfig, @@ -1011,6 +1013,22 @@ def set_backend_config( ) +def _convert_gapic_to_spanner( + gapic_spanner: GapicRagManagedDbConfig.Spanner, +) -> Spanner: + """Converts a GapicRagManagedDbConfig.Spanner to a Spanner.""" + spanner = Spanner() + if gapic_spanner.__contains__("scaled"): + spanner.tier = Scaled() + elif gapic_spanner.__contains__("basic"): + spanner.tier = Basic() + elif gapic_spanner.__contains__("unprovisioned"): + spanner.tier = Unprovisioned() + else: + raise ValueError("At least one of scaled, basic, or unprovisioned must be set.") + return spanner + + def convert_gapic_to_rag_engine_config( response: GapicRagEngineConfig, ) -> RagEngineConfig: @@ -1018,7 +1036,13 @@ def convert_gapic_to_rag_engine_config( rag_managed_db_config = RagManagedDbConfig() # If future fields are added with similar names, beware that __contains__ # may match them. - if response.rag_managed_db_config.__contains__("enterprise"): + if response.rag_managed_db_config.__contains__("spanner"): + rag_managed_db_config.mode = _convert_gapic_to_spanner( + response.rag_managed_db_config.spanner + ) + elif response.rag_managed_db_config.__contains__("serverless"): + rag_managed_db_config.mode = Serverless() + elif response.rag_managed_db_config.__contains__("enterprise"): rag_managed_db_config.tier = Enterprise() elif response.rag_managed_db_config.__contains__("basic"): rag_managed_db_config.tier = Basic() @@ -1027,27 +1051,59 @@ def convert_gapic_to_rag_engine_config( elif response.rag_managed_db_config.__contains__("scaled"): rag_managed_db_config.tier = Scaled() else: - raise ValueError("At least one of rag_managed_db_config must be set.") + raise ValueError("At least one of rag_managed_db_config mode must be set.") return RagEngineConfig( name=response.name, rag_managed_db_config=rag_managed_db_config, ) +def _convert_spanner_to_gapic( + spanner: Spanner, +) -> GapicRagManagedDbConfig.Spanner: + """Converts a Spanner to a GapicRagManagedDbConfig.Spanner.""" + gapic_spanner = GapicRagManagedDbConfig.Spanner() + if isinstance(spanner.tier, Scaled): + gapic_spanner.scaled = GapicRagManagedDbConfig.Scaled() + elif isinstance(spanner.tier, Basic): + gapic_spanner.basic = GapicRagManagedDbConfig.Basic() + elif isinstance(spanner.tier, Unprovisioned): + gapic_spanner.unprovisioned = GapicRagManagedDbConfig.Unprovisioned() + return gapic_spanner + + def convert_rag_engine_config_to_gapic( rag_engine_config: RagEngineConfig, ) -> GapicRagEngineConfig: """Converts a RagEngineConfig to a GapicRagEngineConfig.""" rag_managed_db_config = GapicRagManagedDbConfig() if ( - rag_engine_config.rag_managed_db_config is None - or rag_engine_config.rag_managed_db_config.tier is None + rag_engine_config.rag_managed_db_config is not None + and rag_engine_config.rag_managed_db_config.mode is not None + and rag_engine_config.rag_managed_db_config.tier is not None + ): + raise ValueError( + "mode and tier both cannot be set at the same time. Please set" + " the tier inside the Spanner mode." + ) + + if rag_engine_config.rag_managed_db_config is None or ( + rag_engine_config.rag_managed_db_config.tier is None + and rag_engine_config.rag_managed_db_config.mode is None ): rag_managed_db_config = GapicRagManagedDbConfig( - basic=GapicRagManagedDbConfig.Basic() + spanner=GapicRagManagedDbConfig.Spanner( + basic=GapicRagManagedDbConfig.Basic() + ) ) else: - if isinstance(rag_engine_config.rag_managed_db_config.tier, Enterprise): + if isinstance(rag_engine_config.rag_managed_db_config.mode, Serverless): + rag_managed_db_config.serverless = GapicRagManagedDbConfig.Serverless() + elif isinstance(rag_engine_config.rag_managed_db_config.mode, Spanner): + rag_managed_db_config.spanner = _convert_spanner_to_gapic( + rag_engine_config.rag_managed_db_config.mode + ) + elif isinstance(rag_engine_config.rag_managed_db_config.tier, Enterprise): rag_managed_db_config.enterprise = GapicRagManagedDbConfig.Enterprise() elif isinstance(rag_engine_config.rag_managed_db_config.tier, Basic): rag_managed_db_config.basic = GapicRagManagedDbConfig.Basic() diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index 613db44eec..140906d796 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -627,6 +627,24 @@ class Unprovisioned: """ +@dataclasses.dataclass +class Spanner: + """Switches RAG Engine to use Spanner/RagManagedDb as the backend. + + Attributes: + tier: The tier of the RagManagedDb. The default tier is Basic. + + NOTE: This is the default mode if not explicitly chosen. + """ + + tier: Optional[Union[Basic, Scaled, Unprovisioned]] = None + + +@dataclasses.dataclass +class Serverless: + """Switches RAG Engine to use serverless mode as the backend.""" + + @dataclasses.dataclass class RagManagedDbConfig: """RagManagedDbConfig. @@ -634,9 +652,13 @@ class RagManagedDbConfig: The config of the RagManagedDb used by RagEngine. Attributes: - tier: The tier of the RagManagedDb. The default tier is Basic. + mode: The choice of backend for your RAG Engine. The default mode is + Spanner with Basic tier. + tier: The tier of the RagManagedDb. NOTE: This field is deprecated. Use + `mode` instead to set the tier under Spanner. """ + mode: Optional[Union[Spanner, Serverless]] = None tier: Optional[Union[Enterprise, Basic, Scaled, Unprovisioned]] = None