Skip to content

Commit

Permalink
batch splitting tests
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Nov 2, 2023
1 parent 01d8a2b commit a487a0e
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 32 deletions.
140 changes: 139 additions & 1 deletion datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,9 @@ mod tests {
use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder};
use arrow::datatypes::{DataType, Field, Schema};

use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue,
};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::Literal;
use hashbrown::raw::RawTable;
Expand Down Expand Up @@ -2973,6 +2975,142 @@ mod tests {
}
}

#[tokio::test]
async fn join_splitted_batch() {
let left = build_table(
("a1", &vec![1, 2, 3, 4]),
("b1", &vec![1, 1, 1, 1]),
("c1", &vec![0, 0, 0, 0]),
);
let right = build_table(
("a2", &vec![10, 20, 30, 40, 50]),
("b2", &vec![1, 1, 1, 1, 1]),
("c2", &vec![0, 0, 0, 0, 0]),
);
let on = vec![(
Column::new_with_schema("b1", &left.schema()).unwrap(),
Column::new_with_schema("b2", &right.schema()).unwrap(),
)];

let join_types = vec![
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::Full,
JoinType::RightSemi,
JoinType::RightAnti,
JoinType::LeftSemi,
JoinType::LeftAnti,
];
let expected_resultset_records = 20;
let common_result = [
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b2 | c2 |",
"+----+----+----+----+----+----+",
"| 4 | 1 | 0 | 10 | 1 | 0 |",
"| 3 | 1 | 0 | 10 | 1 | 0 |",
"| 2 | 1 | 0 | 10 | 1 | 0 |",
"| 1 | 1 | 0 | 10 | 1 | 0 |",
"| 4 | 1 | 0 | 20 | 1 | 0 |",
"| 3 | 1 | 0 | 20 | 1 | 0 |",
"| 2 | 1 | 0 | 20 | 1 | 0 |",
"| 1 | 1 | 0 | 20 | 1 | 0 |",
"| 4 | 1 | 0 | 30 | 1 | 0 |",
"| 3 | 1 | 0 | 30 | 1 | 0 |",
"| 2 | 1 | 0 | 30 | 1 | 0 |",
"| 1 | 1 | 0 | 30 | 1 | 0 |",
"| 4 | 1 | 0 | 40 | 1 | 0 |",
"| 3 | 1 | 0 | 40 | 1 | 0 |",
"| 2 | 1 | 0 | 40 | 1 | 0 |",
"| 1 | 1 | 0 | 40 | 1 | 0 |",
"| 4 | 1 | 0 | 50 | 1 | 0 |",
"| 3 | 1 | 0 | 50 | 1 | 0 |",
"| 2 | 1 | 0 | 50 | 1 | 0 |",
"| 1 | 1 | 0 | 50 | 1 | 0 |",
"+----+----+----+----+----+----+",
];
let left_batch = [
"+----+----+----+",
"| a1 | b1 | c1 |",
"+----+----+----+",
"| 1 | 1 | 0 |",
"| 2 | 1 | 0 |",
"| 3 | 1 | 0 |",
"| 4 | 1 | 0 |",
"+----+----+----+",
];
let right_batch = [
"+----+----+----+",
"| a2 | b2 | c2 |",
"+----+----+----+",
"| 10 | 1 | 0 |",
"| 20 | 1 | 0 |",
"| 30 | 1 | 0 |",
"| 40 | 1 | 0 |",
"| 50 | 1 | 0 |",
"+----+----+----+",
];
let right_empty = [
"+----+----+----+",
"| a2 | b2 | c2 |",
"+----+----+----+",
"+----+----+----+",
];
let left_empty = [
"+----+----+----+",
"| a1 | b1 | c1 |",
"+----+----+----+",
"+----+----+----+",
];

// validation of partial join results output for different batch_size setting
for join_type in join_types {
for batch_size in (1..21).rev() {
let session_config = SessionConfig::default().with_batch_size(batch_size);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);

let join =
join(left.clone(), right.clone(), on.clone(), &join_type, false)
.unwrap();

let stream = join.execute(0, task_ctx).unwrap();
let batches = common::collect(stream).await.unwrap();

// For inner/right join expected batch count equals ceil_div result,
// as there is no need to append non-joined build side data.
// For other join types it'll be ceil_div + 1 -- for additional batch
// containing not visited build side rows (empty in this test case).
let expected_batch_count = match join_type {
JoinType::Inner
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti => {
(expected_resultset_records + batch_size - 1) / batch_size
}
_ => (expected_resultset_records + batch_size - 1) / batch_size + 1,
};
assert_eq!(
batches.len(),
expected_batch_count,
"expected {} output batches for {} join with batch_size = {}",
expected_batch_count,
join_type,
batch_size
);

let expected = match join_type {
JoinType::RightSemi => right_batch.to_vec(),
JoinType::RightAnti => right_empty.to_vec(),
JoinType::LeftSemi => left_batch.to_vec(),
JoinType::LeftAnti => left_empty.to_vec(),
_ => common_result.to_vec(),
};
assert_batches_eq!(expected, &batches);
}
}
}

