Skip to content

Change retains annotation from using term arguments to using type arguments #22909

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2261,9 +2261,9 @@ object desugar {
AppliedTypeTree(ref(defn.SeqType), t),
New(ref(defn.RepeatedAnnot.typeRef), Nil :: Nil))
else if op.name == nme.CC_REACH then
Apply(ref(defn.Caps_reachCapability), t :: Nil)
Annotated(t, New(ref(defn.ReachCapabilityAnnot.typeRef), Nil :: Nil))
else if op.name == nme.CC_READONLY then
Apply(ref(defn.Caps_readOnlyCapability), t :: Nil)
Annotated(t, New(ref(defn.ReadOnlyCapabilityAnnot.typeRef), Nil :: Nil))
else
assert(ctx.mode.isExpr || ctx.reporter.errorsReported || ctx.mode.is(Mode.Interactive), ctx.mode)
Select(t, op.name)
Expand Down
20 changes: 16 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
/** Property key for contextual Apply trees of the form `fn given arg` */
val KindOfApply: Property.StickyKey[ApplyKind] = Property.StickyKey()

val RetainsAnnot: Property.StickyKey[Unit] = Property.StickyKey()

// ------ Creation methods for untyped only -----------------

def Ident(name: Name)(implicit src: SourceFile): Ident = new Ident(name)
Expand Down Expand Up @@ -528,10 +530,17 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
Select(scalaDot(nme.caps), nme.CAPTURE_ROOT)

def makeRetaining(parent: Tree, refs: List[Tree], annotName: TypeName)(using Context): Annotated =
Annotated(parent, New(scalaAnnotationDot(annotName), List(refs)))

def makeCapsOf(tp: RefTree)(using Context): Tree =
TypeApply(capsInternalDot(nme.capsOf), tp :: Nil)
var annot: Tree = scalaAnnotationDot(annotName)
if annotName == tpnme.retainsCap then
annot = New(annot, Nil)
else
val trefs =
if refs.isEmpty then ref(defn.NothingType)
// TODO: choose a reduce direction
else refs.map(SingletonTypeTree).reduce[Tree]((a, b) => makeOrType(a, b))
annot = New(AppliedTypeTree(annot, trefs :: Nil), Nil)
annot.putAttachment(RetainsAnnot, ())
Annotated(parent, annot)

// Capture set variable `[C^]` becomes: `[C >: CapSet <: CapSet^{cap}]`
def makeCapsBound()(using Context): TypeBoundsTree =
Expand Down Expand Up @@ -563,6 +572,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def makeAndType(left: Tree, right: Tree)(using Context): AppliedTypeTree =
AppliedTypeTree(ref(defn.andType.typeRef), left :: right :: Nil)

def makeOrType(left: Tree, right: Tree)(using Context): AppliedTypeTree =
AppliedTypeTree(ref(defn.orType.typeRef), left :: right :: Nil)

def makeParameter(pname: TermName, tpe: Tree, mods: Modifiers, isBackquoted: Boolean = false)(using Context): ValDef = {
val vdef = ValDef(pname, tpe, EmptyTree)
if (isBackquoted) vdef.pushAttachment(Backquoted, ())
Expand Down
17 changes: 8 additions & 9 deletions compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean)(cls: Symbol) exte

/** Reconstitute annotation tree from capture set */
override def tree(using Context) =
val elems = refs.elems.toList.map {
case cr: TermRef => ref(cr)
case cr: TermParamRef => untpd.Ident(cr.paramName).withType(cr)
case cr: ThisType => This(cr.cls)
case root(_) => ref(root.cap)
// TODO: Will crash if the type is an annotated type, for example `cap.rd`
}
val arg = repeated(elems, TypeTree(defn.AnyType))
New(symbol.typeRef, arg :: Nil)
if symbol == defn.RetainsCapAnnot then
New(symbol.typeRef, Nil)
else
val elems = refs.elems.toList
val trefs =
if elems.isEmpty then defn.NothingType
else elems.reduce[Type]((a, b) => OrType(a, b, soft = false))
New(AppliedType(symbol.typeRef, trefs :: Nil), Nil)

