Skip to content

Commit

Permalink
Merge pull request #1269 from joroKr21/subtype-anns
Browse files Browse the repository at this point in the history
Backport #1265 - Extract all subtype annotations
  • Loading branch information
joroKr21 authored Sep 7, 2022
2 parents 1913400 + 614bc4e commit 8e2a6f2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 86 deletions.
109 changes: 37 additions & 72 deletions core/src/main/scala/shapeless/annotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ object Annotations {
def apply() = annotations
}

implicit def materialize[A, T, Out <: HList]: Aux[A, T, Out] = macro AnnotationMacros.materializeVariableAnnotations[A, T, Out]
implicit def materialize[A, T, Out <: HList]: Aux[A, T, Out] =
macro AnnotationMacros.materializeVariableAnnotations[A, T, Out]
}

/**
Expand Down Expand Up @@ -172,7 +173,8 @@ object TypeAnnotations {
def apply() = annotations
}

implicit def materialize[A, T, Out <: HList]: Aux[A, T, Out] = macro AnnotationMacros.materializeTypeAnnotations[A, T, Out]
implicit def materialize[A, T, Out <: HList]: Aux[A, T, Out] =
macro AnnotationMacros.materializeTypeAnnotations[A, T, Out]
}

/**
Expand Down Expand Up @@ -226,7 +228,8 @@ object AllAnnotations {
def apply(): Out = annotations
}

implicit def materialize[T, Out <: HList]: Aux[T, Out] = macro AnnotationMacros.materializeAllVariableAnnotations[T, Out]
implicit def materialize[T, Out <: HList]: Aux[T, Out] =
macro AnnotationMacros.materializeAllVariableAnnotations[T, Out]
}

/**
Expand Down Expand Up @@ -280,7 +283,8 @@ object AllTypeAnnotations {
def apply(): Out = annotations
}

implicit def materialize[T, Out <: HList]: Aux[T, Out] = macro AnnotationMacros.materializeAllTypeAnnotations[T, Out]
implicit def materialize[T, Out <: HList]: Aux[T, Out] =
macro AnnotationMacros.materializeAllTypeAnnotations[T, Out]
}

class AnnotationMacros(val c: whitebox.Context) extends CaseClassMacros {
Expand All @@ -290,62 +294,51 @@ class AnnotationMacros(val c: whitebox.Context) extends CaseClassMacros {
def someTpe: Type = typeOf[Some[_]].typeConstructor
def noneTpe: Type = typeOf[None.type]

private def annotation = objectRef[shapeless.Annotation.type]
private def some = objectRef[Some.type]
private def none = objectRef[None.type]
private def someVal = objectRef[Some.type]
private def noneVal = objectRef[None.type]
private def annotationVal = objectRef[shapeless.Annotation.type]

/**
* FIXME Most of the content of this method is cut-n-pasted from generic.scala
*
* @return The AST of the `tpe` constructor.
*/
def construct(tpe: Type): List[Tree] => Tree = {
// FIXME Cut-n-pasted from generic.scala
val sym = tpe.typeSymbol
val isCaseClass = sym.asClass.isCaseClass
def hasNonGenericCompanionMember(name: String): Boolean = {
val mSym = sym.companion.typeSignature.member(TermName(name))
mSym != NoSymbol && !isNonGeneric(mSym)
def hasCompanionMember(name: String) = {
val member = sym.companion.typeSignature.member(TermName(name))
member != NoSymbol && !isNonGeneric(member)
}

if(isCaseClass || hasNonGenericCompanionMember("apply"))
args => q"${companionRef(tpe)}(..$args)"
else
args => q"new $tpe(..$args)"
val useCompanion = sym.asClass.isCaseClass || hasCompanionMember("apply")
if (useCompanion) args => q"${companionRef(tpe)}(..$args)" else args => q"new $tpe(..$args)"
}

private def getAnnotation[A: WeakTypeTag, T: WeakTypeTag]: Option[Tree] = {
val annTpe = weakTypeOf[A]

if (!isProduct(annTpe))
abort(s"$annTpe is not a case class-like type")

val construct0 = construct(annTpe)

val tpe = weakTypeOf[T]

if (!isProduct(annTpe)) abort(s"$annTpe is not a case class-like type")
val constructor = construct(annTpe)
tpe.typeSymbol.annotations.collectFirst {
case ann if ann.tree.tpe =:= annTpe => construct0(ann.tree.children.tail)
case ann if ann.tree.tpe =:= annTpe => constructor(ann.tree.children.tail)
}
}

def materializeAnnotation[A: WeakTypeTag, T: WeakTypeTag]: Tree = {
val annTpe = weakTypeOf[A]
val tpe = weakTypeOf[T]

getAnnotation[A, T] match {
case Some(annTree) => q"$annotation.mkAnnotation[$annTpe, $tpe]($annTree)"
case Some(annTree) => q"$annotationVal.mkAnnotation[$annTpe, $tpe]($annTree)"
case None => abort(s"No $annTpe annotation found on $tpe")
}
}

def materializeAnnotationOptional[A: WeakTypeTag, T: WeakTypeTag]: Tree = {
val optAnnTpe = appliedType(optionTpe, weakTypeOf[A])
val tpe = weakTypeOf[T]

getAnnotation[A, T] match {
case Some(annTree) => q"$annotation.mkAnnotation[$optAnnTpe, $tpe]($some($annTree))"
case None => q"$annotation.mkAnnotation[$optAnnTpe, $tpe]($none)"
case Some(annTree) => q"$annotationVal.mkAnnotation[$optAnnTpe, $tpe]($someVal($annTree))"
case None => q"$annotationVal.mkAnnotation[$optAnnTpe, $tpe]($noneVal)"
}
}

Expand All @@ -354,7 +347,7 @@ class AnnotationMacros(val c: whitebox.Context) extends CaseClassMacros {

def materializeAllVariableAnnotations[T: WeakTypeTag, Out: WeakTypeTag]: Tree =
materializeAllAnnotations[T, Out](typeAnnotation = false)

def materializeTypeAnnotations[A: WeakTypeTag, T: WeakTypeTag, Out: WeakTypeTag]: Tree =
materializeAnnotations[A, T, Out](typeAnnotation = true)

Expand All @@ -364,54 +357,27 @@ class AnnotationMacros(val c: whitebox.Context) extends CaseClassMacros {
@deprecated("Use materializeVariableAnnotations instead", "2.3.6")
def materializeAnnotations[A: WeakTypeTag, T: WeakTypeTag, Out: WeakTypeTag]: Tree =
materializeVariableAnnotations[A, T, Out]

def materializeAnnotations[A: WeakTypeTag, T: WeakTypeTag, Out: WeakTypeTag](typeAnnotation: Boolean): Tree = {
val annTpe = weakTypeOf[A]

if (!isProduct(annTpe))
abort(s"$annTpe is not a case class-like type")

val tpe = weakTypeOf[T]

val annTreeOpts = getAnnotationTreeOptions(tpe, typeAnnotation).map { list =>
list.find(_._1 =:= annTpe).map(_._2)
}

val wrapTpeTrees = annTreeOpts.map {
case Some(annTree) => appliedType(someTpe, annTpe) -> q"_root_.scala.Some($annTree)"
case None => noneTpe -> q"_root_.scala.None"
}

val outTpe = mkHListTpe(wrapTpeTrees.map { case (aTpe, _) => aTpe })
val outTree = wrapTpeTrees.foldRight(q"_root_.shapeless.HNil": Tree) {
case ((_, bound), acc) => pq"_root_.shapeless.::($bound, $acc)"
}

if (typeAnnotation) q"_root_.shapeless.TypeAnnotations.mkAnnotations[$annTpe, $tpe, $outTpe]($outTree)"
else q"_root_.shapeless.Annotations.mkAnnotations[$annTpe, $tpe, $outTpe]($outTree)"
if (!isProduct(annTpe)) abort(s"$annTpe is not a case class-like type")
val (annTypes, annTrees) = getAnnotationTreeOptions(tpe, typeAnnotation).map(_.find(_._1 <:< annTpe) match {
case Some((annTpe, annTree)) => appliedType(someTpe, annTpe) -> q"$someVal($annTree)"
case None => noneTpe -> noneVal
}).unzip
val tc = if (typeAnnotation) objectRef[TypeAnnotations.type] else objectRef[Annotations.type]
q"$tc.mkAnnotations[$annTpe, $tpe, ${mkHListTpe(annTypes)}](${mkHListValue(annTrees)})"
}

def materializeAllAnnotations[T: WeakTypeTag, Out: WeakTypeTag](typeAnnotation: Boolean): Tree = {
val tpe = weakTypeOf[T]
val annTreeOpts = getAnnotationTreeOptions(tpe, typeAnnotation)

val wrapTpeTrees = annTreeOpts.map {
case Nil =>
mkHListTpe(Nil) -> q"(_root_.shapeless.HNil)"
case list =>
mkHListTpe(list.map(_._1)) -> list.foldRight(q"_root_.shapeless.HNil": Tree) {
case ((_, bound), acc) => pq"_root_.shapeless.::($bound, $acc)"
}
}

val outTpe = mkHListTpe(wrapTpeTrees.map { case (aTpe, _) => aTpe })
val outTree = wrapTpeTrees.foldRight(q"_root_.shapeless.HNil": Tree) {
case ((_, bound), acc) =>
pq"_root_.shapeless.::($bound, $acc)"
}

if (typeAnnotation) q"_root_.shapeless.AllTypeAnnotations.mkAnnotations[$tpe, $outTpe]($outTree)"
else q"_root_.shapeless.AllAnnotations.mkAnnotations[$tpe, $outTpe]($outTree)"
val (annTypes, annTrees) = getAnnotationTreeOptions(tpe, typeAnnotation).map { annotations =>
val (types, trees) = annotations.unzip
mkHListTpe(types) -> mkHListValue(trees)
}.unzip
val tc = if (typeAnnotation) objectRef[AllTypeAnnotations.type] else objectRef[AllAnnotations.type]
q"$tc.mkAnnotations[$tpe, ${mkHListTpe(annTypes)}](${mkHListValue(annTrees)})"
}

private def getAnnotationTreeOptions(tpe: Type, typeAnnotation: Boolean): List[List[(Type, Tree)]] = {
Expand Down Expand Up @@ -456,5 +422,4 @@ class AnnotationMacros(val c: whitebox.Context) extends CaseClassMacros {
if (tpe) fromType(s.typeSignature)
else s.annotations
}

}
42 changes: 28 additions & 14 deletions core/src/test/scala/shapeless/annotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ object AnnotationTestsDefinitions {
case class First() extends saAnnotation
case class Second(i: Int, s: String) extends saAnnotation
case class Third(c: Char) extends saAnnotation
case class Fourth[T](t: T) extends saAnnotation

case class Other() extends saAnnotation
case class Last(b: Boolean) extends saAnnotation
Expand Down Expand Up @@ -58,15 +59,17 @@ object AnnotationTestsDefinitions {
case class CC3(
@First i: Int,
s: String,
@Second(2, "b") @Third('c') ob: Option[Boolean]
@Second(2, "b") @Third('c') ob: Option[Boolean],
@Fourth(4) c: Char
)

case class CC4(
i: Int @First,
s: String,
ob: Option[Boolean] @Second(2, "b") @Third('c')
ob: Option[Boolean] @Second(2, "b") @Third('c'),
c: Char @Fourth(4)
)

type PosInt = Int @First
type Email = String @Third('c')
case class User(age: PosInt, email: Email)
Expand Down Expand Up @@ -97,18 +100,18 @@ class AnnotationTests {
def optionalAnnotation: Unit = {
{
val other = Annotation[Option[Other], CC].apply()
assert(other == Some(Other()))
assert(other.contains(Other()))

val last = Annotation[Option[Last], Something].apply()
assert(last == Some(Last(true)))
assert(last.contains(Last(true)))
}

{
val other: Option[Other] = Annotation[Option[Other], Something].apply()
assert(other == None)
assert(other.isEmpty)

val last: Option[Last] = Annotation[Option[Last], CC].apply()
assert(last == None)
assert(last.isEmpty)
}
}

Expand All @@ -128,6 +131,9 @@ class AnnotationTests {
val second: None.type :: None.type :: Some[Second] :: HNil = Annotations[Second, CC].apply()
assert(second == None :: None :: Some(Second(2, "b")) :: HNil)

val fourth: None.type :: None.type :: None.type :: Some[Fourth[Int]] :: HNil = Annotations[Fourth[_], CC3].apply()
assert(fourth == None :: None :: None :: Some(Fourth(4)) :: HNil)

val unused: None.type :: None.type :: None.type :: HNil = Annotations[Unused, CC].apply()
assert(unused == None :: None :: None :: HNil)

Expand All @@ -145,6 +151,9 @@ class AnnotationTests {
val second = Annotations[Second, CC].apply()
assert(second == None :: None :: Some(Second(2, "b")) :: HNil)

val fourth = Annotations[Fourth[_], CC3].apply()
assert(fourth == None :: None :: None :: Some(Fourth(4)) :: HNil)

val unused = Annotations[Unused, CC].apply()
assert(unused == None :: None :: None :: HNil)

Expand Down Expand Up @@ -172,6 +181,9 @@ class AnnotationTests {
val second: None.type :: None.type :: Some[Second] :: HNil = TypeAnnotations[Second, CC2].apply()
assert(second == None :: None :: Some(Second(2, "b")) :: HNil)

val fourth: None.type :: None.type :: None.type :: Some[Fourth[Int]] :: HNil = TypeAnnotations[Fourth[_], CC4].apply()
assert(fourth == None :: None :: None :: Some(Fourth(4)) :: HNil)

val unused: None.type :: None.type :: None.type :: HNil = TypeAnnotations[Unused, CC2].apply()
assert(unused == None :: None :: None :: HNil)
}
Expand All @@ -183,6 +195,9 @@ class AnnotationTests {
val second = TypeAnnotations[Second, CC2].apply()
assert(second == None :: None :: Some(Second(2, "b")) :: HNil)

val fourth = TypeAnnotations[Fourth[_], CC4].apply()
assert(fourth == None :: None :: None :: Some(Fourth(4)) :: HNil)

val unused = TypeAnnotations[Unused, CC2].apply()
assert(unused == None :: None :: None :: HNil)
}
Expand All @@ -192,14 +207,14 @@ class AnnotationTests {
def invalidTypeAnnotations: Unit = {
illTyped(" TypeAnnotations[Dummy, CC2] ", "could not find implicit value for parameter annotations: .*")
illTyped(" TypeAnnotations[Dummy, Base] ", "could not find implicit value for parameter annotations: .*")
illTyped(" TypeAnnotations[Second, Dummy] ", "could not find implicit value for parameter annotations: .*")
illTyped(" TypeAnnotations[Second, Dummy] ", "could not find implicit value for parameter annotations: .*")
}

@Test
def allAnnotations: Unit = {
val cc = AllAnnotations[CC3].apply()
typed[(First :: HNil) :: HNil :: (Second :: Third :: HNil) :: HNil](cc)
assert(cc == (First() :: HNil) :: HNil :: (Second(2, "b") :: Third('c') :: HNil) :: HNil)
typed[(First :: HNil) :: HNil :: (Second :: Third :: HNil) :: (Fourth[Int] :: HNil) :: HNil](cc)
assert(cc == (First() :: HNil) :: HNil :: (Second(2, "b") :: Third('c') :: HNil) :: (Fourth(4) :: HNil) :: HNil)

val st = AllAnnotations[Base].apply()
typed[(First :: HNil) :: (Second :: Third :: HNil) :: HNil](st)
Expand All @@ -211,12 +226,11 @@ class AnnotationTests {
typed[(First :: HNil) :: (Second :: Third :: HNil) :: HNil](st)

val cc = AllTypeAnnotations[CC4].apply() // case class
typed[(First :: HNil) :: HNil :: (Second :: Third :: HNil) :: HNil](cc)
assert(cc == (First() :: HNil) :: HNil :: (Second(2, "b") :: Third('c') :: HNil) :: HNil)
typed[(First :: HNil) :: HNil :: (Second :: Third :: HNil) :: (Fourth[Int] :: HNil) :: HNil](cc)
assert(cc == (First() :: HNil) :: HNil :: (Second(2, "b") :: Third('c') :: HNil) :: (Fourth(4) :: HNil) :: HNil)

val user = AllTypeAnnotations[User].apply() // type refs
typed[(First :: HNil) :: (Third :: HNil) :: HNil](user)
assert(user == (First() :: HNil) :: (Third('c') :: HNil) :: HNil)
}

}

0 comments on commit 8e2a6f2

Please # to comment.