Skip to content

Commit

Permalink
Sort columns (#236)
Browse files Browse the repository at this point in the history
  • Loading branch information
frosforever authored and imarios committed Feb 9, 2018
1 parent 576eb67 commit ba9abbe
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 1 deletion.
37 changes: 37 additions & 0 deletions dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,20 @@ abstract class AbstractTypedColumn[T, U]
*/
def /(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, Double] = typed(self.untyped.divide(u))

/** Returns a descending ordering used in sorting
*
* apache/spark
*/
def desc(implicit catalystOrdered: CatalystOrdered[U]): SortedTypedColumn[T, U] =
new SortedTypedColumn[T, U](untyped.desc)

/** Returns an ascending ordering used in sorting
*
* apache/spark
*/
def asc(implicit catalystOrdered: CatalystOrdered[U]): SortedTypedColumn[T, U] =
new SortedTypedColumn[T, U](untyped.asc)

/**
* Bitwise AND this expression and another expression.
* {{{
Expand Down Expand Up @@ -602,6 +616,29 @@ abstract class AbstractTypedColumn[T, U]
}


sealed class SortedTypedColumn[T, U](val expr: Expression)(
implicit
val uencoder: TypedEncoder[U]
) extends UntypedExpression[T] {

def this(column: Column)(implicit e: TypedEncoder[U]) {
this(FramelessInternals.expr(column))
}

def untyped: Column = new Column(expr)
}

object SortedTypedColumn {
implicit def defaultAscending[T, U : CatalystOrdered](typedColumn: TypedColumn[T, U]): SortedTypedColumn[T, U] =
new SortedTypedColumn[T, U](typedColumn.untyped.asc)(typedColumn.uencoder)

object defaultAscendingPoly extends Poly1 {
implicit def caseTypedColumn[T, U : CatalystOrdered] = at[TypedColumn[T, U]](c => defaultAscending(c))
implicit def caseTypeSortedColumn[T, U] = at[SortedTypedColumn[T, U]](identity)
}
}


object TypedColumn {
/**
* Evidence that type `T` has column `K` with type `V`.
Expand Down
76 changes: 75 additions & 1 deletion dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter}
import org.apache.spark.sql.types.StructType
import shapeless._
import shapeless.labelled.FieldType
import shapeless.ops.hlist.{Diff, IsHCons, Prepend, ToTraversable, Tupler}
import shapeless.ops.hlist.{Diff, IsHCons, Mapper, Prepend, ToTraversable, Tupler}
import shapeless.ops.record.{Keys, Remover, Values}

/** [[TypedDataset]] is a safer interface for working with `Dataset`.
Expand Down Expand Up @@ -710,6 +710,80 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
}
}

/** Sort each partition in the dataset using the columns selected. */
def sortWithinPartitions[A: CatalystOrdered](ca: SortedTypedColumn[T, A]): TypedDataset[T] =
sortWithinPartitionsMany(ca)

/** Sort each partition in the dataset using the columns selected. */
def sortWithinPartitions[A: CatalystOrdered, B: CatalystOrdered](
ca: SortedTypedColumn[T, A],
cb: SortedTypedColumn[T, B]
): TypedDataset[T] = sortWithinPartitionsMany(ca, cb)

/** Sort each partition in the dataset using the columns selected. */
def sortWithinPartitions[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered](
ca: SortedTypedColumn[T, A],
cb: SortedTypedColumn[T, B],
cc: SortedTypedColumn[T, C]
): TypedDataset[T] = sortWithinPartitionsMany(ca, cb, cc)

/** Sort each partition in the dataset by the given column expressions
* Default sort order is ascending.
* {{{
* d.sortWithinPartitionsMany(d('a), d('b).desc, d('c).asc)
* }}}
*/
object sortWithinPartitionsMany extends ProductArgs {
def applyProduct[U <: HList, O <: HList](columns: U)
(implicit
i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O],
i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]]
): TypedDataset[T] = {
val sorted = dataset.toDF()
.sortWithinPartitions(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped):_*)
.as[T](TypedExpressionEncoder[T])

