Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 154 additions & 7 deletions crates/lance-context-core/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ const DEFAULT_MANIFEST_SCAN_BATCH_SIZE: usize = 16;
const RRF_K: f32 = 60.0;
const ID_INDEX_NAME: &str = "id_idx";
const RELATIONSHIPS_COLUMN: &str = "relationships";
/// Schema-metadata key under which the configured [`DistanceMetric`] is persisted
/// so it round-trips on `open` without being re-specified by the caller.
const DISTANCE_METRIC_METADATA_KEY: &str = "lance-context:distance_metric";

/// Configuration for background compaction.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -143,6 +146,17 @@ impl DistanceMetric {
Self::Dot => dot_distance(query, candidate),
}
}

/// Stable string identifier for this metric, used when persisting it in
/// dataset schema metadata. Round-trips through [`DistanceMetric::parse`].
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::L2 => "l2",
Self::Cosine => "cosine",
Self::Dot => "dot",
}
}
}

/// Statistics about compaction status and history.
Expand Down Expand Up @@ -198,7 +212,12 @@ pub struct ContextStoreOptions {
/// Type of scalar index to create on the `id` column.
pub id_index_type: IdIndexType,
/// Distance metric used to rank vector-search results.
pub distance_metric: DistanceMetric,
///
/// For newly-created datasets this is persisted in the schema metadata and
/// becomes the dataset's metric. For existing datasets the persisted metric
/// is used; passing a different metric here is an error. `None` defaults to
/// the persisted metric (or `L2` for datasets created before persistence).
pub distance_metric: Option<DistanceMetric>,
}

impl ContextStoreOptions {
Expand Down Expand Up @@ -282,6 +301,7 @@ impl ContextStore {
storage_options,
&blob_columns,
requested_embedding_dim,
options.distance_metric.unwrap_or_default(),
)
.await?;
(dataset, true)
Expand All @@ -296,6 +316,18 @@ impl ContextStore {
embedding_dim, requested_embedding_dim
))));
}
let distance_metric = distance_metric_from_schema(&arrow_schema)?;
if !created {
if let Some(requested) = options.distance_metric {
if requested != distance_metric {
return Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
"existing context distance metric '{}' does not match requested metric '{}'",
distance_metric.as_str(),
requested.as_str()
))));
}
}
}

