diff --git a/core/src/main/scala/shapeless/poly.scala b/core/src/main/scala/shapeless/poly.scala index b5a350617..6dae66e96 100644 --- a/core/src/main/scala/shapeless/poly.scala +++ b/core/src/main/scala/shapeless/poly.scala @@ -113,6 +113,59 @@ object PolyDefns extends Cases { } } + final case class BindFirst[F, Head](head: Head) extends Poly + + object BindFirst { + implicit def bindFirstCase[BF, F, Head, Tail <: HList, Result0]( + implicit + unpack2: BF <:< BindFirst[F, Head], + witnessBF: Witness.Aux[BF], + finalCall: Case.Aux[F, Head :: Tail, Result0] + ): Case.Aux[BF, Tail, Result0] = new Case[BF, Tail] { + type Result = Result0 + val value: Tail => Result = { tail: Tail => + finalCall.value(witnessBF.value.head :: tail) + } + } + } + + final case class Curried[F, ParameterAccumulator <: HList](parameters: ParameterAccumulator) extends Poly1 + + private[PolyDefns] sealed trait LowPriorityCurried { + implicit def partialApplied[ + Self, + F, + ParameterAccumulator <: HList, + CurrentParameter, + AllParameters <: HList, + RestParameters <: HList, + CurrentLength <: Nat + ](implicit + constraint: Self <:< Curried[F, ParameterAccumulator], + witnessSelf: Witness.Aux[Self], + finalCall: Case[F, AllParameters], + length: ops.hlist.Length.Aux[CurrentParameter :: ParameterAccumulator, CurrentLength], + reverseSplit: ops.hlist.ReverseSplit.Aux[AllParameters, CurrentLength, CurrentParameter :: ParameterAccumulator, RestParameters], + hasRestParameters: RestParameters <:< (_ :: _) + ): Case1.Aux[Self, CurrentParameter, Curried[F, CurrentParameter :: ParameterAccumulator]] = Case1 { + nextParameter: CurrentParameter => + Curried[F, CurrentParameter :: ParameterAccumulator](nextParameter :: witnessSelf.value.parameters) + } + } + + object Curried extends LowPriorityCurried { + implicit def lastParameter[Self, F, LastParameter, ParameterAccumulator <: HList, AllParameters <: HList, Result0]( + implicit + constraint: Self <:< Curried[F, ParameterAccumulator], + witnessSelf: Witness.Aux[Self], + reverse: ops.hlist.Reverse.Aux[LastParameter :: ParameterAccumulator, AllParameters], + finalCall: Case.Aux[F, AllParameters, Result0] + ): Case1.Aux[Self, LastParameter, Result0] = Case1 { + lastParameter: LastParameter => + finalCall(reverse(lastParameter :: witnessSelf.value.parameters)) + } + } + /** * Base class for lifting a `Function1` to a `Poly1` */ @@ -246,6 +299,13 @@ trait Poly extends PolyApply with Serializable { */ object Poly extends PolyInst { implicit def inst0(p: Poly)(implicit cse : p.ProductCase[HNil]) : cse.Result = cse() + + import PolyDefns._ + + final def bindFirst[Head](p: Poly, head: Head): BindFirst[p.type, Head] = new BindFirst[p.type, Head](head) + + final def curried(p: Poly): Curried[p.type, HNil] = new Curried[p.type, HNil](HNil) + } /** diff --git a/core/src/test/scala/shapeless/poly.scala b/core/src/test/scala/shapeless/poly.scala index f385414fc..01b853351 100644 --- a/core/src/test/scala/shapeless/poly.scala +++ b/core/src/test/scala/shapeless/poly.scala @@ -439,4 +439,38 @@ class PolyTests { assertTypedEquals[Int](16, r) } + + @Test + def testBindFirst: Unit = { + object p extends Poly3 { + implicit def x = at[Int, String, Double] { (i, s, d) => + s"$i, $d, $s" + } + } + + val bf = Poly.bindFirst(p, 2) + val r = bf("bar", 3.5) + assertTypedEquals[String]("2, 3.5, bar", r) + + val l = 1.0 :: 2.0 :: 3.0 :: HNil + assertTypedEquals[String]("2, 3.0, 2, 2.0, 2, 1.0, x", l.foldLeft("x")(bf)) + } + + @Test + def testCurried: Unit = { + object p extends Poly3 { + implicit def x = at[Int, Double, String] { (i, d, s) => + s"$i, $d, $s" + } + } + + val c = Poly.curried(p) + val c1 = c(1) + val c2 = c1(42.5) + val r = c2("foo") + assertTypedEquals[String]("1, 42.5, foo", r) + + val l = "x" :: "y" :: "z" :: HNil + assertEquals("1, 42.5, x" :: "1, 42.5, y" :: "1, 42.5, z" :: HNil, l.map(c2)) + } }