TypedDataset.create[T](sorted)
}
}

/** Orders the TypedDataset using the column selected. */
def orderBy[A: CatalystOrdered](ca: SortedTypedColumn[T, A]): TypedDataset[T] =
orderByMany(ca)

/** Orders the TypedDataset using the columns selected. */
def orderBy[A: CatalystOrdered, B: CatalystOrdered](
ca: SortedTypedColumn[T, A],
cb: SortedTypedColumn[T, B]
): TypedDataset[T] = orderByMany(ca, cb)

/** Orders the TypedDataset using the columns selected. */
def orderBy[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered](
ca: SortedTypedColumn[T, A],
cb: SortedTypedColumn[T, B],
cc: SortedTypedColumn[T, C]
): TypedDataset[T] = orderByMany(ca, cb, cc)

/** Sort the dataset by any number of column expressions.
* Default sort order is ascending.
* {{{
* d.orderByMany(d('a), d('b).desc, d('c).asc)
* }}}
*/
object orderByMany extends ProductArgs {
def applyProduct[U <: HList, O <: HList](columns: U)
(implicit
i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O],
i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]]
): TypedDataset[T] = {
val sorted = dataset.toDF()
.orderBy(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped):_*)
.as[T](TypedExpressionEncoder[T])

TypedDataset.create[T](sorted)
}
}

