Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Chain length methods #4166

Merged
merged 7 commits into from
Apr 9, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions core/src/main/scala-2.12/cats/compat/ChainCompat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package cats.data

private[data] trait ChainCompat[+A] { _: Chain[A] =>

/**
* The number of elements in this chain, if it can be cheaply computed, -1 otherwise.
* Cheaply usually means: Not requiring a collection traversal.
*/
final def knownSize: Long =
this match {
case Chain.Empty => 0
case Chain.Singleton(_) => 1
case _ => -1
}
}
17 changes: 17 additions & 0 deletions core/src/main/scala-2.13+/cats/data/ChainCompat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package cats
package data

private[data] trait ChainCompat[+A] { _: Chain[A] =>

/**
* The number of elements in this chain, if it can be cheaply computed, -1 otherwise.
* Cheaply usually means: Not requiring a collection traversal.
*/
final def knownSize: Long =
this match {
case Chain.Empty => 0
case Chain.Singleton(_) => 1
case Chain.Wrap(seq) => seq.knownSize.toLong
case _ => -1
}
}
34 changes: 11 additions & 23 deletions core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import Chain.{
* O(1) `uncons`, such that walking the sequence via N successive `uncons`
* steps takes O(N).
*/
sealed abstract class Chain[+A] {
sealed abstract class Chain[+A] extends ChainCompat[A] {

/**
* Returns the head and tail of this Chain if non empty, none otherwise. Amortized O(1).
Expand Down Expand Up @@ -565,33 +565,21 @@ sealed abstract class Chain[+A] {
/**
* Returns the number of elements in this structure
*/
final def length: Long = {
// TODO: consider optimizing for `Chain.Wrap` case.
// Some underlying seq may not need enumerating all elements to calculate its size.
val iter = iterator
var i: Long = 0
while (iter.hasNext) { i += 1; iter.next(); }
i
}
final def length: Long =
Copy link
Contributor

@johnynek johnynek Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I think what we really might want is:

def foldMap[B](fn: A => B)(implicit B: Monoid[B]): B = {
  @annotation.tailrec
  def loop(chains: List[Chain[A]], acc: B): B =
    chains match {
      case Nil => acc
      case h :: tail =>
        h match {
          case Empty => loop(tail, acc)
          case Wrap(seq) => loop(tail, B.combine(B.combineAll(acc, seq.iterator.map(fn))))
          case Singleton(a) => loop(tail, B.combine(acc, fn(a)))
          case Append(l, r) => loop(l :: r :: tail, acc)
       }
  }
  // we need to be careful and test this with a non-commutative monoid to be sure we get the order
  // right compared to `toList.foldMap`
  loop(this :: Nil, B.empty)
}

That is stack safe and leverages associtivity (and we can override the foldMap in the Foldable instance).

With that, you can do: def length: Long = foldMap(Function.const(1L))

Or you could write a custom loop based on the same idea and make it stack safe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: using such foldMap for calculating length will undermine the original efforts for Wrap case optimization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...although the method foldMap itself is nice to have for sure.

Copy link
Contributor

@johnynek johnynek Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I mean you can make 3 specialized methods: combineAll, foldMap, length and each of them can have their own code.

If you just care about length right now, you can just add that one and change the Wrap(seq) line to be loop(tail, acc + seq.length.toLong) and the Singleton(_) line to be loop(tail, acc + 1L)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: is it going to be faster than a stack-safe implementation based on Eval? I know Eval uses closures, but this code also does some additional allocations of List items all the way while recursing. Wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eval also allocates lists to manage a stack. I think this will be faster since it won't also allocate and call lambdas.

Copy link
Contributor Author

@bplommer bplommer Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option (what did I say about getting sucked down rabbit-holes...)

   private def foldMapCase[B](f: A => B, g: Seq[A] => B)(implicit M: Monoid[B]): B = {
    @tailrec def loop(chains: List[Chain[A]], acc: B): B = chains match {
      case Nil => acc
      case h :: t =>
        h match {
          case Empty        => loop(t, acc)
          case Singleton(a) => loop(t, M.combine(f(a), acc))
          case Wrap(as)     => loop(t, M.combine(g(as), acc))
          case Append(l, r) => loop(l :: r :: t, acc)
        }
    }

    loop(this :: Nil, M.empty)
  }

  final def foldMap[B](f: A => B)(implicit M: Monoid[B]): B = foldMapCase(f, Foldable[Seq].foldMap(_)(f))

  final def length: Long = foldMapCase(_ => 1L, _.length.toLong)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't sweat it. Just copy the code. The inlined copied code will also be more efficient since there is no Long boxing.

We prefer a copy of the code if we can get efficiency wins in library code. We aren't trying to make the internals of cats as beautiful as possible. IMO, we want the API to be beautiful, then we want the library to be stack safe, then we want it to be fast, then we want it to be implemented in a beautiful fashion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking more about this, I am not sure we actually want to implement foldMap the way I suggested. Doing so would likely defeat the optimizations of Monoid.combineAll. I think the right way is just Monoid.combineAll(toIterator.map(fn)) which is free to use internal mutability on the monoid. If we implement the way I suggested above, we force the monoid to continue to concatenate (which could make things like string concatenation quadratic vs linear).

this match {
case Empty => 0
case Singleton(_) => 1
case Wrap(seq) => seq.length.toLong

/**
// TODO: consider implementing this case as a stack-safe recursion.
case Append(_, _) => iterator.length.toLong
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will undo the optimization if there is any nesting of Wrap. I think since I've already written the code, it isn't such a heavy lift to copy and paste that version in (with the modifications).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ends up being implemented by counting through the iterator.

}

/*
* Alias for length
*/
final def size: Long = length

/**
* The number of elements in this chain, if it can be cheaply computed, -1 otherwise.
* Cheaply usually means: Not requiring a collection traversal.
*/
final def knownSize: Long =
// TODO: consider optimizing for `Chain.Wrap` case – call the underlying `knownSize` method.
// Note that `knownSize` was introduced since Scala 2.13 only.
this match {
case _ if isEmpty => 0
case Chain.Singleton(_) => 1
case _ => -1
}

/**
* Compares the length of this chain to a test value.
*
Expand Down