Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 130 additions & 1 deletion rust/lance-index/benches/rq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,138 @@ fn ex_bulk_paths(c: &mut Criterion) {
}
}

/// Top-k accumulation through the gated raw-query multi-bit path: binary
/// FastScan, the per-row lower-bound pruning scan, and the exact rerank of
/// the surviving rows. Error factors are present so the gating is enabled.
fn heap_topk(c: &mut Criterion) {
use arrow_array::{ArrayRef, FixedSizeListArray, Float32Array, UInt8Array, UInt64Array};
use lance_arrow::FixedSizeListArrayExt;
use lance_index::vector::ApproxMode;
use lance_index::vector::bq::transform::{
ERROR_FACTORS_COLUMN, EX_ADD_FACTORS_COLUMN, EX_SCALE_FACTORS_COLUMN,
};
use lance_index::vector::storage::DistanceCalculatorOptions;
use std::collections::BinaryHeap;
use std::sync::Arc;

const TOPK_DIM: usize = 1536;
const TOPK_ROWS: usize = 4096;
const TOPK_K: usize = 10;
const NUM_BITS: u8 = 5;
let ex_bits = NUM_BITS - 1;

let mut rng = SmallRng::seed_from_u64(99);
let rq = RabitQuantizer::new_with_rotation::<Float32Type>(
NUM_BITS,
TOPK_DIM as i32,
RQRotationType::Fast,
);
let metadata = rq.metadata(None);

let code_len = TOPK_DIM / 8;
let binary_codes = (0..TOPK_ROWS * code_len)
.map(|_| rng.random())
.collect::<Vec<u8>>();
let ex_code_len = blocked_ex_code_bytes(TOPK_DIM, ex_bits);
let ex_codes = (0..TOPK_ROWS * ex_code_len)
.map(|_| rng.random())
.collect::<Vec<u8>>();
// Factor magnitudes chosen so the lower bounds spread mostly with the add
// factors; once the heap is full the threshold prunes the vast majority
// of rows, like a production multi-partition scan.
let mut rand_factors = |low: f32, high: f32| {
Arc::new(Float32Array::from(
(0..TOPK_ROWS)
.map(|_| rng.random_range(low..high))
.collect::<Vec<_>>(),
)) as ArrayRef
};
let batch = arrow_array::RecordBatch::try_from_iter(vec![
(
ROW_ID,
Arc::new(UInt64Array::from_iter_values(0..TOPK_ROWS as u64)) as ArrayRef,
),
(
RABIT_CODE_COLUMN,
Arc::new(
FixedSizeListArray::try_new_from_values(
UInt8Array::from(binary_codes),
code_len as i32,
)
.unwrap(),
) as ArrayRef,
),
(ADD_FACTORS_COLUMN, rand_factors(0.0, 1.0)),
(SCALE_FACTORS_COLUMN, rand_factors(0.0005, 0.0015)),
(ERROR_FACTORS_COLUMN, rand_factors(0.0, 0.01)),
(
RABIT_BLOCKED_EX_CODE_COLUMN,
Arc::new(
FixedSizeListArray::try_new_from_values(
UInt8Array::from(ex_codes),
ex_code_len as i32,
)
.unwrap(),
) as ArrayRef,
),
(EX_ADD_FACTORS_COLUMN, rand_factors(0.0, 1.0)),
(EX_SCALE_FACTORS_COLUMN, rand_factors(0.00003, 0.0001)),
])
.unwrap();
let storage =
RabitQuantizationStorage::try_from_batch(batch, &metadata, DistanceType::L2, None).unwrap();
let query: ArrayRef = Arc::new(Float32Array::from(
(0..TOPK_DIM)
.map(|_| rng.random_range(-1.0f32..1.0))
.collect::<Vec<_>>(),
));

for (label, approx_mode) in [
("normal", ApproxMode::Normal),
("accurate", ApproxMode::Accurate),
] {
let mut f32_scratch = Vec::new();
let calc = storage.dist_calculator_with_scratch(
query.clone(),
1.0,
None,
&mut f32_scratch,
DistanceCalculatorOptions { approx_mode },
);
let mut heap = BinaryHeap::with_capacity(TOPK_K + 1);
let mut dists = Vec::new();
let mut u16_scratch = Vec::new();
let mut u8_scratch = Vec::new();
let mut u32_scratch = Vec::new();
c.bench_function(
format!(
"RQ heap topk ({label}): num_bits={NUM_BITS}, DIM={TOPK_DIM}, rows={TOPK_ROWS}, k={TOPK_K}"
)
.as_str(),
|b| {
b.iter(|| {
heap.clear();
calc.accumulate_topk_with_scratch(
TOPK_K,
None,
None,
|id| id as u64,
&mut heap,
&mut dists,
&mut u16_scratch,
&mut u8_scratch,
&mut u32_scratch,
);
black_box(heap.len())
})
},
);
}
}

criterion_group!(
name=benches;
config = Criterion::default().measurement_time(Duration::from_secs(10));
targets = construct_dist_table, compute_distances, ex_dot_kernels, ex_code_storage_load, ex_bulk_paths);
targets = construct_dist_table, compute_distances, ex_dot_kernels, ex_code_storage_load, ex_bulk_paths, heap_topk);

criterion_main!(benches);
1 change: 1 addition & 0 deletions rust/lance-index/src/vector/bq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::vector::quantizer::QuantizerBuildParams;
pub mod builder;
pub(crate) mod dist_table_quant;
pub mod ex_dot;
pub mod prune;
pub mod rotation;
pub mod storage;
pub mod transform;
Expand Down
Loading
Loading