Skip to content

Commit a7e5e6a

Browse files
authored
fix: relax model name validation for Bedrock Embedders (#2625)
1 parent 59c8b0d commit a7e5e6a

File tree

3 files changed

+41
-76
lines changed

3 files changed

+41
-76
lines changed

integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Literal, Optional
2+
from typing import Any, Optional
33

44
from botocore.config import Config
55
from botocore.exceptions import ClientError
@@ -16,15 +16,6 @@
1616

1717
logger = logging.getLogger(__name__)
1818

19-
SUPPORTED_EMBEDDING_MODELS = [
20-
"amazon.titan-embed-text-v1",
21-
"amazon.titan-embed-text-v2:0",
22-
"amazon.titan-embed-image-v1",
23-
"cohere.embed-english-v3",
24-
"cohere.embed-multilingual-v3",
25-
"cohere.embed-v4:0",
26-
]
27-
2819

2920
@component
3021
class AmazonBedrockDocumentEmbedder:
@@ -58,14 +49,7 @@ class AmazonBedrockDocumentEmbedder:
5849

5950
def __init__(
6051
self,
61-
model: Literal[
62-
"amazon.titan-embed-text-v1",
63-
"amazon.titan-embed-text-v2:0",
64-
"amazon.titan-embed-image-v1",
65-
"cohere.embed-english-v3",
66-
"cohere.embed-multilingual-v3",
67-
"cohere.embed-v4:0",
68-
],
52+
model: str,
6953
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008
7054
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
7155
"AWS_SECRET_ACCESS_KEY", strict=False
@@ -90,8 +74,13 @@ def __init__(
9074
constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`,
9175
and `aws_region_name`.
9276
93-
:param model: The embedding model to use. The model has to be specified in the format outlined in the Amazon
94-
Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).
77+
:param model: The embedding model to use.
78+
Amazon Titan and Cohere embedding models are supported, for example:
79+
"amazon.titan-embed-text-v1", "amazon.titan-embed-text-v2:0", "amazon.titan-embed-image-v1",
80+
"cohere.embed-english-v3", "cohere.embed-multilingual-v3", "cohere.embed-v4:0".
81+
To find all supported models, refer to the Amazon Bedrock
82+
[documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) and
83+
filter for "embedding", then select models from the Amazon Titan and Cohere series.
9584
:param aws_access_key_id: AWS access key ID.
9685
:param aws_secret_access_key: AWS secret access key.
9786
:param aws_session_token: AWS session token.
@@ -109,11 +98,8 @@ def __init__(
10998
:raises ValueError: If the model is not supported.
11099
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly.
111100
"""
112-
113-
if not model or model not in SUPPORTED_EMBEDDING_MODELS:
114-
msg = "Please provide a valid model from the list of supported models: " + ", ".join(
115-
SUPPORTED_EMBEDDING_MODELS
116-
)
101+
if "titan" not in model and "cohere" not in model:
102+
msg = f"Model {model} is not supported. Only Amazon Titan and Cohere embedding models are supported."
117103
raise ValueError(msg)
118104

119105
self.model = model
@@ -254,7 +240,7 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]:
254240
elif "titan" in self.model:
255241
documents_with_embeddings = self._embed_titan(documents=documents)
256242
else:
257-
msg = f"Model {self.model} is not supported. Supported models are: {', '.join(SUPPORTED_EMBEDDING_MODELS)}."
243+
msg = f"Model {self.model} is not supported. Only Amazon Titan and Cohere embedding models are supported."
258244
raise ValueError(msg)
259245

260246
return {"documents": documents_with_embeddings}

integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_image_embedder.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import json
66
from dataclasses import replace
7-
from typing import Any, Literal, Optional
7+
from typing import Any, Optional
88

99
from botocore.config import Config
1010
from botocore.exceptions import ClientError
@@ -27,13 +27,6 @@
2727

2828
logger = logging.getLogger(__name__)
2929

30-
SUPPORTED_EMBEDDING_MODELS = [
31-
"amazon.titan-embed-image-v1",
32-
"cohere.embed-english-v3",
33-
"cohere.embed-multilingual-v3",
34-
"cohere.embed-v4:0",
35-
]
36-
3730

3831
@component
3932
class AmazonBedrockDocumentImageEmbedder:
@@ -74,12 +67,7 @@ class AmazonBedrockDocumentImageEmbedder:
7467
def __init__(
7568
self,
7669
*,
77-
model: Literal[
78-
"amazon.titan-embed-image-v1",
79-
"cohere.embed-english-v3",
80-
"cohere.embed-multilingual-v3",
81-
"cohere.embed-v4:0",
82-
],
70+
model: str,
8371
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008
8472
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
8573
"AWS_SECRET_ACCESS_KEY", strict=False
@@ -97,13 +85,13 @@ def __init__(
9785
"""
9886
Creates a AmazonBedrockDocumentImageEmbedder component.
9987
100-
:param model:
101-
The Bedrock model to use for calculating embeddings. Pass a valid model ID.
102-
Supported models:
103-
- "amazon.titan-embed-image-v1"
104-
- "cohere.embed-english-v3"
105-
- "cohere.embed-multilingual-v3"
106-
- "cohere.embed-v4:0"
88+
:param model: The embedding model to use.
89+
Amazon Titan and Cohere multimodal embedding models are supported, for example:
90+
"amazon.titan-embed-image-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3",
91+
"cohere.embed-v4:0".
92+
To find all supported models, refer to the Amazon Bedrock
93+
[documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) and
94+
filter for "embedding", then select multimodal models from the Amazon Titan and Cohere series.
10795
:param aws_access_key_id: AWS access key ID.
10896
:param aws_secret_access_key: AWS secret access key.
10997
:param aws_session_token: AWS session token.
@@ -125,9 +113,10 @@ def __init__(
125113
:raises ValueError: If the model is not supported.
126114
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly.
127115
"""
128-
if not model or model not in SUPPORTED_EMBEDDING_MODELS:
129-
msg = "Please provide a valid model from the list of supported models: " + ", ".join(
130-
SUPPORTED_EMBEDDING_MODELS
116+
if "titan" not in model and "cohere" not in model:
117+
msg = (
118+
f"Model {model} is not supported. "
119+
"Only Amazon Titan and Cohere multimodal embedding models are supported."
131120
)
132121
raise ValueError(msg)
133122

@@ -291,7 +280,10 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]:
291280
elif "titan" in self.model:
292281
embeddings = self._embed_titan(images=images_to_embed)
293282
else:
294-
msg = f"Model {self.model} is not supported. Supported models are: {', '.join(SUPPORTED_EMBEDDING_MODELS)}."
283+
msg = (
284+
f"Model {self.model} is not supported. "
285+
"Only Amazon Titan and Cohere multimodal embedding models are supported."
286+
)
295287
raise ValueError(msg)
296288

297289
docs_with_embeddings = []

integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Literal, Optional
2+
from typing import Any, Optional
33

44
from botocore.config import Config
55
from botocore.exceptions import ClientError
@@ -14,15 +14,6 @@
1414

1515
logger = logging.getLogger(__name__)
1616

17-
SUPPORTED_EMBEDDING_MODELS = [
18-
"amazon.titan-embed-text-v1",
19-
"amazon.titan-embed-text-v2:0",
20-
"amazon.titan-embed-image-v1",
21-
"cohere.embed-english-v3",
22-
"cohere.embed-multilingual-v3",
23-
"cohere.embed-v4:0",
24-
]
25-
2617

2718
@component
2819
class AmazonBedrockTextEmbedder:
@@ -51,14 +42,7 @@ class AmazonBedrockTextEmbedder:
5142

5243
def __init__(
5344
self,
54-
model: Literal[
55-
"amazon.titan-embed-text-v1",
56-
"amazon.titan-embed-text-v2:0",
57-
"amazon.titan-embed-image-v1",
58-
"cohere.embed-english-v3",
59-
"cohere.embed-multilingual-v3",
60-
"cohere.embed-v4:0",
61-
],
45+
model: str,
6246
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008
6347
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
6448
"AWS_SECRET_ACCESS_KEY", strict=False
@@ -79,8 +63,13 @@ def __init__(
7963
constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`,
8064
and `aws_region_name`.
8165
82-
:param model: The embedding model to use. The model has to be specified in the format outlined in the Amazon
83-
Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).
66+
:param model: The embedding model to use.
67+
Amazon Titan and Cohere embedding models are supported, for example:
68+
"amazon.titan-embed-text-v1", "amazon.titan-embed-text-v2:0", "amazon.titan-embed-image-v1",
69+
"cohere.embed-english-v3", "cohere.embed-multilingual-v3", "cohere.embed-v4:0".
70+
To find all supported models, refer to the Amazon Bedrock
71+
[documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) and
72+
filter for "embedding", then select models from the Amazon Titan and Cohere series.
8473
:param aws_access_key_id: AWS access key ID.
8574
:param aws_secret_access_key: AWS secret access key.
8675
:param aws_session_token: AWS session token.
@@ -92,10 +81,8 @@ def __init__(
9281
:raises ValueError: If the model is not supported.
9382
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly.
9483
"""
95-
if not model or model not in SUPPORTED_EMBEDDING_MODELS:
96-
msg = "Please provide a valid model from the list of supported models: " + ", ".join(
97-
SUPPORTED_EMBEDDING_MODELS
98-
)
84+
if "titan" not in model and "cohere" not in model:
85+
msg = f"Model {model} is not supported. Only Amazon Titan and Cohere embedding models are supported."
9986
raise ValueError(msg)
10087

10188
self.model = model
@@ -179,7 +166,7 @@ def run(self, text: str) -> dict[str, list[float]]:
179166
elif "titan" in self.model:
180167
embedding = response_body["embedding"]
181168
else:
182-
msg = f"Unsupported model {self.model}. Supported models are: {', '.join(SUPPORTED_EMBEDDING_MODELS)}"
169+
msg = f"Model {self.model} is not supported. Only Amazon Titan and Cohere embedding models are supported."
183170
raise ValueError(msg)
184171

185172
return {"embedding": embedding}

0 commit comments

Comments
 (0)