#[tokio::test]
async fn single_partition_join_overallocation() -> Result<()> {
let left = build_table(
Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,20 +648,20 @@ fn adjust_indices_by_join_type(
// matched
// unmatched left row will be produced in this batch
let left_unmatched_indices =
get_anti_u64_indices(count_left_batch, &left_indices);
get_anti_u64_indices(0..count_left_batch, &left_indices);
// combine the matched and unmatched left result together
append_left_indices(left_indices, right_indices, left_unmatched_indices)
}
JoinType::LeftSemi => {
// need to remove the duplicated record in the left side
let left_indices = get_semi_u64_indices(count_left_batch, &left_indices);
let left_indices = get_semi_u64_indices(0..count_left_batch, &left_indices);
// the right_indices will not be used later for the `left semi` join
(left_indices, right_indices)
}
JoinType::LeftAnti => {
// need to remove the duplicated record in the left side
// get the anti index for the left side
let left_indices = get_anti_u64_indices(count_left_batch, &left_indices);
let left_indices = get_anti_u64_indices(0..count_left_batch, &left_indices);
// the right_indices will not be used later for the `left anti` join
(left_indices, right_indices)
}
Expand Down
60 changes: 32 additions & 28 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,81 +921,85 @@ pub(crate) fn append_right_indices(

/// Get unmatched and deduplicated indices for specified range of indices
pub(crate) fn get_anti_indices(
rg: Range<usize>,
range: Range<usize>,
input_indices: &UInt32Array,
) -> UInt32Array {
let mut bitmap = BooleanBufferBuilder::new(rg.len());
bitmap.append_n(rg.len(), false);
let mut bitmap = BooleanBufferBuilder::new(range.len());
bitmap.append_n(range.len(), false);
input_indices
.iter()
.flatten()
.map(|v| v as usize)
.filter(|v| rg.contains(v))
.filter(|v| range.contains(v))
.for_each(|v| {
bitmap.set_bit(v - rg.start, true);
bitmap.set_bit(v - range.start, true);
});

let offset = rg.start;
let offset = range.start;

// get the anti index
(rg).filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u32))
(range).filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u32))
.collect::<UInt32Array>()
}

/// Get unmatched and deduplicated indices
pub(crate) fn get_anti_u64_indices(
row_count: usize,
range: Range<usize>,
input_indices: &UInt64Array,
) -> UInt64Array {
let mut bitmap = BooleanBufferBuilder::new(row_count);
bitmap.append_n(row_count, false);
input_indices.iter().flatten().for_each(|v| {
bitmap.set_bit(v as usize, true);
let mut bitmap = BooleanBufferBuilder::new(range.len());
bitmap.append_n(range.len(), false);
input_indices.iter().flatten().map(|v| v as usize).filter(|v| range.contains(v)).for_each(|v| {
bitmap.set_bit(v - range.start, true);
});

let offset = range.start;

// get the anti index
(0..row_count)
.filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u64))
(range)
.filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u64))
.collect::<UInt64Array>()
}

/// Get matched and deduplicated indices for specified range of indices
pub(crate) fn get_semi_indices(
rg: Range<usize>,
range: Range<usize>,
input_indices: &UInt32Array,
) -> UInt32Array {
let mut bitmap = BooleanBufferBuilder::new(rg.len());
bitmap.append_n(rg.len(), false);
let mut bitmap = BooleanBufferBuilder::new(range.len());
bitmap.append_n(range.len(), false);
input_indices
.iter()
.flatten()
.map(|v| v as usize)
.filter(|v| rg.contains(v))
.filter(|v| range.contains(v))
.for_each(|v| {
bitmap.set_bit(v - rg.start, true);
bitmap.set_bit(v - range.start, true);
});

let offset = rg.start;
let offset = range.start;

// get the semi index
(rg).filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u32))
(range).filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u32))
.collect::<UInt32Array>()
}

/// Get matched and deduplicated indices
pub(crate) fn get_semi_u64_indices(
row_count: usize,
range: Range<usize>,
input_indices: &UInt64Array,
) -> UInt64Array {
let mut bitmap = BooleanBufferBuilder::new(row_count);
bitmap.append_n(row_count, false);
input_indices.iter().flatten().for_each(|v| {
bitmap.set_bit(v as usize, true);
let mut bitmap = BooleanBufferBuilder::new(range.len());
bitmap.append_n(range.len(), false);
input_indices.iter().flatten().map(|v| v as usize).filter(|v| range.contains(v)).for_each(|v| {
bitmap.set_bit(v - range.start, true);
});

let offset = range.start;

// get the semi index
(0..row_count)
.filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u64))
(range)
.filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u64))
.collect::<UInt64Array>()
}

Expand Down

0 comments on commit a487a0e

Please sign in to comment.