From 0287b6a193cfceca108adccca4ec40389fd53605 Mon Sep 17 00:00:00 2001 From: dcfocus Date: Fri, 12 Jun 2026 04:25:30 +0000 Subject: [PATCH] feat: persist distance_metric in dataset schema metadata The configurable vector-search distance metric (#74/#77) was a runtime-only option: callers had to re-pass `distance_metric` on every open, and if they forgot, the store silently fell back to `l2` and ranked results differently from how the dataset was intended to be queried. Persist the metric in the Lance schema metadata (key `lance-context:distance_metric`) on create, the same mechanism already used for blob encoding. On open it is recovered automatically, mirroring how `embedding_dim` round-trips via the schema. An explicitly passed metric that disagrees with the persisted one now errors, reusing the `embedding_dim` mismatch-validation pattern. Datasets created before this change carry no key and default to `l2`, preserving existing behavior. `ContextStoreOptions.distance_metric` becomes `Option` (None = use the persisted/default metric), matching `embedding_dim: Option`. Closes #80 Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/lance-context-core/src/store.rs | 161 +++++++++++++++++- .../src/routes/contexts.rs | 8 +- crates/lance-context/src/unified.rs | 8 +- python/src/lib.rs | 4 +- python/tests/test_distance_metric.py | 17 ++ 5 files changed, 182 insertions(+), 16 deletions(-) diff --git a/crates/lance-context-core/src/store.rs b/crates/lance-context-core/src/store.rs index abe6cc6..d1c9013 100644 --- a/crates/lance-context-core/src/store.rs +++ b/crates/lance-context-core/src/store.rs @@ -47,6 +47,9 @@ const DEFAULT_MANIFEST_SCAN_BATCH_SIZE: usize = 16; const RRF_K: f32 = 60.0; const ID_INDEX_NAME: &str = "id_idx"; const RELATIONSHIPS_COLUMN: &str = "relationships"; +/// Schema-metadata key under which the configured [`DistanceMetric`] is persisted +/// so it round-trips on `open` without being re-specified by the caller. +const DISTANCE_METRIC_METADATA_KEY: &str = "lance-context:distance_metric"; /// Configuration for background compaction. #[derive(Debug, Clone)] @@ -143,6 +146,17 @@ impl DistanceMetric { Self::Dot => dot_distance(query, candidate), } } + + /// Stable string identifier for this metric, used when persisting it in + /// dataset schema metadata. Round-trips through [`DistanceMetric::parse`]. + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + Self::L2 => "l2", + Self::Cosine => "cosine", + Self::Dot => "dot", + } + } } /// Statistics about compaction status and history. @@ -198,7 +212,12 @@ pub struct ContextStoreOptions { /// Type of scalar index to create on the `id` column. pub id_index_type: IdIndexType, /// Distance metric used to rank vector-search results. - pub distance_metric: DistanceMetric, + /// + /// For newly-created datasets this is persisted in the schema metadata and + /// becomes the dataset's metric. For existing datasets the persisted metric + /// is used; passing a different metric here is an error. `None` defaults to + /// the persisted metric (or `L2` for datasets created before persistence). + pub distance_metric: Option, } impl ContextStoreOptions { @@ -282,6 +301,7 @@ impl ContextStore { storage_options, &blob_columns, requested_embedding_dim, + options.distance_metric.unwrap_or_default(), ) .await?; (dataset, true) @@ -296,6 +316,18 @@ impl ContextStore { embedding_dim, requested_embedding_dim )))); } + let distance_metric = distance_metric_from_schema(&arrow_schema)?; + if !created { + if let Some(requested) = options.distance_metric { + if requested != distance_metric { + return Err(LanceError::from(ArrowError::InvalidArgumentError(format!( + "existing context distance metric '{}' does not match requested metric '{}'", + distance_metric.as_str(), + requested.as_str() + )))); + } + } + } let mut store = Self { dataset, @@ -310,7 +342,7 @@ impl ContextStore { blob_columns, id_index_type: options.id_index_type, embedding_dim, - distance_metric: options.distance_metric, + distance_metric, }; // Ensure id index if configured @@ -1105,7 +1137,15 @@ impl ContextStore { /// Lance schema for a context store using a caller-selected embedding width. pub fn schema_with_embedding_dim(blob_columns: &HashSet, embedding_dim: i32) -> Schema { - Self::schema_with_options(blob_columns, true, true, true, true, embedding_dim) + Self::schema_with_options( + blob_columns, + true, + true, + true, + true, + embedding_dim, + DistanceMetric::default(), + ) } fn schema_with_options( @@ -1115,6 +1155,7 @@ impl ContextStore { include_relationships: bool, include_lifecycle: bool, embedding_dim: i32, + distance_metric: DistanceMetric, ) -> Schema { let mut id_metadata = HashMap::new(); id_metadata.insert( @@ -1209,7 +1250,12 @@ impl ContextStore { ), ]); - Schema::new(fields) + let schema_metadata = HashMap::from([( + DISTANCE_METRIC_METADATA_KEY.to_string(), + distance_metric.as_str().to_string(), + )]); + + Schema::new_with_metadata(fields, schema_metadata) } async fn load_with_options( @@ -1231,8 +1277,17 @@ impl ContextStore { storage_options: Option>, blob_columns: &HashSet, embedding_dim: i32, + distance_metric: DistanceMetric, ) -> LanceResult { - let schema = Arc::new(Self::schema_with_embedding_dim(blob_columns, embedding_dim)); + let schema = Arc::new(Self::schema_with_options( + blob_columns, + true, + true, + true, + true, + embedding_dim, + distance_metric, + )); let empty_batch = RecordBatch::new_empty(schema.clone()); let batches = RecordBatchIterator::new( vec![Ok::(empty_batch)].into_iter(), @@ -2086,6 +2141,17 @@ fn embedding_dim_from_schema(schema: &Schema) -> LanceResult { Ok(*embedding_dim) } +/// Read the persisted [`DistanceMetric`] from the dataset's schema metadata. +/// +/// Datasets created before metric persistence (no key present) default to +/// [`DistanceMetric::L2`], preserving historical ranking behavior. +fn distance_metric_from_schema(schema: &Schema) -> LanceResult { + match schema.metadata.get(DISTANCE_METRIC_METADATA_KEY) { + Some(value) => DistanceMetric::parse(value), + None => Ok(DistanceMetric::default()), + } +} + /// Dot product of two vectors. fn dot_product(left: &[f32], right: &[f32]) -> f32 { left.iter() @@ -2297,7 +2363,7 @@ mod tests { // Cosine: `aligned` should rank first despite the larger L2 distance. let cos_dir = TempDir::new().unwrap(); let cos_opts = ContextStoreOptions { - distance_metric: DistanceMetric::Cosine, + distance_metric: Some(DistanceMetric::Cosine), ..Default::default() }; let mut cos_store = @@ -2317,7 +2383,7 @@ mod tests { // Dot: `aligned` has the largest inner product -> first. let dot_dir = TempDir::new().unwrap(); let dot_opts = ContextStoreOptions { - distance_metric: DistanceMetric::Dot, + distance_metric: Some(DistanceMetric::Dot), ..Default::default() }; let mut dot_store = @@ -2336,6 +2402,86 @@ mod tests { }); } + #[test] + fn distance_metric_persists_across_reopen() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let dir = TempDir::new().unwrap(); + let uri = dir.path().to_string_lossy().to_string(); + let query = make_embedding2(1.0, 0.0); + let aligned = make_embedding2(10.0, 0.0); + let near = make_embedding2(1.0, 1.0); + + // Create with cosine and write records. + { + let opts = ContextStoreOptions { + distance_metric: Some(DistanceMetric::Cosine), + ..Default::default() + }; + let mut store = ContextStore::open_with_options(&uri, opts).await.unwrap(); + store + .add(&[ + text_record_with("aligned", aligned.clone()), + text_record_with("near", near.clone()), + ]) + .await + .unwrap(); + } + + // Reopen WITHOUT passing the metric: it must be recovered from the + // schema, so cosine ranking (`aligned` first) still applies. + let store = ContextStore::open(&uri).await.unwrap(); + assert_eq!(store.distance_metric, DistanceMetric::Cosine); + let results = store.search(&query, Some(2)).await.unwrap(); + assert_eq!(results[0].record.id, "aligned"); + }); + } + + #[test] + fn distance_metric_mismatch_errors() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let dir = TempDir::new().unwrap(); + let uri = dir.path().to_string_lossy().to_string(); + ContextStore::open_with_options( + &uri, + ContextStoreOptions { + distance_metric: Some(DistanceMetric::Cosine), + ..Default::default() + }, + ) + .await + .unwrap(); + + let result = ContextStore::open_with_options( + &uri, + ContextStoreOptions { + distance_metric: Some(DistanceMetric::Dot), + ..Default::default() + }, + ) + .await; + let err = match result { + Ok(_) => panic!("expected a distance-metric mismatch error"), + Err(err) => err, + }; + assert!( + err.to_string().contains("distance metric"), + "unexpected error: {err}" + ); + }); + } + + #[test] + fn distance_metric_from_schema_defaults_l2_when_absent() { + // Datasets created before metric persistence carry no metadata key. + let schema = Schema::new(vec![Field::new("id", DataType::Utf8, false)]); + assert_eq!( + distance_metric_from_schema(&schema).unwrap(), + DistanceMetric::L2 + ); + } + #[test] fn retrieve_fuses_text_and_vector_channels() { let dir = TempDir::new().unwrap(); @@ -2663,6 +2809,7 @@ mod tests { false, true, DEFAULT_EMBEDDING_DIM, + DistanceMetric::default(), )); let empty_batch = RecordBatch::new_empty(schema.clone()); let batches = RecordBatchIterator::new( diff --git a/crates/lance-context-server/src/routes/contexts.rs b/crates/lance-context-server/src/routes/contexts.rs index 3be2330..56ece0b 100644 --- a/crates/lance-context-server/src/routes/contexts.rs +++ b/crates/lance-context-server/src/routes/contexts.rs @@ -38,10 +38,10 @@ pub async fn create_context( let blob_columns: HashSet = req.blob_columns.unwrap_or_default().into_iter().collect(); let distance_metric = match req.distance_metric.as_deref() { - Some(value) => { - DistanceMetric::parse(value).map_err(|e| AppError::InvalidRequest(e.to_string()))? - } - None => DistanceMetric::default(), + Some(value) => Some( + DistanceMetric::parse(value).map_err(|e| AppError::InvalidRequest(e.to_string()))?, + ), + None => None, }; let uri = state.context_uri(&req.name); diff --git a/crates/lance-context/src/unified.rs b/crates/lance-context/src/unified.rs index 5b92374..af32972 100644 --- a/crates/lance-context/src/unified.rs +++ b/crates/lance-context/src/unified.rs @@ -44,9 +44,11 @@ impl ContextStore { } }; let metric = match distance_metric { - Some(value) => DistanceMetric::parse(value) - .map_err(|e| ContextError::InvalidRequest(e.to_string()))?, - None => DistanceMetric::default(), + Some(value) => Some( + DistanceMetric::parse(value) + .map_err(|e| ContextError::InvalidRequest(e.to_string()))?, + ), + None => None, }; let options = ContextStoreOptions { storage_options, diff --git a/python/src/lib.rs b/python/src/lib.rs index 771ef8a..992af3c 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -197,8 +197,8 @@ impl Context { }; let metric = match distance_metric.as_deref() { - Some(value) => DistanceMetric::parse(value).map_err(to_py_err)?, - None => DistanceMetric::default(), + Some(value) => Some(DistanceMetric::parse(value).map_err(to_py_err)?), + None => None, }; let options = ContextStoreOptions { diff --git a/python/tests/test_distance_metric.py b/python/tests/test_distance_metric.py index ab38cba..4685a99 100644 --- a/python/tests/test_distance_metric.py +++ b/python/tests/test_distance_metric.py @@ -48,6 +48,23 @@ def test_cosine_metric_changes_ranking(tmp_path: Path) -> None: assert [h["external_id"] for h in hits][0] == "aligned" +def test_metric_persists_when_reopened_without_option(tmp_path: Path) -> None: + # Create with cosine, then reopen WITHOUT specifying the metric: it must be + # recovered from the dataset so ranking still uses cosine. + uri = str(tmp_path / "persist.lance") + _make(uri, distance_metric="cosine") + reopened = Context(uri) + hits = reopened.search(QUERY, limit=2) + assert [h["external_id"] for h in hits][0] == "aligned" + + +def test_metric_mismatch_on_reopen_rejected(tmp_path: Path) -> None: + uri = str(tmp_path / "mismatch.lance") + _make(uri, distance_metric="cosine") + with pytest.raises(RuntimeError, match="distance metric"): + Context.create(uri, distance_metric="dot") + + def test_dot_metric_ranks_by_inner_product(tmp_path: Path) -> None: ctx = _make(str(tmp_path / "dot.lance"), distance_metric="dot") hits = ctx.search(QUERY, limit=2)