diff --git a/encodings/fastlanes/src/bitpacking/compute/filter.rs b/encodings/fastlanes/src/bitpacking/compute/filter.rs index aab597e9ff3..06e50e09bf0 100644 --- a/encodings/fastlanes/src/bitpacking/compute/filter.rs +++ b/encodings/fastlanes/src/bitpacking/compute/filter.rs @@ -226,9 +226,7 @@ mod test { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let unpacked = PrimitiveArray::from_iter((0..4096).map(|i| (i % 63) as u8)); let bitpacked = BitPackedData::encode(&unpacked.into_array(), 6, &mut ctx).unwrap(); - let filtered = bitpacked - .filter(Mask::from_indices(4096, (0..1024).collect())) - .unwrap(); + let filtered = bitpacked.filter(Mask::from_indices(4096, 0..1024)).unwrap(); let filtered_prim = filtered.execute::(&mut ctx).unwrap(); assert_arrays_eq!( filtered_prim, @@ -243,7 +241,7 @@ mod test { let unpacked = PrimitiveArray::new(values.clone(), Validity::NonNullable); let bitpacked = BitPackedData::encode(&unpacked.into_array(), 9, &mut ctx).unwrap(); let filtered = bitpacked - .filter(Mask::from_indices(values.len(), (0..250).collect())) + .filter(Mask::from_indices(values.len(), 0..250)) .unwrap() .execute::(&mut ctx) .unwrap(); diff --git a/vortex-array/benches/filter_bool.rs b/vortex-array/benches/filter_bool.rs index f947518d6cd..5699fdb19cb 100644 --- a/vortex-array/benches/filter_bool.rs +++ b/vortex-array/benches/filter_bool.rs @@ -110,7 +110,7 @@ fn make_dense_runs(len: usize, false_rate: f64, rng: &mut StdRng) -> Mask { fn make_single_slice(len: usize, density: f64) -> Mask { let true_count = (len as f64 * density) as usize; let start = (len - true_count) / 2; - Mask::from_indices(len, (start..start + true_count).collect()) + Mask::from_indices(len, start..start + true_count) } // --- Source array generators --- diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index f1fd0a32e9b..1e9cf3ab0ce 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -1629,7 +1629,7 @@ mod test { .unwrap(); // Keep all indices (mask with indices 0-9) - let mask = Mask::from_indices(10, (0..10).collect()); + let mask = Mask::from_indices(10, 0..10); let filtered = patches .filter(&mask, &mut LEGACY_SESSION.create_execution_ctx()) .unwrap() diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 577e3eb3d2f..046a75bc49a 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -1187,8 +1187,8 @@ mod tests { let n = 100usize; // Branch 0: even rows → 0, Branch 1: odd rows → 1, Else: never reached. - let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect()); - let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect()); + let branch0_mask = Mask::from_indices(n, (0..n).step_by(2)); + let branch1_mask = Mask::from_indices(n, (1..n).step_by(2)); let result = merge_case_branches( vec![ diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index ddb662459a7..6403509f5a8 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -432,7 +432,7 @@ mod tests { builder.finish() }; - let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect()); + let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0)); let mask_array = mask.clone().into_array(); let mut ctx = LEGACY_SESSION.create_execution_ctx(); diff --git a/vortex-buffer/public-api.lock b/vortex-buffer/public-api.lock index f6c52f9df89..4e0b1190a7a 100644 --- a/vortex-buffer/public-api.lock +++ b/vortex-buffer/public-api.lock @@ -256,7 +256,7 @@ pub fn vortex_buffer::BitBuffer::empty() -> Self pub fn vortex_buffer::BitBuffer::false_count(&self) -> usize -pub fn vortex_buffer::BitBuffer::from_indices(usize, &[usize]) -> vortex_buffer::BitBuffer +pub fn vortex_buffer::BitBuffer::from_indices(usize, impl core::iter::traits::collect::IntoIterator) -> vortex_buffer::BitBuffer pub fn vortex_buffer::BitBuffer::full(bool, usize) -> Self @@ -452,7 +452,7 @@ pub fn vortex_buffer::BitBufferMut::freeze(self) -> vortex_buffer::BitBuffer pub fn vortex_buffer::BitBufferMut::from_buffer(vortex_buffer::ByteBufferMut, usize, usize) -> Self -pub fn vortex_buffer::BitBufferMut::from_indices(usize, &[usize]) -> vortex_buffer::BitBufferMut +pub fn vortex_buffer::BitBufferMut::from_indices(usize, impl core::iter::traits::collect::IntoIterator) -> vortex_buffer::BitBufferMut pub fn vortex_buffer::BitBufferMut::full(bool, usize) -> Self diff --git a/vortex-buffer/src/bit/buf.rs b/vortex-buffer/src/bit/buf.rs index 61cc5e30044..d9c30ea1917 100644 --- a/vortex-buffer/src/bit/buf.rs +++ b/vortex-buffer/src/bit/buf.rs @@ -139,7 +139,7 @@ impl BitBuffer { } /// Create a bit buffer of `len` with `indices` set as true. - pub fn from_indices(len: usize, indices: &[usize]) -> BitBuffer { + pub fn from_indices(len: usize, indices: impl IntoIterator) -> BitBuffer { BitBufferMut::from_indices(len, indices).freeze() } @@ -650,6 +650,24 @@ mod tests { assert_eq!(sliced.offset(), 2); } + #[test] + fn test_from_indices_dense_crosses_words() { + let len = 130; + let indices = (0..len).filter(|idx| idx % 3 != 1); + let buf = BitBuffer::from_indices(len, indices); + + assert_eq!(buf.len(), len); + for idx in 0..len { + assert_eq!(buf.value(idx), idx % 3 != 1, "mismatch at {idx}"); + } + } + + #[test] + #[should_panic(expected = "index 5 exceeds len 5")] + fn test_from_indices_out_of_bounds() { + BitBuffer::from_indices(5, [0, 5]); + } + #[rstest] #[case(5)] #[case(8)] diff --git a/vortex-buffer/src/bit/buf_mut.rs b/vortex-buffer/src/bit/buf_mut.rs index bf42e92d571..d1c96f069e8 100644 --- a/vortex-buffer/src/bit/buf_mut.rs +++ b/vortex-buffer/src/bit/buf_mut.rs @@ -165,11 +165,21 @@ impl BitBufferMut { } /// Create a bit buffer of `len` with `indices` set as true. - pub fn from_indices(len: usize, indices: &[usize]) -> BitBufferMut { - let mut buf = BitBufferMut::new_unset(len); - // TODO(ngates): for dense indices, we can do better by collecting into u64s. - indices.iter().for_each(|&idx| buf.set(idx)); - buf + pub fn from_indices(len: usize, indices: impl IntoIterator) -> BitBufferMut { + let mut buffer = BufferMut::::zeroed(len.div_ceil(64)); + for idx in indices { + assert!(idx < len, "index {idx} exceeds len {len}"); + buffer.as_mut_slice()[idx / 64] |= 1 << (idx % 64); + } + + let mut buffer = buffer.into_byte_buffer(); + buffer.truncate(len.div_ceil(8)); + + Self { + buffer, + offset: 0, + len, + } } /// Invokes `f` with indexes `0..len` collecting the boolean results into a new `BitBufferMut` diff --git a/vortex-mask/benches/intersect_by_rank.rs b/vortex-mask/benches/intersect_by_rank.rs index 31c9f69f40f..bd6c5f18c5b 100644 --- a/vortex-mask/benches/intersect_by_rank.rs +++ b/vortex-mask/benches/intersect_by_rank.rs @@ -68,12 +68,7 @@ fn create_random_mask(len: usize, selectivity: f64) -> Mask { fn create_random_indices_mask(len: usize, selectivity: f64) -> Mask { #[expect(clippy::cast_possible_truncation, clippy::cast_sign_loss)] let threshold = (selectivity * 1000.0) as usize; - Mask::from_indices( - len, - (0..len) - .filter(|&i| (i * 7 + 13) % 1000 < threshold) - .collect(), - ) + Mask::from_indices(len, (0..len).filter(|&i| (i * 7 + 13) % 1000 < threshold)) } fn create_runs_mask(len: usize, run_len: usize, gap_len: usize) -> Mask { diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index d61e008ae72..006bdafaa64 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -56,7 +56,7 @@ pub fn vortex_mask::Mask::from_buffer(vortex_buffer::bit::buf::BitBuffer) -> Sel pub fn vortex_mask::Mask::from_excluded_indices(usize, impl core::iter::traits::collect::IntoIterator) -> Self -pub fn vortex_mask::Mask::from_indices(usize, alloc::vec::Vec) -> Self +pub fn vortex_mask::Mask::from_indices(usize, impl core::iter::traits::collect::IntoIterator) -> Self pub fn vortex_mask::Mask::from_intersection_indices(usize, impl core::iter::traits::iterator::Iterator, impl core::iter::traits::iterator::Iterator) -> Self diff --git a/vortex-mask/src/intersect_by_rank.rs b/vortex-mask/src/intersect_by_rank.rs index efce6f17dbd..7b126233249 100644 --- a/vortex-mask/src/intersect_by_rank.rs +++ b/vortex-mask/src/intersect_by_rank.rs @@ -1,17 +1,606 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use crate::AllOr; +use std::sync::Arc; + +use vortex_buffer::BitBuffer; +use vortex_buffer::BufferMut; + use crate::Mask; +use crate::MaskValues; + +trait DepositBits { + /// Whether the implementation benefits from short-circuiting on `rank_bits == 0` + /// and `self_chunk == u64::MAX`. The portable path loops `popcount(mask)` times, + /// so an all-ones mask is genuinely expensive; BMI2 PDEP is constant-time and + /// the branches just add mispredict cost. + const PREFER_BRANCHES: bool; + + fn deposit_bits(source: u64, mask: u64, mask_count: usize) -> u64; +} + +trait SelectBit { + /// Position (0..63) of the `rank`-th set bit in `word`. Caller ensures + /// `rank < word.count_ones()`. + fn select_bit_position(word: u64, rank: usize) -> usize; +} + +struct PortableDeposit; + +impl DepositBits for PortableDeposit { + const PREFER_BRANCHES: bool = true; + + #[inline] + fn deposit_bits(source: u64, mask: u64, mask_count: usize) -> u64 { + if mask_count >= 16 && count_ones(source) * 8 < mask_count { + return deposit_sparse_source(source, mask); + } + + deposit_by_mask(source, mask) + } +} + +struct PortableSelect; + +impl SelectBit for PortableSelect { + #[inline] + fn select_bit_position(word: u64, rank: usize) -> usize { + select_bit_position_portable(word, rank) + } +} + +#[inline] +fn deposit_by_mask(mut source: u64, mut mask: u64) -> u64 { + let mut result = 0u64; + while mask != 0 { + let bit = mask & mask.wrapping_neg(); + if source & 1 != 0 { + result |= bit; + } + source >>= 1; + mask &= mask - 1; + } + result +} + +#[inline] +fn deposit_sparse_source(mut source: u64, mask: u64) -> u64 { + let mut result = 0u64; + while source != 0 { + result |= select_set_bit(mask, trailing_zeros(source)); + source &= source - 1; + } + result +} + +#[inline] +fn select_set_bit(word: u64, rank: usize) -> u64 { + 1u64 << select_bit_position_portable(word, rank) +} + +#[inline] +fn select_bit_position_portable(word: u64, mut rank: usize) -> usize { + debug_assert!(rank < count_ones(word)); + let mut bit_offset = 0usize; + for byte in word.to_le_bytes() { + let count = count_ones_byte(byte); + if rank < count { + let mut bits = byte; + for _ in 0..rank { + bits &= bits - 1; + } + + return bit_offset + trailing_zeros_byte(bits); + } + + rank -= count; + bit_offset += 8; + } + + debug_assert!(false, "rank out of bounds"); + 0 +} + +#[inline] +fn count_ones_byte(value: u8) -> usize { + value.count_ones() as usize +} + +#[inline] +fn trailing_zeros(value: u64) -> usize { + value.trailing_zeros() as usize +} + +#[inline] +fn trailing_zeros_byte(value: u8) -> usize { + value.trailing_zeros() as usize +} + +struct BitWords<'a> { + buffer: &'a BitBuffer, + bytes: Option<&'a [u8]>, + chunk_len: usize, + remainder_len: usize, +} + +impl<'a> BitWords<'a> { + fn new(buffer: &'a BitBuffer) -> Self { + Self { + buffer, + bytes: (buffer.offset() == 0) + .then(|| &buffer.inner().as_slice()[..buffer.len().div_ceil(8)]), + chunk_len: buffer.len() / 64, + remainder_len: buffer.len() % 64, + } + } + + #[inline] + fn chunk_len(&self) -> usize { + self.chunk_len + } + + #[inline] + fn remainder_len(&self) -> usize { + self.remainder_len + } + + #[inline] + fn word(&self, chunk_idx: usize) -> u64 { + debug_assert!(chunk_idx < self.chunk_len); + if let Some(bytes) = self.bytes { + return read_u64(&bytes[chunk_idx * 8..chunk_idx * 8 + 8]); + } + + self.pack_bits(chunk_idx * 64, 64) + } + + #[inline] + fn remainder_bits(&self) -> u64 { + if self.remainder_len == 0 { + return 0; + } + + if let Some(bytes) = self.bytes { + let start = self.chunk_len * 8; + return read_u64(&bytes[start..]) & low_bits(self.remainder_len); + } + + self.pack_bits(self.chunk_len * 64, self.remainder_len) + } + + #[inline] + fn iter(&self) -> BitWordsIter<'_, 'a> { + BitWordsIter { + words: self, + chunk_idx: 0, + } + } + + #[inline] + fn pack_bits(&self, start: usize, bit_count: usize) -> u64 { + let mut word = 0u64; + for bit_idx in 0..bit_count { + word |= (self.buffer.value(start + bit_idx) as u64) << bit_idx; + } + word + } +} + +struct BitWordsIter<'w, 'a> { + words: &'w BitWords<'a>, + chunk_idx: usize, +} + +impl Iterator for BitWordsIter<'_, '_> { + type Item = u64; + + fn next(&mut self) -> Option { + if self.chunk_idx == self.words.chunk_len { + return None; + } + + let word = self.words.word(self.chunk_idx); + self.chunk_idx += 1; + Some(word) + } +} + +#[inline] +fn read_u64(bytes: &[u8]) -> u64 { + let mut value = [0; 8]; + value[..bytes.len().min(8)].copy_from_slice(&bytes[..bytes.len().min(8)]); + u64::from_le_bytes(value) +} + +#[cfg(target_arch = "x86_64")] +struct Bmi2Deposit; + +#[cfg(target_arch = "x86_64")] +impl DepositBits for Bmi2Deposit { + const PREFER_BRANCHES: bool = false; + + #[inline] + fn deposit_bits(source: u64, mask: u64, _mask_count: usize) -> u64 { + // SAFETY: callers only instantiate this implementation after checking BMI2 support. + unsafe { pdep_bmi2(source, mask) } + } +} + +#[cfg(target_arch = "x86_64")] +struct Bmi2Select; + +#[cfg(target_arch = "x86_64")] +impl SelectBit for Bmi2Select { + #[inline] + fn select_bit_position(word: u64, rank: usize) -> usize { + // SAFETY: callers only instantiate this implementation after checking BMI2 support. + unsafe { select_bit_position_bmi2(word, rank) } + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "bmi2")] +unsafe fn pdep_bmi2(source: u64, mask: u64) -> u64 { + core::arch::x86_64::_pdep_u64(source, mask) +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "bmi2")] +unsafe fn select_bit_position_bmi2(word: u64, rank: usize) -> usize { + debug_assert!(rank < word.count_ones() as usize); + // PDEP places the rank-th bit of source into the rank-th set bit of mask, returning a single + // bit at the desired position. + let bit = core::arch::x86_64::_pdep_u64(1u64 << rank, word); + bit.trailing_zeros() as usize +} + +/// Reader that pulls variable-length (0..=64 bit) groups from a [`BitBuffer`] sequentially. +/// +/// Maintains a 128-bit window over two consecutive chunks (`current`, `next`) and uses a +/// funnel shift via `u128` to extract bits at any offset without branching. The shift +/// pattern compiles to a single funnel-shift / SHRD-style sequence on x86_64. +struct RankBitReader<'a> { + chunks: BitWords<'a>, + chunk_idx: usize, + remainder: u64, + current: u64, + next: u64, + bit_offset: usize, + remainder_loaded: bool, +} + +impl<'a> RankBitReader<'a> { + fn new(buffer: &'a BitBuffer) -> Self { + let chunks = BitWords::new(buffer); + let remainder = chunks.remainder_bits(); + let mut chunk_idx = 0usize; + let mut remainder_loaded = false; + + let current = if chunks.chunk_len() == 0 { + remainder_loaded = true; + remainder + } else { + let word = chunks.word(chunk_idx); + chunk_idx += 1; + word + }; + let next = if chunk_idx < chunks.chunk_len() { + let word = chunks.word(chunk_idx); + chunk_idx += 1; + word + } else if !remainder_loaded { + remainder_loaded = true; + remainder + } else { + 0 + }; + + Self { + chunks, + chunk_idx, + remainder, + current, + next, + bit_offset: 0, + remainder_loaded, + } + } + + #[inline] + fn fetch_next(&mut self) -> u64 { + if self.chunk_idx < self.chunks.chunk_len() { + let word = self.chunks.word(self.chunk_idx); + self.chunk_idx += 1; + word + } else if !self.remainder_loaded { + self.remainder_loaded = true; + self.remainder + } else { + 0 + } + } + + #[inline] + fn read(&mut self, bit_count: usize) -> u64 { + debug_assert!(bit_count <= 64); + + // Funnel shift: extract `bit_count` bits at `bit_offset` from the (next:current) + // 128-bit window. For bit_offset in 0..=63 this is a single SHRD-style instruction + // on x86_64; the u128 cast keeps it well-defined when bit_offset == 0. + let combined = ((self.next as u128) << 64) | (self.current as u128); + // The truncation is intentional: we want the low 64 bits of the funnel-shifted + // window, which is exactly what `as u64` produces. + #[expect(clippy::cast_possible_truncation)] + let bits = (combined >> self.bit_offset) as u64 & low_bits(bit_count); + + let new_offset = self.bit_offset + bit_count; + if new_offset >= 64 { + self.current = self.next; + self.next = self.fetch_next(); + self.bit_offset = new_offset - 64; + } else { + self.bit_offset = new_offset; + } + + bits + } +} + +#[inline] +fn low_bits(bit_count: usize) -> u64 { + debug_assert!(bit_count <= 64); + if bit_count == 64 { + u64::MAX + } else { + (1u64 << bit_count) - 1 + } +} + +#[inline] +fn count_ones(value: u64) -> usize { + value.count_ones() as usize +} + +#[inline] +fn mask_from_buffer(buffer: BitBuffer, true_count: usize) -> Mask { + let len = buffer.len(); + if true_count == 0 { + return Mask::new_false(len); + } + if true_count == len { + return Mask::new_true(len); + } + + Mask::Values(Arc::new(MaskValues { + buffer, + indices: Default::default(), + slices: Default::default(), + true_count, + density: true_count as f64 / len as f64, + })) +} + +#[inline] +fn push_result_chunk( + result: &mut BufferMut, + self_chunk: u64, + self_count: usize, + rank_bits: u64, +) { + // The portable deposit loops `popcount(self_chunk)` times, so an all-ones mask is + // genuinely 64x more expensive than the early returns; for BMI2 PDEP both inputs run + // in constant time and unpredictable branches just add mispredict overhead. + let chunk = if D::PREFER_BRANCHES { + if rank_bits == 0 { + 0 + } else if self_chunk == u64::MAX { + rank_bits + } else { + D::deposit_bits(rank_bits, self_chunk, self_count) + } + } else { + D::deposit_bits(rank_bits, self_chunk, self_count) + }; + + // SAFETY: callers allocate enough capacity for every output chunk. + unsafe { result.push_unchecked(chunk) }; +} + +fn intersect_bit_buffers( + self_buffer: &BitBuffer, + mask_buffer: &BitBuffer, + true_count: usize, +) -> Mask { + let len = self_buffer.len(); + let mut result = BufferMut::with_capacity(len.div_ceil(64)); + let mut reader = RankBitReader::new(mask_buffer); + let self_chunks = BitWords::new(self_buffer); + + for self_chunk in self_chunks.iter() { + let self_count = count_ones(self_chunk); + let rank_bits = reader.read(self_count); + push_result_chunk::(&mut result, self_chunk, self_count, rank_bits); + } + + if self_chunks.remainder_len() != 0 { + let self_chunk = self_chunks.remainder_bits(); + let self_count = count_ones(self_chunk); + let rank_bits = reader.read(self_count); + push_result_chunk::(&mut result, self_chunk, self_count, rank_bits); + } + + mask_from_buffer( + BitBuffer::new(result.freeze().into_byte_buffer(), len), + true_count, + ) +} + +fn intersect_bit_buffer_by_rank_indices( + self_buffer: &BitBuffer, + mask_indices: &[usize], +) -> Mask { + let len = self_buffer.len(); + let mut result = BufferMut::with_capacity(len.div_ceil(64)); + let self_chunks = BitWords::new(self_buffer); + let mut rank_base = 0usize; + let mut rank_idx = 0usize; + + for self_chunk in self_chunks.iter() { + let self_count = count_ones(self_chunk); + let next_rank_base = rank_base + self_count; + let rank_bits = rank_bits_for_chunk(mask_indices, &mut rank_idx, rank_base, next_rank_base); + push_result_chunk::(&mut result, self_chunk, self_count, rank_bits); + rank_base = next_rank_base; + } + + if self_chunks.remainder_len() != 0 { + let self_chunk = self_chunks.remainder_bits(); + let self_count = count_ones(self_chunk); + let next_rank_base = rank_base + self_count; + let rank_bits = rank_bits_for_chunk(mask_indices, &mut rank_idx, rank_base, next_rank_base); + push_result_chunk::(&mut result, self_chunk, self_count, rank_bits); + } + + debug_assert_eq!(rank_idx, mask_indices.len()); + + mask_from_buffer( + BitBuffer::new(result.freeze().into_byte_buffer(), len), + mask_indices.len(), + ) +} + +/// Walks `mask_indices` (global ranks into `self_buffer.set_bits`) and emits the corresponding +/// positions in `self_buffer`. For each rank, advances `self_buffer`'s chunks via popcount +/// skip-while, then locates the bit inside the current chunk with rank-select. +/// +/// This dominates the chunk-scan paths when the mask is very sparse: cost is +/// `O(mask.true_count() + self.len() / 64)` rather than `O(self.len() / 64)` per chunk. +fn intersect_mask_driven(self_buffer: &BitBuffer, mask_indices: I, true_count: usize) -> Mask +where + S: SelectBit, + I: Iterator, +{ + let len = self_buffer.len(); + if true_count == 0 { + return Mask::new_false(len); + } + + let chunks = BitWords::new(self_buffer); + let remainder = chunks.remainder_bits(); + let mut chunk_iter = chunks.iter(); + + let (mut current_chunk, mut on_remainder) = match chunk_iter.next() { + Some(c) => (c, false), + None => (remainder, true), + }; + let mut current_count = count_ones(current_chunk); + let mut current_chunk_idx = 0usize; + let mut rank_before = 0usize; + + let mut output: Vec = Vec::with_capacity(true_count); + + for global_rank in mask_indices { + while rank_before + current_count <= global_rank { + rank_before += current_count; + current_chunk_idx += 1; + current_chunk = match chunk_iter.next() { + Some(c) => c, + None if !on_remainder => { + on_remainder = true; + remainder + } + None => { + debug_assert!(false, "mask index out of bounds"); + 0 + } + }; + current_count = count_ones(current_chunk); + } + + let local_rank = global_rank - rank_before; + let bit_pos = S::select_bit_position(current_chunk, local_rank); + output.push(current_chunk_idx * 64 + bit_pos); + } + + debug_assert_eq!(output.len(), true_count); + Mask::from_indices(len, output) +} + +#[inline] +fn rank_bits_for_chunk( + mask_indices: &[usize], + rank_idx: &mut usize, + rank_base: usize, + next_rank_base: usize, +) -> u64 { + let mut rank_bits = 0u64; + while let Some(&rank) = mask_indices.get(*rank_idx) { + if rank >= next_rank_base { + break; + } + rank_bits |= 1u64 << (rank - rank_base); + *rank_idx += 1; + } + rank_bits +} + +fn intersect_by_rank_indices(len: usize, self_indices: &[usize], mask_indices: &[usize]) -> Mask { + Mask::from_indices( + len, + mask_indices.iter().map(|idx| { + // SAFETY: mask indices are ranks into self_indices, because + // mask.len() == self.true_count() == self_indices.len(). + unsafe { *self_indices.get_unchecked(*idx) } + }), + ) +} + +#[inline] +fn intersect_bit_buffers_dispatch( + self_buffer: &BitBuffer, + mask_buffer: &BitBuffer, + true_count: usize, +) -> Mask { + #[cfg(target_arch = "x86_64")] + if std::arch::is_x86_feature_detected!("bmi2") { + return intersect_bit_buffers::(self_buffer, mask_buffer, true_count); + } + + intersect_bit_buffers::(self_buffer, mask_buffer, true_count) +} + +#[inline] +fn intersect_rank_indices_dispatch(self_buffer: &BitBuffer, mask_indices: &[usize]) -> Mask { + #[cfg(target_arch = "x86_64")] + if std::arch::is_x86_feature_detected!("bmi2") { + return intersect_bit_buffer_by_rank_indices::(self_buffer, mask_indices); + } + + intersect_bit_buffer_by_rank_indices::(self_buffer, mask_indices) +} + +#[inline] +fn intersect_mask_driven_dispatch( + self_buffer: &BitBuffer, + mask_indices: I, + true_count: usize, +) -> Mask +where + I: Iterator, +{ + #[cfg(target_arch = "x86_64")] + if std::arch::is_x86_feature_detected!("bmi2") { + return intersect_mask_driven::(self_buffer, mask_indices, true_count); + } + + intersect_mask_driven::(self_buffer, mask_indices, true_count) +} impl Mask { /// Take the intersection of the `mask` with the set of true values in `self`. /// - /// We are more interested in low selectivity `self` (as indices) with a boolean buffer mask, - /// so we don't optimize for other cases, yet. - /// - /// Note: we might be able to accelerate this function on x86 with BMI, see: - /// + /// The hot path keeps bit-buffer-backed masks as bit buffers. It scans the set bits of `self` + /// by rank and deposits selected rank bits into their original positions. /// /// # Examples /// @@ -29,22 +618,70 @@ impl Mask { pub fn intersect_by_rank(&self, mask: &Mask) -> Mask { assert_eq!(self.true_count(), mask.len()); - match (self.indices(), mask.indices()) { - (AllOr::All, _) => mask.clone(), - (_, AllOr::All) => self.clone(), - (AllOr::None, _) | (_, AllOr::None) => Self::new_false(self.len()), - - (AllOr::Some(self_indices), AllOr::Some(mask_indices)) => { - Self::from_indices( - self.len(), - mask_indices - .iter() - .map(|idx| - // This is verified as safe because we know that the indices are less than the - // mask.len() and we known mask.len() <= self.len(), - // implied by `self.true_count() == mask.len()`. - unsafe{*self_indices.get_unchecked(*idx)}) - .collect(), + match (self, mask) { + (Self::AllTrue(_), _) => mask.clone(), + (_, Self::AllTrue(_)) => self.clone(), + (Self::AllFalse(_), _) | (_, Self::AllFalse(_)) => Self::new_false(self.len()), + (Self::Values(self_values), Self::Values(mask_values)) => { + let self_is_very_sparse = self_values.true_count() < self.len().div_ceil(64); + // The mask-driven path becomes worthwhile around ~3% mask density: each set + // bit costs a select + push, but we save a per-self-chunk popcount + deposit. + let mask_is_very_sparse = mask_values.true_count().saturating_mul(32) < mask.len(); + + if let Some(mask_indices) = mask_values.indices.get() { + if let Some(self_indices) = self_values.indices.get() + && mask_indices.len() < self.len().div_ceil(64) + { + return intersect_by_rank_indices(self.len(), self_indices, mask_indices); + } + + if self_is_very_sparse { + return intersect_by_rank_indices( + self.len(), + self_values.indices(), + mask_indices, + ); + } + + if mask_is_very_sparse { + return intersect_mask_driven_dispatch( + self_values.bit_buffer(), + mask_indices.iter().copied(), + mask_values.true_count(), + ); + } + + if mask_indices.len().saturating_mul(4) > mask.len() { + return intersect_bit_buffers_dispatch( + self_values.bit_buffer(), + mask_values.bit_buffer(), + mask_values.true_count(), + ); + } + + return intersect_rank_indices_dispatch(self_values.bit_buffer(), mask_indices); + } + + if self_is_very_sparse { + return intersect_by_rank_indices( + self.len(), + self_values.indices(), + mask_values.indices(), + ); + } + + if mask_is_very_sparse { + return intersect_mask_driven_dispatch( + self_values.bit_buffer(), + mask_values.bit_buffer().set_indices(), + mask_values.true_count(), + ); + } + + intersect_bit_buffers_dispatch( + self_values.bit_buffer(), + mask_values.bit_buffer(), + mask_values.true_count(), ) } } @@ -196,4 +833,112 @@ mod test { _ => panic!("Unexpected result"), } } + + #[rstest] + // Larger sizes to push the bench-shaped buffer paths through the unit tests too. + #[case::dense_len_1024(1024, 31, 0.5, 0.5)] + // Very-sparse mask exercises the mask-driven dispatch path. Both densities live in + // the half-open interval where `mask_is_very_sparse` is true. + #[case::sparse_mask_1pct(1024, 17, 0.5, 0.01)] + #[case::sparse_mask_2pct(2048, 0, 0.5, 0.02)] + #[case::very_sparse_mask_with_offsets(513, 5, 0.5, 0.005)] + fn test_intersect_by_rank_density_matrix( + #[case] base_len: usize, + #[case] base_offset: usize, + #[case] base_density: f64, + #[case] rank_density: f64, + ) { + #[expect(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let base_threshold = (base_density * 1024.0) as usize; + #[expect(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let rank_threshold = (rank_density * 1024.0) as usize; + + let base_source: Vec = (0..base_len + base_offset + 16) + .map(|i| (i * 7 + 13) % 1024 < base_threshold) + .collect(); + let base_bits = base_source[base_offset..base_offset + base_len].to_vec(); + let base = Mask::from_buffer( + BitBuffer::from(base_source).slice(base_offset..base_offset + base_len), + ); + + let rank_len = base.true_count(); + let rank_bits: Vec = (0..rank_len) + .map(|i| (i * 11 + 7) % 1024 < rank_threshold) + .collect(); + let rank_from_buffer = Mask::from_buffer(BitBuffer::from(rank_bits.clone())); + let rank_indices_vec = rank_bits + .iter() + .enumerate() + .filter_map(|(idx, &v)| v.then_some(idx)) + .collect::>(); + let rank_from_indices = Mask::from_indices(rank_len, rank_indices_vec); + + let expected = expected_intersect_by_rank(&base_bits, &rank_bits); + + assert_eq!( + base.intersect_by_rank(&rank_from_buffer), + expected, + "uncached rank" + ); + assert_eq!( + base.intersect_by_rank(&rank_from_indices), + expected, + "cached rank" + ); + } + + #[rstest] + #[case::short(37, 0, 0)] + #[case::base_offset(257, 5, 0)] + #[case::rank_offset(257, 0, 3)] + #[case::both_offsets(513, 6, 5)] + fn test_intersect_by_rank_bitbuffer_paths_with_offsets( + #[case] base_len: usize, + #[case] base_offset: usize, + #[case] rank_offset: usize, + ) { + let base_source: Vec = (0..base_len + base_offset + 16) + .map(|i| (i % 3 == 0) ^ (i % 11 == 0) ^ (i % 17 == 0)) + .collect(); + let base_bits = base_source[base_offset..base_offset + base_len].to_vec(); + let base = Mask::from_buffer( + BitBuffer::from(base_source).slice(base_offset..base_offset + base_len), + ); + + let rank_len = base.true_count(); + let rank_bits: Vec = (0..rank_len) + .map(|i| (i % 5 == 0) || (i % 13 == 3)) + .collect(); + let mut rank_source = vec![false; rank_offset]; + rank_source.extend(rank_bits.iter().copied()); + rank_source.extend([true, false, true, false, true, false, true, false]); + + let rank_from_buffer = Mask::from_buffer( + BitBuffer::from(rank_source).slice(rank_offset..rank_offset + rank_len), + ); + let rank_indices = rank_bits + .iter() + .enumerate() + .filter_map(|(idx, &value)| value.then_some(idx)) + .collect::>(); + let rank_from_indices = Mask::from_indices(rank_len, rank_indices); + + let expected = expected_intersect_by_rank(&base_bits, &rank_bits); + + assert_eq!(base.intersect_by_rank(&rank_from_buffer), expected); + assert_eq!(base.intersect_by_rank(&rank_from_indices), expected); + } + + fn expected_intersect_by_rank(base_bits: &[bool], rank_bits: &[bool]) -> Mask { + let mut rank = 0usize; + Mask::from_iter(base_bits.iter().map(|&is_set| { + if is_set { + let keep = rank_bits[rank]; + rank += 1; + keep + } else { + false + } + })) + } } diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index 898aa767a3a..29b2f5f83e4 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -205,15 +205,17 @@ impl Mask { })) } - /// Create a new [`Mask`] from a [`Vec`]. - // TODO(ngates): this should take an IntoIterator. - pub fn from_indices(len: usize, indices: Vec) -> Self { - let true_count = indices.len(); + /// Create a new [`Mask`] from sorted, unique indices. + pub fn from_indices(len: usize, indices: impl IntoIterator) -> Self { + let indices = indices.into_iter().collect::>(); assert!(indices.is_sorted(), "Mask indices must be sorted"); assert!( - indices.last().is_none_or(|&idx| idx < len), - "Mask indices must be in bounds (len={len})" + indices.windows(2).all(|w| w[0] != w[1]), + "Mask indices must be unique" ); + let buffer = BitBuffer::from_indices(len, indices.iter().copied()); + debug_assert_eq!(buffer.len(), len); + let true_count = buffer.true_count(); if true_count == 0 { return Self::AllFalse(len); @@ -222,13 +224,8 @@ impl Mask { return Self::AllTrue(len); } - let mut buf = BitBufferMut::new_unset(len); - // TODO(ngates): for dense indices, we can do better by collecting into u64s. - indices.iter().for_each(|&idx| buf.set(idx)); - debug_assert_eq!(buf.len(), len); - Self::Values(Arc::new(MaskValues { - buffer: buf.freeze(), + buffer, indices: OnceLock::from(indices), slices: Default::default(), true_count, diff --git a/vortex-mask/src/tests.rs b/vortex-mask/src/tests.rs index ff0bf04f76a..d1496fcb30f 100644 --- a/vortex-mask/src/tests.rs +++ b/vortex-mask/src/tests.rs @@ -443,6 +443,12 @@ fn test_mask_from_indices_unsorted() { Mask::from_indices(5, vec![2, 0, 3]); // Not sorted } +#[test] +#[should_panic] +fn test_mask_from_indices_duplicate() { + Mask::from_indices(5, vec![0, 2, 2]); // Not unique +} + #[test] #[should_panic] fn test_mask_from_indices_out_of_bounds() { diff --git a/vortex-scan/src/selection.rs b/vortex-scan/src/selection.rs index 1374851bec0..5f0bbd79939 100644 --- a/vortex-scan/src/selection.rs +++ b/vortex-scan/src/selection.rs @@ -77,8 +77,7 @@ impl Selection { .filter_map(|idx| { // Only include indices that fit in usize usize::try_from(idx).ok() - }) - .collect(), + }), ) }) .unwrap_or_else(|| Mask::new_false(range_len)); @@ -117,8 +116,7 @@ impl Selection { .filter_map(|idx| { // Only include indices that fit in usize usize::try_from(idx).ok() - }) - .collect(), + }), ); RowMask::new(range.start, mask)