Skip to content

Commit

Permalink
fix: Fix accidental raise on shape 1 (#18748)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Sep 15, 2024
1 parent 5a262db commit 4894e24
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
6 changes: 6 additions & 0 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,12 @@ impl DataFrame {
self.shape().0
}

/// Returns the size as number of rows * number of columns
pub fn size(&self) -> usize {
let s = self.shape();
s.0 * s.1
}

/// Returns `true` if the [`DataFrame`] contains no rows.
///
/// # Example
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-mem-engine/src/executors/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl StackExec {
// possibly mismatching column lengths.
unsafe { df.get_columns_mut() }.extend(res.into_iter().map(Column::from));
} else {
let height = df.height();
let (df_height, df_width) = df.shape();

// When we have CSE we cannot verify scalars yet.
let verify_scalar = if !df.get_columns().is_empty() {
Expand All @@ -78,11 +78,11 @@ impl StackExec {
};
for (i, c) in res.iter().enumerate() {
let len = c.len();
if verify_scalar && len != height && len == 1 {
if verify_scalar && len != df_height && len == 1 && df_width > 0 {
polars_ensure!(self.exprs[i].is_scalar(),
InvalidOperation: "Series {}, length {} doesn't match the DataFrame height of {}\n\n\
If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()').",
c.name(), len, height
c.name(), len, df_height
);
}
}
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/dataframe/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,14 @@ def test_extend_column_name_mismatch() -> None:

with pytest.raises(ShapeError):
df1.extend(df2)


def test_initialize_df_18736() -> None:
# Completely empty initialization
df = pl.DataFrame()
s_0 = pl.Series([])
s_1 = pl.Series([None])
s_2 = pl.Series([None, None])
assert df.with_columns(s_0).shape == (0, 1)
assert df.with_columns(s_1).shape == (1, 1)
assert df.with_columns(s_2).shape == (2, 1)

0 comments on commit 4894e24

Please sign in to comment.