Skip to content

Commit

Permalink
[ENH] Use binary search for gt/gte/lt/lte
Browse files Browse the repository at this point in the history
  • Loading branch information
Sicheng Pan committed Oct 3, 2024
1 parent 91ba2c9 commit 0a15a86
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 288 deletions.
269 changes: 136 additions & 133 deletions rust/blockstore/src/arrow/block/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,88 +66,119 @@ impl Block {
delta
}

fn binary_search<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
/// Binary search the blockfile to find the partition point of the specified prefix and key.
/// The implementation is based on [`std::slice::partition_point`].
///
/// `(prefix, key)` serves as the search key, and it is sorted in ascending order.
/// The partition is defined by: `|x| x < (prefix, key)`.
/// The code is a result of inlining this predicate in [`std::slice::partition_point`].
/// If the key is unspecified (i.e. [`None`]), we find the first index of the prefix.
#[inline]
fn binary_search_index<'me, K: ArrowReadableKey<'me>>(
&'me self,
prefix: &str,
key: K,
) -> Option<V> {
// Copied from std lib https://doc.rust-lang.org/src/core/slice/mod.rs.html#2786
// and modified for two nested level comparisons.
key: Option<&K>,
) -> usize {
let mut size = self.len();
let mut left = 0;
let mut right = size;
let prefix_arr = self
if size == 0 {
return 0;
}

let prefix_array = self
.data
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
while left < right {
let mid = left + size / 2;

let prefix_cmp = prefix_arr.value(mid).cmp(prefix);
// This control flow produces conditional moves, which results in
// fewer branches and instructions than if/else or matching on
// cmp::Ordering.
// This is x86 asm for u8: https://rust.godbolt.org/z/698eYffTx.
left = if prefix_cmp == Less { mid + 1 } else { left };
right = if prefix_cmp == Greater { mid } else { right };
if prefix_cmp == Equal {
let key_cmp = K::get(self.data.column(1), mid)
.partial_cmp(&key)
.expect("NaN not expected"); // NaN not expected
left = if key_cmp == Less { mid + 1 } else { left };
right = if key_cmp == Greater { mid } else { right };
if key_cmp == Equal {
return Some(V::get(self.data.column(2), mid));
}
let mut base = 0;

// This loop intentionally doesn't have an early exit if the comparison
// returns Equal. We want the number of loop iterations to depend *only*
// on the size of the input slice so that the CPU can reliably predict
// the loop count.
while size > 1 {
let half = size / 2;
let mid = base + half;

// SAFETY: the call is made safe by the following inconstants:
// - `mid >= 0`: by definition
// - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
let mut cmp = prefix_array.value(mid).cmp(prefix);

// Continue to compare the key if prefix matches
if let (Equal, Some(k)) = (cmp, key) {
cmp = K::get(self.data.column(1), mid)
.partial_cmp(k)
.expect("Array values should be comparable.");
}

size = right - left;
base = if cmp == Less { mid } else { base };
size -= half;
}

// SAFETY: `base` is always in [0, size) because `base <= mid`.
// `base` should be the last index where the element is smaller than the target,
// or 0 if the first element is already larger than the target.
match prefix_array.value(base).cmp(prefix) {
Less => base + 1,
Equal => match key {
Some(k) => match K::get(self.data.column(1), base).partial_cmp(k) {
Some(Less) => base + 1,
_ => base,
},
None => base,
},
Greater => base,
}
None
}

fn binary_search_prefix<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
#[inline]
fn match_prefix_key_at_index<'me, K: ArrowReadableKey<'me>>(
&'me self,
prefix: &str,
) -> Option<Vec<(&str, K, V)>> {
let prefix_arr = self
key: &K,
index: usize,
) -> bool {
let prefix_array = self
.data
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut size = prefix_arr.len();
let mut left = 0;
let mut right = size;
while left < right {
let mid = left + size / 2;

let predicate = prefix_arr.value(mid) < prefix;

// This control flow produces conditional moves, which results in
// fewer branches and instructions than if/else or matching on
// boolean.
// This is x86 asm for u8: https://rust.godbolt.org/z/698eYffTx.
left = if predicate { mid + 1 } else { left };
right = if !predicate { mid } else { right };

size = right - left;
}

let mut start_idx = left;
let mut res = vec![];
while start_idx < prefix_arr.len() && prefix_arr.value(start_idx) == prefix {
res.push((
prefix_arr.value(start_idx),
K::get(self.data.column(1), start_idx),
V::get(self.data.column(2), start_idx),
));
start_idx += 1;
index < self.len()
&& matches!(
(
prefix_array.value(index).cmp(prefix),
K::get(self.data.column(1), index).partial_cmp(key),
),
(Equal, Some(Equal))
)
}

#[inline]
fn scan_prefix<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&'me self,
prefix: &str,
range: impl Iterator<Item = usize>,
) -> Vec<(K, V)> {
let prefix_array = self
.data
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.expect("The prefix array should be a string arrary.");
let mut result = Vec::new();
for index in range {
if prefix_array.value(index) == prefix {
result.push((
K::get(self.data.column(1), index),
V::get(self.data.column(2), index),
));
} else {
break;
}
}

Some(res)
result
}

/*
Expand All @@ -162,7 +193,12 @@ impl Block {
prefix: &str,
key: K,
) -> Option<V> {
self.binary_search(prefix, key)
let index = self.binary_search_index(prefix, Some(&key));
if self.match_prefix_key_at_index(prefix, &key, index) {
Some(V::get(self.data.column(2), index))
} else {
None
}
}

/// Get all the values for a given prefix in the block
Expand All @@ -171,8 +207,11 @@ impl Block {
pub fn get_prefix<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&'me self,
prefix: &str,
) -> Option<Vec<(&str, K, V)>> {
self.binary_search_prefix(prefix)
) -> Vec<(K, V)> {
self.scan_prefix(
prefix,
self.binary_search_index(prefix, Option::<&K>::None)..self.len(),
)
}

/// Get all the values for a given prefix in the block where the key is greater than the given key
Expand All @@ -182,97 +221,61 @@ impl Block {
&'me self,
prefix: &str,
key: K,
) -> Option<Vec<(&str, K, V)>> {
let prefix_array = self
.data
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut res: Vec<(&str, K, V)> = vec![];
for i in 0..self.data.num_rows() {
let curr_prefix = prefix_array.value(i);
let curr_key = K::get(self.data.column(1), i);
if curr_prefix == prefix && curr_key > key {
res.push((curr_prefix, curr_key, V::get(self.data.column(2), i)));
}
) -> Vec<(K, V)> {
let index = self.binary_search_index(prefix, Some(&key));
if self.match_prefix_key_at_index(prefix, &key, index) {
self.scan_prefix(prefix, index + 1..self.len())
} else {
self.scan_prefix(prefix, index..self.len())
}
Some(res)
}

/// Get all the values for a given prefix in the block where the key is less than the given key
/// Get all the values for a given prefix in the block where the key is greater than or equal to the given key
/// ### Panics
/// - If the underlying data types are not the same as the types specified in the function signature
pub fn get_lt<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
pub fn get_gte<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&'me self,
prefix: &str,
key: K,
) -> Option<Vec<(&str, K, V)>> {
let prefix_array = self
.data
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut res: Vec<(&str, K, V)> = vec![];
for i in 0..self.data.num_rows() {
let curr_prefix = prefix_array.value(i);
let curr_key = K::get(self.data.column(1), i);
if curr_prefix == prefix && curr_key < key {
res.push((curr_prefix, curr_key, V::get(self.data.column(2), i)));
}
}
Some(res)
) -> Vec<(K, V)> {
self.scan_prefix(
prefix,
self.binary_search_index(prefix, Some(&key))..self.len(),
)
}

/// Get all the values for a given prefix in the block where the key is less than or equal to the given key
/// Get all the values for a given prefix in the block where the key is less than the given key
/// ### Panics
/// - If the underlying data types are not the same as the types specified in the function signature
pub fn get_lte<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
pub fn get_lt<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&'me self,
prefix: &str,
key: K,
) -> Option<Vec<(&str, K, V)>> {
let prefix_array = self
.data
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut res: Vec<(&str, K, V)> = vec![];
for i in 0..self.data.num_rows() {
let curr_prefix = prefix_array.value(i);
let curr_key = K::get(self.data.column(1), i);
if curr_prefix == prefix && curr_key <= key {
res.push((curr_prefix, curr_key, V::get(self.data.column(2), i)));
}
}
Some(res)
) -> Vec<(K, V)> {
let mut result = self.scan_prefix(
prefix,
(0..self.binary_search_index(prefix, Some(&key))).rev(),
);
result.reverse();
result
}

/// Get all the values for a given prefix in the block where the key is greater than or equal to the given key
/// Get all the values for a given prefix in the block where the key is less than or equal to the given key
/// ### Panics
/// - If the underlying data types are not the same as the types specified in the function signature
pub fn get_gte<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
pub fn get_lte<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&'me self,
prefix: &str,
key: K,
) -> Option<Vec<(&str, K, V)>> {
let prefix_array = self
.data
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut res: Vec<(&str, K, V)> = vec![];
for i in 0..self.data.num_rows() {
let curr_prefix = prefix_array.value(i);
let curr_key = K::get(self.data.column(1), i);
if curr_prefix == prefix && curr_key >= key {
res.push((curr_prefix, curr_key, V::get(self.data.column(2), i)));
}
}
Some(res)
) -> Vec<(K, V)> {
let index = self.binary_search_index(prefix, Some(&key));
let mut result = if self.match_prefix_key_at_index(prefix, &key, index) {
self.scan_prefix(prefix, (0..=index).rev())
} else {
self.scan_prefix(prefix, (0..index).rev())
};
result.reverse();
result
}

/// Get all the values for a given prefix in the block where the key is between the given keys
Expand Down
Loading

0 comments on commit 0a15a86

Please sign in to comment.