let mut store = Self {
dataset,
Expand All @@ -310,7 +342,7 @@ impl ContextStore {
blob_columns,
id_index_type: options.id_index_type,
embedding_dim,
distance_metric: options.distance_metric,
distance_metric,
};

// Ensure id index if configured
Expand Down Expand Up @@ -1105,7 +1137,15 @@ impl ContextStore {

/// Lance schema for a context store using a caller-selected embedding width.
pub fn schema_with_embedding_dim(blob_columns: &HashSet<String>, embedding_dim: i32) -> Schema {
Self::schema_with_options(blob_columns, true, true, true, true, embedding_dim)
Self::schema_with_options(
blob_columns,
true,
true,
true,
true,
embedding_dim,
DistanceMetric::default(),
)
}

fn schema_with_options(
Expand All @@ -1115,6 +1155,7 @@ impl ContextStore {
include_relationships: bool,
include_lifecycle: bool,
embedding_dim: i32,
distance_metric: DistanceMetric,
) -> Schema {
let mut id_metadata = HashMap::new();
id_metadata.insert(
Expand Down Expand Up @@ -1209,7 +1250,12 @@ impl ContextStore {
),
]);

Schema::new(fields)
let schema_metadata = HashMap::from([(
DISTANCE_METRIC_METADATA_KEY.to_string(),
distance_metric.as_str().to_string(),
)]);

Schema::new_with_metadata(fields, schema_metadata)
}

async fn load_with_options(
Expand All @@ -1231,8 +1277,17 @@ impl ContextStore {
storage_options: Option<HashMap<String, String>>,
blob_columns: &HashSet<String>,
embedding_dim: i32,
distance_metric: DistanceMetric,
) -> LanceResult<Dataset> {
let schema = Arc::new(Self::schema_with_embedding_dim(blob_columns, embedding_dim));
let schema = Arc::new(Self::schema_with_options(
blob_columns,
true,
true,
true,
true,
embedding_dim,
distance_metric,
));
let empty_batch = RecordBatch::new_empty(schema.clone());
let batches = RecordBatchIterator::new(
vec![Ok::<RecordBatch, ArrowError>(empty_batch)].into_iter(),
Expand Down Expand Up @@ -2086,6 +2141,17 @@ fn embedding_dim_from_schema(schema: &Schema) -> LanceResult<i32> {
Ok(*embedding_dim)
}

/// Read the persisted [`DistanceMetric`] from the dataset's schema metadata.
///
/// Datasets created before metric persistence (no key present) default to
/// [`DistanceMetric::L2`], preserving historical ranking behavior.
fn distance_metric_from_schema(schema: &Schema) -> LanceResult<DistanceMetric> {
match schema.metadata.get(DISTANCE_METRIC_METADATA_KEY) {
Some(value) => DistanceMetric::parse(value),
None => Ok(DistanceMetric::default()),
}
}

/// Dot product of two vectors.
fn dot_product(left: &[f32], right: &[f32]) -> f32 {
left.iter()
Expand Down Expand Up @@ -2297,7 +2363,7 @@ mod tests {
// Cosine: `aligned` should rank first despite the larger L2 distance.
let cos_dir = TempDir::new().unwrap();
let cos_opts = ContextStoreOptions {
distance_metric: DistanceMetric::Cosine,
distance_metric: Some(DistanceMetric::Cosine),
..Default::default()
};
let mut cos_store =
Expand All @@ -2317,7 +2383,7 @@ mod tests {
// Dot: `aligned` has the largest inner product -> first.
let dot_dir = TempDir::new().unwrap();
let dot_opts = ContextStoreOptions {
distance_metric: DistanceMetric::Dot,
distance_metric: Some(DistanceMetric::Dot),
..Default::default()
};
let mut dot_store =
Expand All @@ -2336,6 +2402,86 @@ mod tests {
});
}

#[test]
fn distance_metric_persists_across_reopen() {
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
let dir = TempDir::new().unwrap();
let uri = dir.path().to_string_lossy().to_string();
let query = make_embedding2(1.0, 0.0);
let aligned = make_embedding2(10.0, 0.0);
let near = make_embedding2(1.0, 1.0);

// Create with cosine and write records.
{
let opts = ContextStoreOptions {
distance_metric: Some(DistanceMetric::Cosine),
..Default::default()
};
let mut store = ContextStore::open_with_options(&uri, opts).await.unwrap();
store
.add(&[
text_record_with("aligned", aligned.clone()),
text_record_with("near", near.clone()),
])
.await
.unwrap();
}

// Reopen WITHOUT passing the metric: it must be recovered from the
// schema, so cosine ranking (`aligned` first) still applies.
let store = ContextStore::open(&uri).await.unwrap();
assert_eq!(store.distance_metric, DistanceMetric::Cosine);
let results = store.search(&query, Some(2)).await.unwrap();
assert_eq!(results[0].record.id, "aligned");
});
}

#[test]
fn distance_metric_mismatch_errors() {
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
let dir = TempDir::new().unwrap();
let uri = dir.path().to_string_lossy().to_string();
ContextStore::open_with_options(
&uri,
ContextStoreOptions {
distance_metric: Some(DistanceMetric::Cosine),
..Default::default()
},
)
.await
.unwrap();

let result = ContextStore::open_with_options(
&uri,
ContextStoreOptions {
distance_metric: Some(DistanceMetric::Dot),
..Default::default()
},
)
.await;
let err = match result {
Ok(_) => panic!("expected a distance-metric mismatch error"),
Err(err) => err,
};
assert!(
err.to_string().contains("distance metric"),
"unexpected error: {err}"
);
});
}

#[test]
fn distance_metric_from_schema_defaults_l2_when_absent() {
// Datasets created before metric persistence carry no metadata key.
let schema = Schema::new(vec![Field::new("id", DataType::Utf8, false)]);
assert_eq!(
distance_metric_from_schema(&schema).unwrap(),
DistanceMetric::L2
);
}

#[test]
fn retrieve_fuses_text_and_vector_channels() {
let dir = TempDir::new().unwrap();
Expand Down Expand Up @@ -2663,6 +2809,7 @@ mod tests {
false,
true,
DEFAULT_EMBEDDING_DIM,
DistanceMetric::default(),
));
let empty_batch = RecordBatch::new_empty(schema.clone());
let batches = RecordBatchIterator::new(
Expand Down
8 changes: 4 additions & 4 deletions crates/lance-context-server/src/routes/contexts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ pub async fn create_context(
let blob_columns: HashSet<String> = req.blob_columns.unwrap_or_default().into_iter().collect();

let distance_metric = match req.distance_metric.as_deref() {
Some(value) => {
DistanceMetric::parse(value).map_err(|e| AppError::InvalidRequest(e.to_string()))?
}
None => DistanceMetric::default(),
Some(value) => Some(
DistanceMetric::parse(value).map_err(|e| AppError::InvalidRequest(e.to_string()))?,
),
None => None,
};

let uri = state.context_uri(&req.name);
Expand Down
8 changes: 5 additions & 3 deletions crates/lance-context/src/unified.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ impl ContextStore {
}
};
let metric = match distance_metric {
Some(value) => DistanceMetric::parse(value)
.map_err(|e| ContextError::InvalidRequest(e.to_string()))?,
None => DistanceMetric::default(),
Some(value) => Some(
DistanceMetric::parse(value)
.map_err(|e| ContextError::InvalidRequest(e.to_string()))?,
),
None => None,
};
let options = ContextStoreOptions {
storage_options,
Expand Down
4 changes: 2 additions & 2 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ impl Context {
};

let metric = match distance_metric.as_deref() {
Some(value) => DistanceMetric::parse(value).map_err(to_py_err)?,
None => DistanceMetric::default(),
Some(value) => Some(DistanceMetric::parse(value).map_err(to_py_err)?),
None => None,
};

let options = ContextStoreOptions {
Expand Down
17 changes: 17 additions & 0 deletions python/tests/test_distance_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ def test_cosine_metric_changes_ranking(tmp_path: Path) -> None:
assert [h["external_id"] for h in hits][0] == "aligned"


def test_metric_persists_when_reopened_without_option(tmp_path: Path) -> None:
# Create with cosine, then reopen WITHOUT specifying the metric: it must be
# recovered from the dataset so ranking still uses cosine.
uri = str(tmp_path / "persist.lance")
_make(uri, distance_metric="cosine")
reopened = Context(uri)
hits = reopened.search(QUERY, limit=2)
assert [h["external_id"] for h in hits][0] == "aligned"


def test_metric_mismatch_on_reopen_rejected(tmp_path: Path) -> None:
uri = str(tmp_path / "mismatch.lance")
_make(uri, distance_metric="cosine")
with pytest.raises(RuntimeError, match="distance metric"):
Context.create(uri, distance_metric="dot")


def test_dot_metric_ranks_by_inner_product(tmp_path: Path) -> None:
ctx = _make(str(tmp_path / "dot.lance"), distance_metric="dot")
hits = ctx.search(QUERY, limit=2)
Expand Down
Loading