diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index b4737dd488..db5f620d07 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -16,7 +16,7 @@ # from dataclasses import dataclass, field -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union, Any from google.auth import credentials as auth_credentials from google.cloud.aiplatform import base @@ -208,7 +208,8 @@ class MatchNeighbor: For example, values [1,2,3] with dimensions [4,5,6] means value 1 is of the 4th dimension, value 2 is of the 4th dimension, and value 3 is of the 6th dimension. - + embedding_metadata (Dict[str,Any]): + Optional. The embedding metadata of the matching datapoint. """ id: str @@ -220,6 +221,7 @@ class MatchNeighbor: numeric_restricts: Optional[List[NumericNamespace]] = None sparse_embedding_values: Optional[List[float]] = None sparse_embedding_dimensions: Optional[List[int]] = None + embedding_metadata: Optional[Dict[str,Any]] = None def from_index_datapoint( self, index_datapoint: gca_index_v1beta1.IndexDatapoint @@ -276,6 +278,8 @@ def from_index_datapoint( self.sparse_embedding_dimensions = ( index_datapoint.sparse_embedding.dimensions ) + if index_datapoint.embedding_metadata is not None: + self.embedding_metadata = dict(index_datapoint.embedding_metadata) return self def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighbor":