Skip to content

Commit

Permalink
feat: Delay deserialization of python function until physical plan (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 3, 2024
1 parent d4df3f2 commit 35946cf
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 60 deletions.
9 changes: 4 additions & 5 deletions crates/polars-error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,17 +397,16 @@ macro_rules! polars_ensure {
pub fn to_compute_err(err: impl Display) -> PolarsError {
PolarsError::ComputeError(err.to_string().into())
}

#[macro_export]
macro_rules! feature_gated {
($feature:expr, $content:expr) => {{
#[cfg(feature = $feature)]
($($feature:literal);*, $content:expr) => {{
#[cfg(all($(feature = $feature),*))]
{
$content
}
#[cfg(not(feature = $feature))]
#[cfg(not(all($(feature = $feature),*)))]
{
panic!("activate '{}' feature", $feature)
panic!("activate '{}' feature", concat!($($feature, ", "),*))
}
}};
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ fn create_physical_expr_inner(

Ok(Arc::new(ApplyExpr::new(
input,
function.clone(),
function.clone().materialize()?,
node_to_expr(expression, expr_arena),
*options,
state.allow_threading,
Expand Down
48 changes: 47 additions & 1 deletion crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};

use bytes::Bytes;
use polars_core::chunked_array::cast::CastOptions;
use polars_core::error::feature_gated;
use polars_core::prelude::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -157,7 +159,7 @@ pub enum Expr {
/// function arguments
input: Vec<Expr>,
/// function to apply
function: SpecialEq<Arc<dyn ColumnsUdf>>,
function: OpaqueColumnUdf,
/// output dtype of the function
output_type: GetOutput,
options: FunctionOptions,
Expand All @@ -172,6 +174,50 @@ pub enum Expr {
Selector(super::selector::Selector),
}

pub type OpaqueColumnUdf = LazySerde<SpecialEq<Arc<dyn ColumnsUdf>>>;
pub(crate) fn new_column_udf<F: ColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {
LazySerde::Deserialized(SpecialEq::new(Arc::new(func)))
}

#[derive(Clone)]
pub enum LazySerde<T: Clone> {
Deserialized(T),
Bytes(Bytes),
}

impl<T: PartialEq + Clone> PartialEq for LazySerde<T> {
fn eq(&self, other: &Self) -> bool {
use LazySerde as L;
match (self, other) {
(L::Deserialized(a), L::Deserialized(b)) => a == b,
(L::Bytes(a), L::Bytes(b)) => a.as_ptr() == b.as_ptr() && a.len() == b.len(),
_ => false,
}
}
}

impl<T: Clone> Debug for LazySerde<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Bytes(_) => write!(f, "lazy-serde<Bytes>"),
Self::Deserialized(_) => write!(f, "lazy-serde<T>"),
}
}
}

impl OpaqueColumnUdf {
pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn ColumnsUdf>>> {
match self {
Self::Deserialized(t) => Ok(t),
Self::Bytes(b) => {
feature_gated!("serde";"python", {
python_udf::PythonUdfExpression::try_deserialize(b.as_ref()).map(SpecialEq::new)
})
},
}
}
}

#[allow(clippy::derived_hash_with_manual_eq)]
impl Hash for Expr {
fn hash<H: Hasher>(&self, state: &mut H) {
Expand Down
37 changes: 26 additions & 11 deletions crates/polars-plan/src/dsl/expr_dyn_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ use std::ops::Deref;
use std::sync::Arc;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")]
use serde::{Deserializer, Serializer};
use serde::{Deserialize, Deserializer, Serialize, Serializer};

use super::*;

Expand All @@ -20,14 +18,6 @@ pub trait ColumnsUdf: Send + Sync {
fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")
}

// Needed for python functions. After they are deserialized we first check if they
// have a function that generates an output
// This will be slower during optimization, so it is up to us to move
// all expression to the known function architecture.
fn get_output(&self) -> Option<GetOutput> {
None
}
}

#[cfg(feature = "serde")]
Expand All @@ -46,6 +36,31 @@ impl Serialize for SpecialEq<Arc<dyn ColumnsUdf>> {
}

#[cfg(feature = "serde")]
impl<T: Serialize + Clone> Serialize for LazySerde<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Self::Deserialized(t) => t.serialize(serializer),
Self::Bytes(b) => serializer.serialize_bytes(b),
}
}
}

#[cfg(feature = "serde")]
impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'a>,
{
let buf = Vec::<u8>::deserialize(deserializer)?;
Ok(Self::Bytes(bytes::Bytes::from(buf)))
}
}

#[cfg(feature = "serde")]
// impl<T: Deserialize> Deserialize for crate::dsl::expr::LazySerde<T> {
impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn ColumnsUdf>> {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
Expand Down
16 changes: 8 additions & 8 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ where
let mut exprs = exprs.as_ref().to_vec();
exprs.push(acc);

let function = SpecialEq::new(Arc::new(move |columns: &mut [Column]| {
let function = new_column_udf(move |columns: &mut [Column]| {
let mut columns = columns.to_vec();
let mut acc = columns.pop().unwrap();

Expand All @@ -38,7 +38,7 @@ where
}
}
Ok(Some(acc))
}) as Arc<dyn ColumnsUdf>);
});

Expr::AnonymousFunction {
input: exprs,
Expand Down Expand Up @@ -67,7 +67,7 @@ where
{
let exprs = exprs.as_ref().to_vec();

let function = SpecialEq::new(Arc::new(move |columns: &mut [Column]| {
let function = new_column_udf(move |columns: &mut [Column]| {
let mut c_iter = columns.iter();

match c_iter.next() {
Expand All @@ -83,7 +83,7 @@ where
},
None => Err(polars_err!(ComputeError: "`reduce` did not have any expressions to fold")),
}
}) as Arc<dyn ColumnsUdf>);
});

Expr::AnonymousFunction {
input: exprs,
Expand All @@ -109,7 +109,7 @@ where
{
let exprs = exprs.as_ref().to_vec();

let function = SpecialEq::new(Arc::new(move |columns: &mut [Column]| {
let function = new_column_udf(move |columns: &mut [Column]| {
let mut c_iter = columns.iter();

match c_iter.next() {
Expand All @@ -131,7 +131,7 @@ where
},
None => Err(polars_err!(ComputeError: "`reduce` did not have any expressions to fold")),
}
}) as Arc<dyn ColumnsUdf>);
});

Expr::AnonymousFunction {
input: exprs,
Expand All @@ -158,7 +158,7 @@ where
let mut exprs = exprs.as_ref().to_vec();
exprs.push(acc);

let function = SpecialEq::new(Arc::new(move |columns: &mut [Column]| {
let function = new_column_udf(move |columns: &mut [Column]| {
let mut columns = columns.to_vec();
let mut acc = columns.pop().unwrap();

Expand All @@ -177,7 +177,7 @@ where
}

StructChunked::from_columns(acc.name().clone(), &result).map(|ca| Some(ca.into_column()))
}) as Arc<dyn ColumnsUdf>);
});

Expr::AnonymousFunction {
input: exprs,
Expand Down
18 changes: 9 additions & 9 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ impl Expr {

Expr::AnonymousFunction {
input: vec![self],
function: SpecialEq::new(Arc::new(f)),
function: new_column_udf(f),
output_type,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
Expand Down Expand Up @@ -582,7 +582,7 @@ impl Expr {

Expr::AnonymousFunction {
input,
function: SpecialEq::new(Arc::new(function)),
function: new_column_udf(function),
output_type,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
Expand All @@ -607,7 +607,7 @@ impl Expr {

Expr::AnonymousFunction {
input: vec![self],
function: SpecialEq::new(Arc::new(f)),
function: new_column_udf(f),
output_type,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyList,
Expand All @@ -631,7 +631,7 @@ impl Expr {

Expr::AnonymousFunction {
input: vec![self],
function: SpecialEq::new(Arc::new(f)),
function: new_column_udf(f),
output_type,
options,
}
Expand All @@ -654,7 +654,7 @@ impl Expr {

Expr::AnonymousFunction {
input: vec![self],
function: SpecialEq::new(Arc::new(f)),
function: new_column_udf(f),
output_type,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
Expand Down Expand Up @@ -687,7 +687,7 @@ impl Expr {

Expr::AnonymousFunction {
input,
function: SpecialEq::new(Arc::new(function)),
function: new_column_udf(function),
output_type,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
Expand Down Expand Up @@ -1983,7 +1983,7 @@ where

Expr::AnonymousFunction {
input,
function: SpecialEq::new(Arc::new(function)),
function: new_column_udf(function),
output_type,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
Expand All @@ -2009,7 +2009,7 @@ where

Expr::AnonymousFunction {
input,
function: SpecialEq::new(Arc::new(function)),
function: new_column_udf(function),
output_type,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyList,
Expand Down Expand Up @@ -2047,7 +2047,7 @@ where

Expr::AnonymousFunction {
input,
function: SpecialEq::new(Arc::new(function)),
function: new_column_udf(function),
output_type,
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
Expand Down
16 changes: 1 addition & 15 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,20 +214,6 @@ impl ColumnsUdf for PythonUdfExpression {
Ok(())
})
}

fn get_output(&self) -> Option<GetOutput> {
let output_type = self.output_type.clone();
Some(GetOutput::map_field(move |fld| {
Ok(match output_type {
Some(ref dt) => Field::new(fld.name().clone(), dt.clone()),
None => {
let mut fld = fld.clone();
fld.coerce(DataType::Unknown(Default::default()));
fld
},
})
}))
}
}

/// Serializable version of [`GetOutput`] for Python UDFs.
Expand Down Expand Up @@ -301,7 +287,7 @@ impl Expr {

Expr::AnonymousFunction {
input: vec![self],
function: SpecialEq::new(Arc::new(func)),
function: new_column_udf(func),
output_type,
options: FunctionOptions {
collect_groups,
Expand Down
10 changes: 4 additions & 6 deletions crates/polars-plan/src/dsl/udf.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::sync::Arc;

use arrow::legacy::error::{polars_bail, PolarsResult};
use polars_core::prelude::Field;
use polars_core::schema::Schema;
use polars_utils::pl_str::PlSmallStr;

use super::{ColumnsUdf, Expr, GetOutput, SpecialEq};
use crate::prelude::{Context, FunctionOptions};
use super::{ColumnsUdf, Expr, GetOutput, OpaqueColumnUdf};
use crate::prelude::{new_column_udf, Context, FunctionOptions};

/// Represents a user-defined function
#[derive(Clone)]
Expand All @@ -18,7 +16,7 @@ pub struct UserDefinedFunction {
/// The function output type.
pub return_type: GetOutput,
/// The function implementation.
pub fun: SpecialEq<Arc<dyn ColumnsUdf>>,
pub fun: OpaqueColumnUdf,
/// Options for the function.
pub options: FunctionOptions,
}
Expand Down Expand Up @@ -46,7 +44,7 @@ impl UserDefinedFunction {
name,
input_fields,
return_type,
fun: SpecialEq::new(Arc::new(fun)),
fun: new_column_udf(fun),
options: FunctionOptions::default(),
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/aexpr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ pub enum AExpr {
},
AnonymousFunction {
input: Vec<ExprIR>,
function: SpecialEq<Arc<dyn ColumnsUdf>>,
function: OpaqueColumnUdf,
output_type: GetOutput,
options: FunctionOptions,
},
Expand Down
3 changes: 0 additions & 3 deletions crates/polars-plan/src/plans/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,11 @@ impl AExpr {
AnonymousFunction {
output_type,
input,
function,
options,
..
} => {
*nested = nested
.saturating_sub(options.flags.contains(FunctionFlags::RETURNS_SCALAR) as _);
let tmp = function.get_output();
let output_type = tmp.as_ref().unwrap_or(output_type);
let fields = func_args_to_fields(input, schema, arena, nested)?;
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str);
output_type.get_field(schema, Context::Default, &fields)
Expand Down

0 comments on commit 35946cf

Please sign in to comment.