/** Returns a new Dataset as a tuple with the specified
* column dropped.
* Does not allow for dropping from a single column TypedDataset
Expand Down
161 changes: 161 additions & 0 deletions dataset/src/test/scala/frameless/OrderByTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package frameless

import org.scalacheck.Prop
import org.scalacheck.Prop._
import org.scalatest.Matchers
import shapeless.test.illTyped
import org.apache.spark.sql.Column

class OrderByTests extends TypedDatasetSuite with Matchers {
def sortings[A : CatalystOrdered, T]: Seq[(TypedColumn[T, A] => SortedTypedColumn[T, A], Column => Column)] = Seq(
(_.desc, _.desc),
(_.asc, _.asc),
(t => t, t => t) //default ascending
)

test("single column non nullable orderBy") {
def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = {
val ds = TypedDataset.create(data)

sortings[A, X1[A]].map { case (typ, untyp) =>
ds.dataset.orderBy(untyp(ds.dataset.col("a"))).collect().toVector.?=(
ds.orderBy(typ(ds('a))).collect().run().toVector)
}.reduce(_ && _)
}

check(forAll(prop[Int] _))
check(forAll(prop[Boolean] _))
check(forAll(prop[Byte] _))
check(forAll(prop[Short] _))
check(forAll(prop[Long] _))
check(forAll(prop[Float] _))
check(forAll(prop[Double] _))
check(forAll(prop[SQLDate] _))
check(forAll(prop[SQLTimestamp] _))
check(forAll(prop[String] _))
}

test("single column non nullable partition sorting") {
def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = {
val ds = TypedDataset.create(data)

sortings[A, X1[A]].map { case (typ, untyp) =>
ds.dataset.sortWithinPartitions(untyp(ds.dataset.col("a"))).collect().toVector.?=(
ds.sortWithinPartitions(typ(ds('a))).collect().run().toVector)
}.reduce(_ && _)
}

check(forAll(prop[Int] _))
check(forAll(prop[Boolean] _))
check(forAll(prop[Byte] _))
check(forAll(prop[Short] _))
check(forAll(prop[Long] _))
check(forAll(prop[Float] _))
check(forAll(prop[Double] _))
check(forAll(prop[SQLDate] _))
check(forAll(prop[SQLTimestamp] _))
check(forAll(prop[String] _))
}

test("two columns non nullable orderBy") {
def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = {
val ds = TypedDataset.create(data)

sortings[A, X2[A, B]].reverse.zip(sortings[B, X2[A, B]]).map { case ((typA, untypA), (typB, untypB)) =>
val vanillaSpark = ds.dataset.orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))).collect().toVector
vanillaSpark.?=(ds.orderBy(typA(ds('a)), typB(ds('b))).collect().run().toVector).&&(
vanillaSpark ?= ds.orderByMany(typA(ds('a)), typB(ds('b))).collect().run().toVector
)
}.reduce(_ && _)
}

check(forAll(prop[SQLDate, Long] _))
check(forAll(prop[String, Boolean] _))
check(forAll(prop[SQLTimestamp, Long] _))
}

test("two columns non nullable partition sorting") {
def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = {
val ds = TypedDataset.create(data)

sortings[A, X2[A, B]].reverse.zip(sortings[B, X2[A, B]]).map { case ((typA, untypA), (typB, untypB)) =>
val vanillaSpark = ds.dataset.sortWithinPartitions(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))).collect().toVector
vanillaSpark.?=(ds.sortWithinPartitions(typA(ds('a)), typB(ds('b))).collect().run().toVector).&&(
vanillaSpark ?= ds.sortWithinPartitionsMany(typA(ds('a)), typB(ds('b))).collect().run().toVector
)
}.reduce(_ && _)
}

check(forAll(prop[SQLDate, Long] _))
check(forAll(prop[String, Boolean] _))
check(forAll(prop[SQLTimestamp, Long] _))
}

test("three columns non nullable orderBy") {
def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X3[A,B,A]]): Prop = {
val ds = TypedDataset.create(data)

sortings[A, X3[A, B, A]].reverse
.zip(sortings[B, X3[A, B, A]])
.zip(sortings[A, X3[A, B, A]])
.map { case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) =>
val vanillaSpark = ds.dataset
.orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")), untypA2(ds.dataset.col("c")))
.collect().toVector

vanillaSpark.?=(ds.orderBy(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector).&&(
vanillaSpark ?= ds.orderByMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector
)
}.reduce(_ && _)
}

check(forAll(prop[SQLDate, Long] _))
check(forAll(prop[String, Boolean] _))
check(forAll(prop[SQLTimestamp, Long] _))
}

test("three columns non nullable partition sorting") {
def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X3[A,B,A]]): Prop = {
val ds = TypedDataset.create(data)

sortings[A, X3[A, B, A]].reverse
.zip(sortings[B, X3[A, B, A]])
.zip(sortings[A, X3[A, B, A]])
.map { case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) =>
val vanillaSpark = ds.dataset
.sortWithinPartitions(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")), untypA2(ds.dataset.col("c")))
.collect().toVector

vanillaSpark.?=(ds.sortWithinPartitions(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector).&&(
vanillaSpark ?= ds.sortWithinPartitionsMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector
)
}.reduce(_ && _)
}

check(forAll(prop[SQLDate, Long] _))
check(forAll(prop[String, Boolean] _))
check(forAll(prop[SQLTimestamp, Long] _))
}

test("sort support for mixed default and explicit ordering") {
def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A, B]]): Prop = {
val ds = TypedDataset.create(data)

ds.dataset.orderBy(ds.dataset.col("a"), ds.dataset.col("b").desc).collect().toVector.?=(
ds.orderByMany(ds('a), ds('b).desc).collect().run().toVector) &&
ds.dataset.sortWithinPartitions(ds.dataset.col("a"), ds.dataset.col("b").desc).collect().toVector.?=(
ds.sortWithinPartitionsMany(ds('a), ds('b).desc).collect().run().toVector)
}

check(forAll(prop[SQLDate, Long] _))
check(forAll(prop[String, Boolean] _))
check(forAll(prop[SQLTimestamp, Long] _))
}

test("fail when selected column is not sortable") {
val d = TypedDataset.create(X2(1, Map(1 -> 2)) :: X2(2, Map(2 -> 2)) :: Nil)
d.orderBy(d('a).desc)
illTyped("""d.orderBy(d('b).desc)""")
illTyped("""d.sortWithinPartitions(d('b).desc)""")
}
}

0 comments on commit ba9abbe

Please sign in to comment.