diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index c494a0c2..e4458e25 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -305,6 +305,20 @@ abstract class AbstractTypedColumn[T, U] */ def /(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, Double] = typed(self.untyped.divide(u)) + /** Returns a descending ordering used in sorting + * + * apache/spark + */ + def desc(implicit catalystOrdered: CatalystOrdered[U]): SortedTypedColumn[T, U] = + new SortedTypedColumn[T, U](untyped.desc) + + /** Returns an ascending ordering used in sorting + * + * apache/spark + */ + def asc(implicit catalystOrdered: CatalystOrdered[U]): SortedTypedColumn[T, U] = + new SortedTypedColumn[T, U](untyped.asc) + /** * Bitwise AND this expression and another expression. * {{{ @@ -602,6 +616,29 @@ abstract class AbstractTypedColumn[T, U] } +sealed class SortedTypedColumn[T, U](val expr: Expression)( + implicit + val uencoder: TypedEncoder[U] +) extends UntypedExpression[T] { + + def this(column: Column)(implicit e: TypedEncoder[U]) { + this(FramelessInternals.expr(column)) + } + + def untyped: Column = new Column(expr) +} + +object SortedTypedColumn { + implicit def defaultAscending[T, U : CatalystOrdered](typedColumn: TypedColumn[T, U]): SortedTypedColumn[T, U] = + new SortedTypedColumn[T, U](typedColumn.untyped.asc)(typedColumn.uencoder) + + object defaultAscendingPoly extends Poly1 { + implicit def caseTypedColumn[T, U : CatalystOrdered] = at[TypedColumn[T, U]](c => defaultAscending(c)) + implicit def caseTypeSortedColumn[T, U] = at[SortedTypedColumn[T, U]](identity) + } + } + + object TypedColumn { /** * Evidence that type `T` has column `K` with type `V`. diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index 7b05b0f2..f8e149d0 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -11,7 +11,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter} import org.apache.spark.sql.types.StructType import shapeless._ import shapeless.labelled.FieldType -import shapeless.ops.hlist.{Diff, IsHCons, Prepend, ToTraversable, Tupler} +import shapeless.ops.hlist.{Diff, IsHCons, Mapper, Prepend, ToTraversable, Tupler} import shapeless.ops.record.{Keys, Remover, Values} /** [[TypedDataset]] is a safer interface for working with `Dataset`. @@ -710,6 +710,80 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } } + /** Sort each partition in the dataset using the columns selected. */ + def sortWithinPartitions[A: CatalystOrdered](ca: SortedTypedColumn[T, A]): TypedDataset[T] = + sortWithinPartitionsMany(ca) + + /** Sort each partition in the dataset using the columns selected. */ + def sortWithinPartitions[A: CatalystOrdered, B: CatalystOrdered]( + ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B] + ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb) + + /** Sort each partition in the dataset using the columns selected. */ + def sortWithinPartitions[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered]( + ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B], + cc: SortedTypedColumn[T, C] + ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb, cc) + + /** Sort each partition in the dataset by the given column expressions + * Default sort order is ascending. + * {{{ + * d.sortWithinPartitionsMany(d('a), d('b).desc, d('c).asc) + * }}} + */ + object sortWithinPartitionsMany extends ProductArgs { + def applyProduct[U <: HList, O <: HList](columns: U) + (implicit + i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O], + i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]] + ): TypedDataset[T] = { + val sorted = dataset.toDF() + .sortWithinPartitions(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped):_*) + .as[T](TypedExpressionEncoder[T]) + + TypedDataset.create[T](sorted) + } + } + + /** Orders the TypedDataset using the column selected. */ + def orderBy[A: CatalystOrdered](ca: SortedTypedColumn[T, A]): TypedDataset[T] = + orderByMany(ca) + + /** Orders the TypedDataset using the columns selected. */ + def orderBy[A: CatalystOrdered, B: CatalystOrdered]( + ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B] + ): TypedDataset[T] = orderByMany(ca, cb) + + /** Orders the TypedDataset using the columns selected. */ + def orderBy[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered]( + ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B], + cc: SortedTypedColumn[T, C] + ): TypedDataset[T] = orderByMany(ca, cb, cc) + + /** Sort the dataset by any number of column expressions. + * Default sort order is ascending. + * {{{ + * d.orderByMany(d('a), d('b).desc, d('c).asc) + * }}} + */ + object orderByMany extends ProductArgs { + def applyProduct[U <: HList, O <: HList](columns: U) + (implicit + i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O], + i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]] + ): TypedDataset[T] = { + val sorted = dataset.toDF() + .orderBy(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped):_*) + .as[T](TypedExpressionEncoder[T]) + + TypedDataset.create[T](sorted) + } + } + /** Returns a new Dataset as a tuple with the specified * column dropped. * Does not allow for dropping from a single column TypedDataset diff --git a/dataset/src/test/scala/frameless/OrderByTests.scala b/dataset/src/test/scala/frameless/OrderByTests.scala new file mode 100644 index 00000000..839bd6f2 --- /dev/null +++ b/dataset/src/test/scala/frameless/OrderByTests.scala @@ -0,0 +1,161 @@ +package frameless + +import org.scalacheck.Prop +import org.scalacheck.Prop._ +import org.scalatest.Matchers +import shapeless.test.illTyped +import org.apache.spark.sql.Column + +class OrderByTests extends TypedDatasetSuite with Matchers { + def sortings[A : CatalystOrdered, T]: Seq[(TypedColumn[T, A] => SortedTypedColumn[T, A], Column => Column)] = Seq( + (_.desc, _.desc), + (_.asc, _.asc), + (t => t, t => t) //default ascending + ) + + test("single column non nullable orderBy") { + def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = { + val ds = TypedDataset.create(data) + + sortings[A, X1[A]].map { case (typ, untyp) => + ds.dataset.orderBy(untyp(ds.dataset.col("a"))).collect().toVector.?=( + ds.orderBy(typ(ds('a))).collect().run().toVector) + }.reduce(_ && _) + } + + check(forAll(prop[Int] _)) + check(forAll(prop[Boolean] _)) + check(forAll(prop[Byte] _)) + check(forAll(prop[Short] _)) + check(forAll(prop[Long] _)) + check(forAll(prop[Float] _)) + check(forAll(prop[Double] _)) + check(forAll(prop[SQLDate] _)) + check(forAll(prop[SQLTimestamp] _)) + check(forAll(prop[String] _)) + } + + test("single column non nullable partition sorting") { + def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = { + val ds = TypedDataset.create(data) + + sortings[A, X1[A]].map { case (typ, untyp) => + ds.dataset.sortWithinPartitions(untyp(ds.dataset.col("a"))).collect().toVector.?=( + ds.sortWithinPartitions(typ(ds('a))).collect().run().toVector) + }.reduce(_ && _) + } + + check(forAll(prop[Int] _)) + check(forAll(prop[Boolean] _)) + check(forAll(prop[Byte] _)) + check(forAll(prop[Short] _)) + check(forAll(prop[Long] _)) + check(forAll(prop[Float] _)) + check(forAll(prop[Double] _)) + check(forAll(prop[SQLDate] _)) + check(forAll(prop[SQLTimestamp] _)) + check(forAll(prop[String] _)) + } + + test("two columns non nullable orderBy") { + def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = { + val ds = TypedDataset.create(data) + + sortings[A, X2[A, B]].reverse.zip(sortings[B, X2[A, B]]).map { case ((typA, untypA), (typB, untypB)) => + val vanillaSpark = ds.dataset.orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))).collect().toVector + vanillaSpark.?=(ds.orderBy(typA(ds('a)), typB(ds('b))).collect().run().toVector).&&( + vanillaSpark ?= ds.orderByMany(typA(ds('a)), typB(ds('b))).collect().run().toVector + ) + }.reduce(_ && _) + } + + check(forAll(prop[SQLDate, Long] _)) + check(forAll(prop[String, Boolean] _)) + check(forAll(prop[SQLTimestamp, Long] _)) + } + + test("two columns non nullable partition sorting") { + def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = { + val ds = TypedDataset.create(data) + + sortings[A, X2[A, B]].reverse.zip(sortings[B, X2[A, B]]).map { case ((typA, untypA), (typB, untypB)) => + val vanillaSpark = ds.dataset.sortWithinPartitions(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))).collect().toVector + vanillaSpark.?=(ds.sortWithinPartitions(typA(ds('a)), typB(ds('b))).collect().run().toVector).&&( + vanillaSpark ?= ds.sortWithinPartitionsMany(typA(ds('a)), typB(ds('b))).collect().run().toVector + ) + }.reduce(_ && _) + } + + check(forAll(prop[SQLDate, Long] _)) + check(forAll(prop[String, Boolean] _)) + check(forAll(prop[SQLTimestamp, Long] _)) + } + + test("three columns non nullable orderBy") { + def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X3[A,B,A]]): Prop = { + val ds = TypedDataset.create(data) + + sortings[A, X3[A, B, A]].reverse + .zip(sortings[B, X3[A, B, A]]) + .zip(sortings[A, X3[A, B, A]]) + .map { case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) => + val vanillaSpark = ds.dataset + .orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")), untypA2(ds.dataset.col("c"))) + .collect().toVector + + vanillaSpark.?=(ds.orderBy(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector).&&( + vanillaSpark ?= ds.orderByMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector + ) + }.reduce(_ && _) + } + + check(forAll(prop[SQLDate, Long] _)) + check(forAll(prop[String, Boolean] _)) + check(forAll(prop[SQLTimestamp, Long] _)) + } + + test("three columns non nullable partition sorting") { + def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X3[A,B,A]]): Prop = { + val ds = TypedDataset.create(data) + + sortings[A, X3[A, B, A]].reverse + .zip(sortings[B, X3[A, B, A]]) + .zip(sortings[A, X3[A, B, A]]) + .map { case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) => + val vanillaSpark = ds.dataset + .sortWithinPartitions(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")), untypA2(ds.dataset.col("c"))) + .collect().toVector + + vanillaSpark.?=(ds.sortWithinPartitions(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector).&&( + vanillaSpark ?= ds.sortWithinPartitionsMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector + ) + }.reduce(_ && _) + } + + check(forAll(prop[SQLDate, Long] _)) + check(forAll(prop[String, Boolean] _)) + check(forAll(prop[SQLTimestamp, Long] _)) + } + + test("sort support for mixed default and explicit ordering") { + def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A, B]]): Prop = { + val ds = TypedDataset.create(data) + + ds.dataset.orderBy(ds.dataset.col("a"), ds.dataset.col("b").desc).collect().toVector.?=( + ds.orderByMany(ds('a), ds('b).desc).collect().run().toVector) && + ds.dataset.sortWithinPartitions(ds.dataset.col("a"), ds.dataset.col("b").desc).collect().toVector.?=( + ds.sortWithinPartitionsMany(ds('a), ds('b).desc).collect().run().toVector) + } + + check(forAll(prop[SQLDate, Long] _)) + check(forAll(prop[String, Boolean] _)) + check(forAll(prop[SQLTimestamp, Long] _)) + } + + test("fail when selected column is not sortable") { + val d = TypedDataset.create(X2(1, Map(1 -> 2)) :: X2(2, Map(2 -> 2)) :: Nil) + d.orderBy(d('a).desc) + illTyped("""d.orderBy(d('b).desc)""") + illTyped("""d.sortWithinPartitions(d('b).desc)""") + } +} \ No newline at end of file