override def symbol(using Context) = cls

Expand Down
74 changes: 34 additions & 40 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,38 +190,25 @@ def ccState(using Context): CCState =

extension (tree: Tree)

/** Map tree with CaptureRef type to its type,
* map CapSet^{refs} to the `refs` references,
* throw IllegalCaptureRef otherwise
*/
def toCaptureRefs(using Context): List[CaptureRef] = tree match
case ReachCapabilityApply(arg) =>
arg.toCaptureRefs.map(_.reach)
case ReadOnlyCapabilityApply(arg) =>
arg.toCaptureRefs.map(_.readOnly)
case CapsOfApply(arg) =>
arg.toCaptureRefs
case _ => tree.tpe.dealiasKeepAnnots match
case ref: CaptureRef if ref.isTrackableRef =>
ref :: Nil
case AnnotatedType(parent, ann)
if ann.symbol.isRetains && parent.derivesFrom(defn.Caps_CapSet) =>
ann.tree.toCaptureSet.elems.toList
case tpe =>
throw IllegalCaptureRef(tpe) // if this was compiled from cc syntax, problem should have been reported at Typer

/** Convert a @retains or @retainsByName annotation tree to the capture set it represents.
* For efficience, the result is cached as an Attachment on the tree.
*/
def toCaptureSet(using Context): CaptureSet =
tree.getAttachment(Captures) match
case Some(refs) => refs
case None =>
val refs = CaptureSet(tree.retainedElems.flatMap(_.toCaptureRefs)*)
//.showing(i"toCaptureSet $tree --> $result", capt)
val refs = CaptureSet(tree.retainedSet.retainedElements*)
tree.putAttachment(Captures, refs)
refs

def retainedSet(using Context): Type =
tree match
case Apply(TypeApply(_, refs :: Nil), _) => refs.tpe
case _ =>
if tree.symbol.maybeOwner == defn.RetainsCapAnnot
then root.cap
else NoType

/** The arguments of a @retains, @retainsCap or @retainsByName annotation */
def retainedElems(using Context): List[Tree] = tree match
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) =>
Expand All @@ -233,6 +220,21 @@ extension (tree: Tree)

extension (tp: Type)

def retainedElements(using Context): List[CaptureRef] = tp match
case ReachCapability(tp1) =>
tp1.reach :: Nil
case ReadOnlyCapability(tp1) =>
tp1.readOnly :: Nil
case tp: CaptureRef if tp.isTrackableRef =>
tp :: Nil
case tp: TypeRef if tp.symbol.isType && tp.derivesFrom(defn.Caps_CapSet) =>
tp :: Nil
case OrType(tp1, tp2) =>
tp1.retainedElements ++ tp2.retainedElements
case _ =>
if tp.isNothingType then Nil
else throw IllegalCaptureRef(tp)

/** Is this type a CaptureRef that can be tracked?
* This is true for
* - all ThisTypes and all TermParamRef,
Expand Down Expand Up @@ -655,7 +657,7 @@ extension (cls: ClassSymbol)
|| bc.is(CaptureChecked)
&& bc.givenSelfType.dealiasKeepAnnots.match
case CapturingType(_, refs) => refs.isAlwaysEmpty
case RetainingType(_, refs) => refs.isEmpty
case RetainingType(_, refs) => refs.retainedElements.isEmpty
case selfType =>
isCaptureChecking // At Setup we have not processed self types yet, so
// unless a self type is explicitly given, we can't tell
Expand Down Expand Up @@ -773,7 +775,7 @@ class CleanupRetains(using Context) extends TypeMap:
def apply(tp: Type): Type =
tp match
case AnnotatedType(tp, annot) if annot.symbol == defn.RetainsAnnot || annot.symbol == defn.RetainsByNameAnnot =>
RetainingType(tp, Nil, byName = annot.symbol == defn.RetainsByNameAnnot)
RetainingType(tp, defn.NothingType, byName = annot.symbol == defn.RetainsByNameAnnot)
case _ => mapOver(tp)

/** A typemap that follows aliases and keeps their transformed results if
Expand All @@ -792,26 +794,18 @@ trait FollowAliasesMap(using Context) extends TypeMap:
/** An extractor for `caps.reachCapability(ref)`, which is used to express a reach
* capability as a tree in a @retains annotation.
*/
object ReachCapabilityApply:
def unapply(tree: Apply)(using Context): Option[Tree] = tree match
case Apply(reach, arg :: Nil) if reach.symbol == defn.Caps_reachCapability => Some(arg)
case _ => None
// object ReachCapabilityApply:
// def unapply(tree: Apply)(using Context): Option[Tree] = tree match
// case Apply(reach, arg :: Nil) if reach.symbol == defn.Caps_reachCapability => Some(arg)
// case _ => None

