Skip to content

Commit

Permalink
Fix order-preservation in cudf-polars groupby
Browse files Browse the repository at this point in the history
When we are requested to maintain order in groupby aggregations we
must post-process the result by computing a permutation between the
wanted order (of the input keys) and the order returned by the groupby
aggregation. To do this, we can perform a join between the two unique
key tables. Previously, we assumed that the gather map returned in
this join for the left (wanted order) table was the identity. However,
this is not guaranteed, in addition to computing the match between the
wanted key order and the key order we have, we must also apply the
permutation between the left gather map order and the identity.

- Closes #16893
  • Loading branch information
wence- committed Sep 27, 2024
1 parent 0632538 commit 1199246
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
31 changes: 23 additions & 8 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,24 +603,39 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests
]
broadcasted = broadcast(*result_keys, *results)
result_keys = broadcasted[: len(result_keys)]
results = broadcasted[len(result_keys) :]
# Handle order preservation of groups
# like cudf classic does
# https://github.com/rapidsai/cudf/blob/5780c4d8fb5afac2e04988a2ff5531f94c22d3a3/python/cudf/cudf/core/groupby/groupby.py#L723-L743
if self.maintain_order and not sorted:
left = plc.stream_compaction.stable_distinct(
# The order we want
want = plc.stream_compaction.stable_distinct(
plc.Table([k.obj for k in keys]),
list(range(group_keys.num_columns())),
plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
right = plc.Table([key.obj for key in result_keys])
_, indices = plc.join.left_join(left, right, plc.types.NullEquality.EQUAL)
# The order we have
have = plc.Table([key.obj for key in broadcasted[: len(keys)]])

# We know an inner join is OK because by construction
# want and have are permutations of each other.
left_order, right_order = plc.join.inner_join(
want, have, plc.types.NullEquality.EQUAL
)
# Now left_order is an arbitrary permutation of the ordering we
# want, and right_order is a matching permutation of the ordering
# we have. To get to the original ordering, we need
# left_order == iota(nrows), with right_order permuted
# appropriately. This can be obtained by sorting
# right_order by left_order.
(right_order,) = plc.sorting.sort_by_key(
plc.Table([right_order]),
plc.Table([left_order]),
[plc.types.Order.ASCENDING],
[plc.types.NullOrder.AFTER],
).columns()
ordered_table = plc.copying.gather(
plc.Table([col.obj for col in broadcasted]),
indices,
right_order,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
)
broadcasted = [
Expand Down
22 changes: 22 additions & 0 deletions python/cudf_polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import itertools

import numpy as np
import pytest

import polars as pl
Expand Down Expand Up @@ -191,3 +192,24 @@ def test_groupby_literal_in_agg(df, key, expr):
def test_groupby_unary_non_pointwise_raises(df, expr):
q = df.group_by("key1").agg(expr)
assert_ir_translation_raises(q, NotImplementedError)


@pytest.mark.parametrize("nrows", [30, 300, 300_000])
@pytest.mark.parametrize("nkeys", [1, 2, 4])
def test_groupby_maintain_order_random(nrows, nkeys, with_nulls):
key_names = [f"key{key}" for key in range(nkeys)]
key_values = [np.random.randint(100, size=nrows) for _ in key_names]
value = np.random.randint(-100, 100, size=nrows)
df = pl.DataFrame(dict(zip(key_names, key_values, strict=True), value=value))
if with_nulls:
df = df.with_columns(
*(
pl.when(pl.col(name) == 1)
.then(None)
.otherwise(pl.col(name))
.alias(name)
for name in key_names
)
)
q = df.lazy().group_by(key_names, maintain_order=True).agg(pl.col("value").sum())
assert_gpu_result_equal(q)

0 comments on commit 1199246

Please sign in to comment.