Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Backport #1265 - Extract all subtype annotations #1269

Merged
merged 2 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 37 additions & 72 deletions core/src/main/scala/shapeless/annotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ object Annotations {
def apply() = annotations
}

implicit def materialize[A, T, Out <: HList]: Aux[A, T, Out] = macro AnnotationMacros.materializeVariableAnnotations[A, T, Out]
implicit def materialize[A, T, Out <: HList]: Aux[A, T, Out] =
macro AnnotationMacros.materializeVariableAnnotations[A, T, Out]
}

/**
Expand Down Expand Up @@ -172,7 +173,8 @@ object TypeAnnotations {
def apply() = annotations
}

implicit def materialize[A, T, Out <: HList]: Aux[A, T, Out] = macro AnnotationMacros.materializeTypeAnnotations[A, T, Out]
implicit def materialize[A, T, Out <: HList]: Aux[A, T, Out] =
macro AnnotationMacros.materializeTypeAnnotations[A, T, Out]
}

/**
Expand Down Expand Up @@ -226,7 +228,8 @@ object AllAnnotations {
def apply(): Out = annotations
}

implicit def materialize[T, Out <: HList]: Aux[T, Out] = macro AnnotationMacros.materializeAllVariableAnnotations[T, Out]
implicit def materialize[T, Out <: HList]: Aux[T, Out] =
macro AnnotationMacros.materializeAllVariableAnnotations[T, Out]
}

/**
Expand Down Expand Up @@ -280,7 +283,8 @@ object AllTypeAnnotations {
def apply(): Out = annotations
}

implicit def materialize[T, Out <: HList]: Aux[T, Out] = macro AnnotationMacros.materializeAllTypeAnnotations[T, Out]
implicit def materialize[T, Out <: HList]: Aux[T, Out] =
macro AnnotationMacros.materializeAllTypeAnnotations[T, Out]
}

class AnnotationMacros(val c: whitebox.Context) extends CaseClassMacros {
Expand All @@ -290,62 +294,51 @@ class AnnotationMacros(val c: whitebox.Context) extends CaseClassMacros {
def someTpe: Type = typeOf[Some[_]].typeConstructor
def noneTpe: Type = typeOf[None.type]

private def annotation = objectRef[shapeless.Annotation.type]
private def some = objectRef[Some.type]
private def none = objectRef[None.type]
private def someVal = objectRef[Some.type]
private def noneVal = objectRef[None.type]
private def annotationVal = objectRef[shapeless.Annotation.type]

/**
* FIXME Most of the content of this method is cut-n-pasted from generic.scala
*
* @return The AST of the `tpe` constructor.
*/
def construct(tpe: Type): List[Tree] => Tree = {
// FIXME Cut-n-pasted from generic.scala
val sym = tpe.typeSymbol
val isCaseClass = sym.asClass.isCaseClass
def hasNonGenericCompanionMember(name: String): Boolean = {
val mSym = sym.companion.typeSignature.member(TermName(name))
mSym != NoSymbol && !isNonGeneric(mSym)
def hasCompanionMember(name: String) = {
val member = sym.companion.typeSignature.member(TermName(name))
member != NoSymbol && !isNonGeneric(member)
}

if(isCaseClass || hasNonGenericCompanionMember("apply"))
args => q"${companionRef(tpe)}(..$args)"
else
args => q"new $tpe(..$args)"
val useCompanion = sym.asClass.isCaseClass || hasCompanionMember("apply")
if (useCompanion) args => q"${companionRef(tpe)}(..$args)" else args => q"new $tpe(..$args)"
}

private def getAnnotation[A: WeakTypeTag, T: WeakTypeTag]: Option[Tree] = {
val annTpe = weakTypeOf[A]

if (!isProduct(annTpe))
abort(s"$annTpe is not a case class-like type")

val construct0 = construct(annTpe)

val tpe = weakTypeOf[T]

if (!isProduct(annTpe)) abort(s"$annTpe is not a case class-like type")
val constructor = construct(annTpe)
tpe.typeSymbol.annotations.collectFirst {
case ann if ann.tree.tpe =:= annTpe => construct0(ann.tree.children.tail)
case ann if ann.tree.tpe =:= annTpe => constructor(ann.tree.children.tail)
}
}

def materializeAnnotation[A: WeakTypeTag, T: WeakTypeTag]: Tree = {
val annTpe = weakTypeOf[A]
val tpe = weakTypeOf[T]

getAnnotation[A, T] match {
case Some(annTree) => q"$annotation.mkAnnotation[$annTpe, $tpe]($annTree)"
case Some(annTree) => q"$annotationVal.mkAnnotation[$annTpe, $tpe]($annTree)"
case None => abort(s"No $annTpe annotation found on $tpe")
}
}

def materializeAnnotationOptional[A: WeakTypeTag, T: WeakTypeTag]: Tree = {
val optAnnTpe = appliedType(optionTpe, weakTypeOf[A])
val tpe = weakTypeOf[T]

getAnnotation[A, T] match {
case Some(annTree) => q"$annotation.mkAnnotation[$optAnnTpe, $tpe]($some($annTree))"
case None => q"$annotation.mkAnnotation[$optAnnTpe, $tpe]($none)"
case Some(annTree) => q"$annotationVal.mkAnnotation[$optAnnTpe, $tpe]($someVal($annTree))"
case None => q"$annotationVal.mkAnnotation[$optAnnTpe, $tpe]($noneVal)"
}
}

Expand All @@ -354,7 +347,7 @@ class AnnotationMacros(val c: whitebox.Context) extends CaseClassMacros {

def materializeAllVariableAnnotations[T: WeakTypeTag, Out: WeakTypeTag]: Tree =
materializeAllAnnotations[T, Out](typeAnnotation = false)

def materializeTypeAnnotations[A: WeakTypeTag, T: WeakTypeTag, Out: WeakTypeTag]: Tree =
materializeAnnotations[A, T, Out](typeAnnotation = true)

Expand All @@ -364,54 +357,27 @@ class AnnotationMacros(val c: whitebox.Context) extends CaseClassMacros {
@deprecated("Use materializeVariableAnnotations instead", "2.3.6")
def materializeAnnotations[A: WeakTypeTag, T: WeakTypeTag, Out: WeakTypeTag]: Tree =
materializeVariableAnnotations[A, T, Out]

def materializeAnnotations[A: WeakTypeTag, T: WeakTypeTag, Out: WeakTypeTag](typeAnnotation: Boolean): Tree = {
val annTpe = weakTypeOf[A]

if (!isProduct(annTpe))
abort(s"$annTpe is not a case class-like type")

val tpe = weakTypeOf[T]

val annTreeOpts = getAnnotationTreeOptions(tpe, typeAnnotation).map { list =>
list.find(_._1 =:= annTpe).map(_._2)
}

val wrapTpeTrees = annTreeOpts.map {
case Some(annTree) => appliedType(someTpe, annTpe) -> q"_root_.scala.Some($annTree)"
case None => noneTpe -> q"_root_.scala.None"
}

val outTpe = mkHListTpe(wrapTpeTrees.map { case (aTpe, _) => aTpe })
val outTree = wrapTpeTrees.foldRight(q"_root_.shapeless.HNil": Tree) {
case ((_, bound), acc) => pq"_root_.shapeless.::($bound, $acc)"
}

if (typeAnnotation) q"_root_.shapeless.TypeAnnotations.mkAnnotations[$annTpe, $tpe, $outTpe]($outTree)"
else q"_root_.shapeless.Annotations.mkAnnotations[$annTpe, $tpe, $outTpe]($outTree)"
if (!isProduct(annTpe)) abort(s"$annTpe is not a case class-like type")
val (annTypes, annTrees) = getAnnotationTreeOptions(tpe, typeAnnotation).map(_.find(_._1 <:< annTpe) match {
case Some((annTpe, annTree)) => appliedType(someTpe, annTpe) -> q"$someVal($annTree)"
case None => noneTpe -> noneVal
}).unzip
val tc = if (typeAnnotation) objectRef[TypeAnnotations.type] else objectRef[Annotations.type]
q"$tc.mkAnnotations[$annTpe, $tpe, ${mkHListTpe(annTypes)}](${mkHListValue(annTrees)})"
}

def materializeAllAnnotations[T: WeakTypeTag, Out: WeakTypeTag](typeAnnotation: Boolean): Tree = {
val tpe = weakTypeOf[T]
val annTreeOpts = getAnnotationTreeOptions(tpe, typeAnnotation)

val wrapTpeTrees = annTreeOpts.map {
case Nil =>
mkHListTpe(Nil) -> q"(_root_.shapeless.HNil)"
case list =>
mkHListTpe(list.map(_._1)) -> list.foldRight(q"_root_.shapeless.HNil": Tree) {
case ((_, bound), acc) => pq"_root_.shapeless.::($bound, $acc)"
}
}

val outTpe = mkHListTpe(wrapTpeTrees.map { case (aTpe, _) => aTpe })
val outTree = wrapTpeTrees.foldRight(q"_root_.shapeless.HNil": Tree) {
case ((_, bound), acc) =>
pq"_root_.shapeless.::($bound, $acc)"
}

if (typeAnnotation) q"_root_.shapeless.AllTypeAnnotations.mkAnnotations[$tpe, $outTpe]($outTree)"
else q"_root_.shapeless.AllAnnotations.mkAnnotations[$tpe, $outTpe]($outTree)"
val (annTypes, annTrees) = getAnnotationTreeOptions(tpe, typeAnnotation).map { annotations =>
val (types, trees) = annotations.unzip
mkHListTpe(types) -> mkHListValue(trees)
}.unzip
val tc = if (typeAnnotation) objectRef[AllTypeAnnotations.type] else objectRef[AllAnnotations.type]
q"$tc.mkAnnotations[$tpe, ${mkHListTpe(annTypes)}](${mkHListValue(annTrees)})"
}

private def getAnnotationTreeOptions(tpe: Type, typeAnnotation: Boolean): List[List[(Type, Tree)]] = {
Expand Down Expand Up @@ -456,5 +422,4 @@ class AnnotationMacros(val c: whitebox.Context) extends CaseClassMacros {
if (tpe) fromType(s.typeSignature)
else s.annotations
}

}
42 changes: 28 additions & 14 deletions core/src/test/scala/shapeless/annotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ object AnnotationTestsDefinitions {
case class First() extends saAnnotation
case class Second(i: Int, s: String) extends saAnnotation
case class Third(c: Char) extends saAnnotation
case class Fourth[T](t: T) extends saAnnotation

case class Other() extends saAnnotation
case class Last(b: Boolean) extends saAnnotation
Expand Down Expand Up @@ -58,15 +59,17 @@ object AnnotationTestsDefinitions {
case class CC3(
@First i: Int,
s: String,
@Second(2, "b") @Third('c') ob: Option[Boolean]
@Second(2, "b") @Third('c') ob: Option[Boolean],
@Fourth(4) c: Char
)

case class CC4(
i: Int @First,
s: String,
ob: Option[Boolean] @Second(2, "b") @Third('c')
ob: Option[Boolean] @Second(2, "b") @Third('c'),
c: Char @Fourth(4)
)

type PosInt = Int @First
type Email = String @Third('c')
case class User(age: PosInt, email: Email)
Expand Down Expand Up @@ -97,18 +100,18 @@ class AnnotationTests {
def optionalAnnotation: Unit = {
{
val other = Annotation[Option[Other], CC].apply()
assert(other == Some(Other()))
assert(other.contains(Other()))

val last = Annotation[Option[Last], Something].apply()
assert(last == Some(Last(true)))
assert(last.contains(Last(true)))
}

{
val other: Option[Other] = Annotation[Option[Other], Something].apply()
assert(other == None)
assert(other.isEmpty)

val last: Option[Last] = Annotation[Option[Last], CC].apply()
assert(last == None)
assert(last.isEmpty)
}
}

Expand All @@ -128,6 +131,9 @@ class AnnotationTests {
val second: None.type :: None.type :: Some[Second] :: HNil = Annotations[Second, CC].apply()
assert(second == None :: None :: Some(Second(2, "b")) :: HNil)

val fourth: None.type :: None.type :: None.type :: Some[Fourth[Int]] :: HNil = Annotations[Fourth[_], CC3].apply()
assert(fourth == None :: None :: None :: Some(Fourth(4)) :: HNil)

val unused: None.type :: None.type :: None.type :: HNil = Annotations[Unused, CC].apply()
assert(unused == None :: None :: None :: HNil)

Expand All @@ -145,6 +151,9 @@ class AnnotationTests {
val second = Annotations[Second, CC].apply()
assert(second == None :: None :: Some(Second(2, "b")) :: HNil)

val fourth = Annotations[Fourth[_], CC3].apply()
assert(fourth == None :: None :: None :: Some(Fourth(4)) :: HNil)

val unused = Annotations[Unused, CC].apply()
assert(unused == None :: None :: None :: HNil)

Expand Down Expand Up @@ -172,6 +181,9 @@ class AnnotationTests {
val second: None.type :: None.type :: Some[Second] :: HNil = TypeAnnotations[Second, CC2].apply()
assert(second == None :: None :: Some(Second(2, "b")) :: HNil)

val fourth: None.type :: None.type :: None.type :: Some[Fourth[Int]] :: HNil = TypeAnnotations[Fourth[_], CC4].apply()
assert(fourth == None :: None :: None :: Some(Fourth(4)) :: HNil)

val unused: None.type :: None.type :: None.type :: HNil = TypeAnnotations[Unused, CC2].apply()
assert(unused == None :: None :: None :: HNil)
}
Expand All @@ -183,6 +195,9 @@ class AnnotationTests {
val second = TypeAnnotations[Second, CC2].apply()
assert(second == None :: None :: Some(Second(2, "b")) :: HNil)

val fourth = TypeAnnotations[Fourth[_], CC4].apply()
assert(fourth == None :: None :: None :: Some(Fourth(4)) :: HNil)

val unused = TypeAnnotations[Unused, CC2].apply()
assert(unused == None :: None :: None :: HNil)
}
Expand All @@ -192,14 +207,14 @@ class AnnotationTests {
def invalidTypeAnnotations: Unit = {
illTyped(" TypeAnnotations[Dummy, CC2] ", "could not find implicit value for parameter annotations: .*")
illTyped(" TypeAnnotations[Dummy, Base] ", "could not find implicit value for parameter annotations: .*")
illTyped(" TypeAnnotations[Second, Dummy] ", "could not find implicit value for parameter annotations: .*")
illTyped(" TypeAnnotations[Second, Dummy] ", "could not find implicit value for parameter annotations: .*")
}

@Test
def allAnnotations: Unit = {
val cc = AllAnnotations[CC3].apply()
typed[(First :: HNil) :: HNil :: (Second :: Third :: HNil) :: HNil](cc)
assert(cc == (First() :: HNil) :: HNil :: (Second(2, "b") :: Third('c') :: HNil) :: HNil)
typed[(First :: HNil) :: HNil :: (Second :: Third :: HNil) :: (Fourth[Int] :: HNil) :: HNil](cc)
assert(cc == (First() :: HNil) :: HNil :: (Second(2, "b") :: Third('c') :: HNil) :: (Fourth(4) :: HNil) :: HNil)

val st = AllAnnotations[Base].apply()
typed[(First :: HNil) :: (Second :: Third :: HNil) :: HNil](st)
Expand All @@ -211,12 +226,11 @@ class AnnotationTests {
typed[(First :: HNil) :: (Second :: Third :: HNil) :: HNil](st)

val cc = AllTypeAnnotations[CC4].apply() // case class
typed[(First :: HNil) :: HNil :: (Second :: Third :: HNil) :: HNil](cc)
assert(cc == (First() :: HNil) :: HNil :: (Second(2, "b") :: Third('c') :: HNil) :: HNil)
typed[(First :: HNil) :: HNil :: (Second :: Third :: HNil) :: (Fourth[Int] :: HNil) :: HNil](cc)
assert(cc == (First() :: HNil) :: HNil :: (Second(2, "b") :: Third('c') :: HNil) :: (Fourth(4) :: HNil) :: HNil)

val user = AllTypeAnnotations[User].apply() // type refs
typed[(First :: HNil) :: (Third :: HNil) :: HNil](user)
assert(user == (First() :: HNil) :: (Third('c') :: HNil) :: HNil)
}

}