/** An extractor for `caps.readOnlyCapability(ref)`, which is used to express a read-only
* capability as a tree in a @retains annotation.
*/
object ReadOnlyCapabilityApply:
def unapply(tree: Apply)(using Context): Option[Tree] = tree match
case Apply(ro, arg :: Nil) if ro.symbol == defn.Caps_readOnlyCapability => Some(arg)
case _ => None

/** An extractor for `caps.capsOf[X]`, which is used to express a generic capture set
* as a tree in a @retains annotation.
*/
object CapsOfApply:
def unapply(tree: TypeApply)(using Context): Option[Tree] = tree match
case TypeApply(capsOf, arg :: Nil) if capsOf.symbol == defn.Caps_capsOf => Some(arg)
case _ => None
// object ReadOnlyCapabilityApply:
// def unapply(tree: Apply)(using Context): Option[Tree] = tree match
// case Apply(ro, arg :: Nil) if ro.symbol == defn.Caps_readOnlyCapability => Some(arg)
// case _ => None

abstract class AnnotatedCapability(annotCls: Context ?=> ClassSymbol):
def apply(tp: Type)(using Context): AnnotatedType =
Expand Down
25 changes: 12 additions & 13 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,24 @@ object CheckCaptures:
* This check is performed at Typer.
*/
def checkWellformed(parent: Tree, ann: Tree)(using Context): Unit =
def check(elem: Tree, pos: SrcPos): Unit = elem.tpe match
def check(elem: Type, pos: SrcPos): Unit = elem match
case ref: CaptureRef =>
if !ref.isTrackableRef then
report.error(em"$elem cannot be tracked since it is not a parameter or local value", pos)
case tpe =>
report.error(em"$elem: $tpe is not a legal element of a capture set", pos)
for elem <- ann.retainedElems do
for elem <- ann.retainedSet.retainedElements do
elem match
case CapsOfApply(arg) =>
def isLegalCapsOfArg =
arg.symbol.isType && arg.symbol.info.derivesFrom(defn.Caps_CapSet)
if !isLegalCapsOfArg then
report.error(
em"""$arg is not a legal prefix for `^` here,
|is must be a type parameter or abstract type with a caps.CapSet upper bound.""",
elem.srcPos)
case ReachCapabilityApply(arg) => check(arg, elem.srcPos)
case ReadOnlyCapabilityApply(arg) => check(arg, elem.srcPos)
case _ => check(elem, elem.srcPos)
case ref: TypeRef =>
val refSym = ref.symbol
if refSym.isType && !refSym.info.derivesFrom(defn.Caps_CapSet) then
report.error(em"$elem is not a legal element of a capture set", ann.srcPos)
case ReachCapability(ref) =>
check(ref, ann.srcPos)
case ReadOnlyCapability(ref) =>
check(ref, ann.srcPos)
case _ =>
check(elem, ann.srcPos)

/** Under the sealed policy, report an error if some part of `tp` contains the
* root capability in its capture set or if it refers to a type parameter that
Expand Down
12 changes: 4 additions & 8 deletions compiler/src/dotty/tools/dotc/cc/RetainingType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,20 @@ import Decorators.i
*/
object RetainingType:

