diff --git a/compiler/src/dotty/tools/dotc/cc/Capability.scala b/compiler/src/dotty/tools/dotc/cc/Capability.scala index 5c1de33aea0e..7b4b608ca012 100644 --- a/compiler/src/dotty/tools/dotc/cc/Capability.scala +++ b/compiler/src/dotty/tools/dotc/cc/Capability.scala @@ -482,27 +482,30 @@ object Capabilities: case info: OrType => viaInfo(info.tp1)(test) && viaInfo(info.tp2)(test) case _ => false + def trySubpath(y: TermRef, trySamePrefix: Boolean = true): Boolean = + y.prefix.match + case ypre: Capability => + this.subsumes(ypre) + || trySamePrefix + && this.match + case x @ TermRef(xpre: Capability, _) if x.symbol == y.symbol => + // To show `{x.f} <:< {y.f}`, it is important to prove `x` and `y` + // are equvalent, which means `x =:= y` in terms of subtyping, + // not just `{x} =:= {y}` in terms of subcapturing. + // It is possible to construct two singleton types `x` and `y`, + // which subsume each other, but are not equal references. + // See `tests/neg-custom-args/captures/path-prefix.scala` for example. + withMode(Mode.IgnoreCaptures): + TypeComparer.isSameRef(xpre, ypre) + case _ => + false + case _ => false + try (this eq y) || maxSubsumes(y, canAddHidden = !vs.isOpen) || y.match case y: TermRef => - y.prefix.match - case ypre: Capability => - this.subsumes(ypre) - || this.match - case x @ TermRef(xpre: Capability, _) if x.symbol == y.symbol => - // To show `{x.f} <:< {y.f}`, it is important to prove `x` and `y` - // are equvalent, which means `x =:= y` in terms of subtyping, - // not just `{x} =:= {y}` in terms of subcapturing. - // It is possible to construct two singleton types `x` and `y`, - // which subsume each other, but are not equal references. - // See `tests/neg-custom-args/captures/path-prefix.scala` for example. - withMode(Mode.IgnoreCaptures): - TypeComparer.isSameRef(xpre, ypre) - case _ => - false - case _ => false - || viaInfo(y.info)(subsumingRefs(this, _)) + trySubpath(y) || viaInfo(y.info)(subsumingRefs(this, _)) case Maybe(y1) => this.stripMaybe.subsumes(y1) case ReadOnly(y1) => this.stripReadOnly.subsumes(y1) case y: TypeRef if y.derivesFrom(defn.Caps_CapSet) => @@ -516,6 +519,12 @@ object Capabilities: this.subsumes(hi) case _ => y.captureSetOfInfo.elems.forall(this.subsumes) + case Reach(y1: TermRef) => + def isClassFunctionParam: Boolean = + def isClassParam = y1.symbol.is(ParamAccessor) + def isFunctionType = defn.isFunctionType(y1.widenDealias) + isClassParam && isFunctionType + isClassFunctionParam && trySubpath(y1, trySamePrefix = false) case _ => false || this.match case Reach(x1) => x1.subsumes(y.stripReach) diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index dccbd0a005d7..e6103d5af391 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -766,21 +766,7 @@ class CheckCaptures extends Recheck, SymTransformer: val appType = resultToFresh( super.recheckApplication(tree, qualType, funType, argTypes), Origin.ResultInstance(funType, tree.symbol)) - val qualCaptures = qualType.captureSet - val argCaptures = - for (argType, formal) <- argTypes.lazyZip(funType.paramInfos) yield - if formal.hasAnnotation(defn.UseAnnot) then argType.deepCaptureSet else argType.captureSet - appType match - case appType @ CapturingType(appType1, refs) - if qualType.exists - && !tree.fun.symbol.isConstructor - && qualCaptures.mightSubcapture(refs) - && argCaptures.forall(_.mightSubcapture(refs)) => - val callCaptures = argCaptures.foldLeft(qualCaptures)(_ ++ _) - appType.derivedCapturingType(appType1, callCaptures) - .showing(i"narrow $tree: $appType, refs = $refs, qual-cs = ${qualType.captureSet} = $result", capt) - case appType => - appType + appType private def isDistinct(xs: List[Type]): Boolean = xs match case x :: xs1 => xs1.isEmpty || !xs1.contains(x) && isDistinct(xs1) @@ -832,8 +818,12 @@ class CheckCaptures extends Recheck, SymTransformer: for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do val getter = cls.info.member(getterName).suchThat(_.isRefiningParamAccessor).symbol if !getter.is(Private) && getter.hasTrackedParts then - refined = refined.refinedOverride(getterName, argType.unboxed) // Yichen you might want to check this - allCaptures ++= argType.captureSet + refined = refined.refinedOverride(getterName, argType.unboxed) // TODO: This looks unsound. + // Try to find an counter-example. + if defn.isFunctionType(argType.widenDealias) then + allCaptures ++= argType.deepCaptureSet + else + allCaptures ++= argType.captureSet (refined, allCaptures) /** Augment result type of constructor with refinements and captures. diff --git a/scala2-library-cc/src/scala/collection/Iterable.scala b/scala2-library-cc/src/scala/collection/Iterable.scala index c5d10211e3ab..017e0ab3049c 100644 --- a/scala2-library-cc/src/scala/collection/Iterable.scala +++ b/scala2-library-cc/src/scala/collection/Iterable.scala @@ -684,9 +684,9 @@ trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with Iterable def map[B](f: A => B): CC[B]^{this, f} = iterableFactory.from(new View.Map(this, f)) - def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} = iterableFactory.from(new View.FlatMap(this, f)) + def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f*} = iterableFactory.from(new View.FlatMap(this, f)) - def flatten[B](implicit asIterable: A -> IterableOnce[B]): CC[B]^{this} = flatMap(asIterable) + def flatten[B](implicit asIterable: A -> IterableOnce[B]): CC[B]^{this, asIterable*} = flatMap(asIterable) def collect[B](pf: PartialFunction[A, B]^): CC[B]^{this, pf} = iterableFactory.from(new View.Collect(this, pf)) @@ -911,7 +911,7 @@ object IterableOps { def map[B](f: A => B): CC[B]^{this, f} = self.iterableFactory.from(new View.Map(filtered, f)) - def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} = + def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f*} = self.iterableFactory.from(new View.FlatMap(filtered, f)) def foreach[U](f: A => U): Unit = filtered.foreach(f) diff --git a/scala2-library-cc/src/scala/collection/IterableOnce.scala b/scala2-library-cc/src/scala/collection/IterableOnce.scala index 7ea62a9e1a65..91b3f8e51f6b 100644 --- a/scala2-library-cc/src/scala/collection/IterableOnce.scala +++ b/scala2-library-cc/src/scala/collection/IterableOnce.scala @@ -246,7 +246,7 @@ final class IterableOnceExtensionMethods[A](private val it: IterableOnce[A]) ext } @deprecated("Use .iterator.flatMap instead or consider requiring an Iterable", "2.13.0") - def flatMap[B](f: A => IterableOnce[B]^): IterableOnce[B]^{f} = it match { + def flatMap[B](f: A => IterableOnce[B]^): IterableOnce[B]^{f*} = it match { case it: Iterable[A] => it.flatMap(f) case _ => it.iterator.flatMap(f) } @@ -439,7 +439,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ => * @return a new $coll resulting from applying the given collection-valued function * `f` to each element of this $coll and concatenating the results. */ - def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} + def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f*} /** Converts this $coll of iterable collections into * a $coll formed by the elements of these iterable diff --git a/scala2-library-cc/src/scala/collection/Iterator.scala b/scala2-library-cc/src/scala/collection/Iterator.scala index 91a22caa288c..08fae1ec8df5 100644 --- a/scala2-library-cc/src/scala/collection/Iterator.scala +++ b/scala2-library-cc/src/scala/collection/Iterator.scala @@ -588,8 +588,8 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite def next() = f(self.next()) } - def flatMap[B](f: A => IterableOnce[B]^): Iterator[B]^{this, f} = new AbstractIterator[B] { - private[this] var cur: Iterator[B]^{f} = Iterator.empty + def flatMap[B](f: A => IterableOnce[B]^): Iterator[B]^{this, f*} = new AbstractIterator[B] { + private[this] var cur: Iterator[B]^{f*} = Iterator.empty /** Trillium logic boolean: -1 = unknown, 0 = false, 1 = true */ private[this] var _hasNext: Int = -1 @@ -623,7 +623,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite } } - def flatten[B](implicit ev: A -> IterableOnce[B]): Iterator[B]^{this} = + def flatten[B](implicit ev: A -> IterableOnce[B]): Iterator[B]^{this, ev*} = flatMap[B](ev) def concat[B >: A](xs: => IterableOnce[B]^): Iterator[B]^{this, xs} = new Iterator.ConcatIterator[B](self).concat(xs) diff --git a/scala2-library-cc/src/scala/collection/View.scala b/scala2-library-cc/src/scala/collection/View.scala index 72a073836e77..fe85ed3ef77b 100644 --- a/scala2-library-cc/src/scala/collection/View.scala +++ b/scala2-library-cc/src/scala/collection/View.scala @@ -57,8 +57,9 @@ object View extends IterableFactory[View] { * * @tparam A View element type */ - def fromIteratorProvider[A](it: () => Iterator[A]^): View[A]^{it} = new AbstractView[A] { - def iterator: Iterator[A]^{it} = it() + def fromIteratorProvider[A](it: () => Iterator[A]^): View[A]^{it*} = new AbstractView[A] { // TODO: this seems clearly unsound: not only `it*` but also `it` is used + // why it capture-checks? + def iterator: Iterator[A]^{it*} = it() } /** @@ -310,7 +311,7 @@ object View extends IterableFactory[View] { /** A view that flatmaps elements of the underlying collection. */ @SerialVersionUID(3L) class FlatMap[A, B](underlying: SomeIterableOps[A]^, f: A => IterableOnce[B]^) extends AbstractView[B] { - def iterator: Iterator[B]^{underlying, f} = underlying.iterator.flatMap(f) + def iterator: Iterator[B]^{underlying, f*} = underlying.iterator.flatMap(f) override def knownSize: Int = if (underlying.knownSize == 0) 0 else super.knownSize override def isEmpty: Boolean = iterator.isEmpty } diff --git a/scala2-library-cc/src/scala/collection/WithFilter.scala b/scala2-library-cc/src/scala/collection/WithFilter.scala index a2255a8cc0c5..0f4a033e0813 100644 --- a/scala2-library-cc/src/scala/collection/WithFilter.scala +++ b/scala2-library-cc/src/scala/collection/WithFilter.scala @@ -45,7 +45,7 @@ abstract class WithFilter[+A, +CC[_]] extends Serializable { * of the filtered outer $coll and * concatenating the results. */ - def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} + def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f*} /** Applies a function `f` to all elements of the `filtered` outer $coll. * diff --git a/scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala b/scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala index 726b011c6929..93fcc26852dd 100644 --- a/scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala +++ b/scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala @@ -592,7 +592,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz */ // optimisations are not for speed, but for functionality // see tickets #153, #498, #2147, and corresponding tests in run/ (as well as run/stream_flatmap_odds.scala) - override def flatMap[B](f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f} = + override def flatMap[B](f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f*} = if (knownIsEmpty) LazyListIterable.empty else LazyListIterable.flatMapImpl(this, f) @@ -600,7 +600,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz * * $preservesLaziness */ - override def flatten[B](implicit asIterable: A -> IterableOnce[B]): LazyListIterable[B]^{this} = flatMap(asIterable) + override def flatten[B](implicit asIterable: A -> IterableOnce[B]): LazyListIterable[B]^{this, asIterable*} = flatMap(asIterable) /** @inheritdoc * @@ -1061,11 +1061,11 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { } } - private def flatMapImpl[A, B](ll: LazyListIterable[A]^, f: A => IterableOnce[B]^): LazyListIterable[B]^{ll, f} = { + private def flatMapImpl[A, B](ll: LazyListIterable[A]^, f: A => IterableOnce[B]^): LazyListIterable[B]^{ll, f*} = { // DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD var restRef: LazyListIterable[A]^{ll} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric newLL { - var it: Iterator[B]^{ll, f} = null + var it: Iterator[B]^{ll, f*} = null var itHasNext = false var rest = restRef // var rest = restRef.elem while (!itHasNext && !rest.isEmpty) { @@ -1185,7 +1185,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { /** Creates a State from an Iterator, with another State appended after the Iterator * is empty. */ - private def stateFromIteratorConcatSuffix[A](it: Iterator[A]^)(suffix: => State[A]^): State[A]^{it, suffix} = + private def stateFromIteratorConcatSuffix[A](it: Iterator[A]^)(suffix: => State[A]^): State[A]^{it, suffix*} = if (it.hasNext) sCons(it.next(), newLL(stateFromIteratorConcatSuffix(it)(suffix))) else suffix @@ -1307,7 +1307,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { extends collection.WithFilter[A, LazyListIterable] { private[this] val filtered = lazyList.filter(p) def map[B](f: A => B): LazyListIterable[B]^{this, f} = filtered.map(f) - def flatMap[B](f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f} = filtered.flatMap(f) + def flatMap[B](f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f*} = filtered.flatMap(f) def foreach[U](f: A => U): Unit = filtered.foreach(f) def withFilter(q: A => Boolean): collection.WithFilter[A, LazyListIterable]^{this, q} = new WithFilter(filtered, q) } @@ -1353,7 +1353,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { final class DeferredState[A] { private[this] var _state: (() => State[A]^) @uncheckedCaptures = _ - def eval(): State[A]^ = { + def eval(): State[A]^{this} = { val state = _state if (state == null) throw new IllegalStateException("uninitialized") state()