def apply(tp: Type, refs: List[Tree], byName: Boolean = false)(using Context): Type =
def apply(tp: Type, typeElems: Type, byName: Boolean = false)(using Context): Type =
val annotCls = if byName then defn.RetainsByNameAnnot else defn.RetainsAnnot
val annotTree =
New(annotCls.typeRef,
Typed(
SeqLiteral(refs, TypeTree(defn.AnyType)),
TypeTree(defn.RepeatedParamClass.typeRef.appliedTo(defn.AnyType))) :: Nil)
val annotTree = New(AppliedType(annotCls.typeRef, typeElems :: Nil), Nil)
AnnotatedType(tp, Annotation(annotTree))

def unapply(tp: AnnotatedType)(using Context): Option[(Type, List[Tree])] =
def unapply(tp: AnnotatedType)(using Context): Option[(Type, Type)] =
val sym = tp.annot.symbol
if sym.isRetainsLike then
tp.annot match
case _: CaptureAnnotation =>
assert(ctx.mode.is(Mode.IgnoreCaptures), s"bad retains $tp at ${ctx.phase}")
None
case ann =>
Some((tp.parent, ann.tree.retainedElems))
Some((tp.parent, ann.tree.retainedSet))
else
None
end RetainingType
29 changes: 12 additions & 17 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
case CapturingType(_, refs) =>
!refs.isAlwaysEmpty
case RetainingType(parent, refs) =>
!refs.isEmpty
!refs.retainedElements.isEmpty
case tp: (TypeRef | AppliedType) =>
val sym = tp.typeSymbol
if sym.isClass
Expand Down Expand Up @@ -856,7 +856,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
&& !refs.isUniversal // if refs is {cap}, an added variable would not change anything
case RetainingType(parent, refs) =>
needsVariable(parent)
&& !refs.tpes.exists:
&& !refs.retainedElements.exists:
case ref: TermRef => ref.isCap
case _ => false
case AnnotatedType(parent, _) =>
Expand Down Expand Up @@ -951,19 +951,13 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
* @param tpt the tree for which an error or warning should be reported
*/
private def checkWellformed(parent: Type, ann: Tree, tpt: Tree)(using Context): Unit =
capt.println(i"checkWF post $parent ${ann.retainedElems} in $tpt")
var retained = ann.retainedElems.toArray
for i <- 0 until retained.length do
val refTree = retained(i)
val refs =
try refTree.toCaptureRefs
catch case ex: IllegalCaptureRef =>
report.error(em"Illegal capture reference: ${ex.getMessage.nn}", refTree.srcPos)
Nil
for ref <- refs do
capt.println(i"checkWF post $parent ${ann.retainedSet} in $tpt")
try
val retainedRefs = ann.retainedSet.retainedElements.toArray
for i <- 0 until retainedRefs.length do
val ref = retainedRefs(i)
def pos =
if refTree.span.exists then refTree.srcPos
else if ann.span.exists then ann.srcPos
if ann.span.exists then ann.srcPos
else tpt.srcPos

def check(others: CaptureSet, dom: Type | CaptureSet): Unit =
Expand All @@ -979,14 +973,15 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:

val others =
for
j <- 0 until retained.length if j != i
r <- retained(j).toCaptureRefs
j <- 0 until retainedRefs.length if j != i
r = retainedRefs(j)
if !r.isRootCapability
yield r
val remaining = CaptureSet(others*)
check(remaining, remaining)
end for
end for
catch case ex: IllegalCaptureRef =>
report.error(em"Illegal capture reference: ${ex.getMessage.nn}", tpt.srcPos)
end checkWellformed

/** Check well formed at post check time. We need to wait until after
Expand Down
6 changes: 1 addition & 5 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class Definitions {
HKTypeLambda(argParamNames :+ "R".toTypeName, argVariances :+ Covariant)(
tl => List.fill(arity + 1)(TypeBounds.empty),
tl => RetainingType(underlyingClass.typeRef.appliedTo(tl.paramRefs),
ref(captureRoot.termRef) :: Nil)
captureRoot.termRef)
))
else
val cls = denot.asClass.classSymbol
Expand Down Expand Up @@ -998,9 +998,6 @@ class Definitions {
@tu lazy val Caps_Capability: ClassSymbol = requiredClass("scala.caps.Capability")
@tu lazy val Caps_CapSet: ClassSymbol = requiredClass("scala.caps.CapSet")
@tu lazy val CapsInternalModule: Symbol = requiredModule("scala.caps.internal")
@tu lazy val Caps_reachCapability: TermSymbol = CapsInternalModule.requiredMethod("reachCapability")
@tu lazy val Caps_readOnlyCapability: TermSymbol = CapsInternalModule.requiredMethod("readOnlyCapability")
@tu lazy val Caps_capsOf: TermSymbol = CapsInternalModule.requiredMethod("capsOf")
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
@tu lazy val Caps_unsafeAssumeSeparate: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumeSeparate")
Expand Down Expand Up @@ -1093,7 +1090,6 @@ class Definitions {
@tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.annotation.retains")
@tu lazy val RetainsCapAnnot: ClassSymbol = requiredClass("scala.annotation.retainsCap")
@tu lazy val RetainsByNameAnnot: ClassSymbol = requiredClass("scala.annotation.retainsByName")
@tu lazy val RetainsArgAnnot: ClassSymbol = requiredClass("scala.annotation.retainsArg")
@tu lazy val PublicInBinaryAnnot: ClassSymbol = requiredClass("scala.annotation.publicInBinary")
@tu lazy val WitnessNamesAnnot: ClassSymbol = requiredClass("scala.annotation.internal.WitnessNames")

Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ object Flags {
/** Tracked modifier for class parameter / a class with some tracked parameters */
val (Tracked @ _, _, Dependent @ _) = newFlags(46, "tracked")

val (CaptureParam @ _, _, _) = newFlags(47, "capture-param")

// ------------ Flags following this one are not pickled ----------------------------------

/** Symbol is not a member of its owner */
Expand Down Expand Up @@ -449,7 +451,7 @@ object Flags {

/** Flags representing source modifiers */
private val CommonSourceModifierFlags: FlagSet =
commonFlags(Private, Protected, Final, Case, Implicit, Given, Override, JavaStatic, Transparent, Erased)
commonFlags(Private, Protected, Final, Case, Implicit, Given, Override, JavaStatic, Transparent, Erased, CaptureParam)

val TypeSourceModifierFlags: FlagSet =
CommonSourceModifierFlags.toTypeFlags | Abstract | Sealed | Opaque | Open
Expand All @@ -469,7 +471,7 @@ object Flags {
val FromStartFlags: FlagSet = commonFlags(
Module, Package, Deferred, Method, Case, Enum, Param, ParamAccessor,
Scala2SpecialFlags, MutableOrOpen, Opaque, Touched, JavaStatic,
OuterOrCovariant, LabelOrContravariant, CaseAccessor, Tracked,
OuterOrCovariant, LabelOrContravariant, CaseAccessor, Tracked, CaptureParam,
Extension, NonMember, Implicit, Given, Permanent, Synthetic, Exported,
SuperParamAliasOrScala2x, Inline, Macro, ConstructorProxy, Invisible)

Expand Down
8 changes: 8 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Mode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ object Mode {
*/
val ImplicitExploration: Mode = newMode(12, "ImplicitExploration")

/** We are currently inside a capture set.
* A term name could be a capture variable, so we need to
* check that it is valid to use as type name.
* Since this mode is only used during annotation typing,
* we can reuse the value of `ImplicitExploration` to save bits.
*/
val InCaptureSet: Mode = ImplicitExploration

/** We are currently unpickling Scala2 info */
val Scala2Unpickling: Mode = newMode(13, "Scala2Unpickling")

Expand Down
Loading
Loading