Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Call Subscriber.onSubscribe(StreamSubscription) before returning from StreamPublisher.subscribe(Subscriber) #3360

Merged
merged 3 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions core/shared/src/main/scala/fs2/interop/flow/StreamPublisher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import cats.effect.std.Dispatcher
import cats.effect.unsafe.IORuntime

import java.util.Objects.requireNonNull
import java.util.concurrent.Flow.{Publisher, Subscriber, Subscription}
import java.util.concurrent.Flow.{Publisher, Subscriber}
import java.util.concurrent.RejectedExecutionException
import scala.util.control.NoStackTrace

Expand All @@ -47,23 +47,19 @@ private[flow] sealed abstract class StreamPublisher[F[_], A] private (
)(implicit
F: Async[F]
) extends Publisher[A] {
protected def runSubscription(subscribe: F[Unit]): Unit
protected def runSubscription(run: F[Unit]): Unit

override final def subscribe(subscriber: Subscriber[_ >: A]): Unit = {
requireNonNull(
subscriber,
"The subscriber provided to subscribe must not be null"
)
val subscription = StreamSubscription(stream, subscriber)
subscriber.onSubscribe(subscription)
try
runSubscription(
StreamSubscription.subscribe(stream, subscriber)
)
runSubscription(subscription.run)
catch {
case _: IllegalStateException | _: RejectedExecutionException =>
subscriber.onSubscribe(new Subscription {
override def cancel(): Unit = ()
override def request(x$1: Long): Unit = ()
})
subscriber.onError(StreamPublisher.CanceledStreamPublisherException)
}
}
Expand All @@ -72,30 +68,30 @@ private[flow] sealed abstract class StreamPublisher[F[_], A] private (
private[flow] object StreamPublisher {
private final class DispatcherStreamPublisher[F[_], A](
stream: Stream[F, A],
startDispatcher: Dispatcher[F]
dispatcher: Dispatcher[F]
)(implicit
F: Async[F]
) extends StreamPublisher[F, A](stream) {
override protected final def runSubscription(subscribe: F[Unit]): Unit =
startDispatcher.unsafeRunAndForget(subscribe)
override protected final def runSubscription(run: F[Unit]): Unit =
dispatcher.unsafeRunAndForget(run)
}

private final class IORuntimeStreamPublisher[A](
stream: Stream[IO, A]
)(implicit
runtime: IORuntime
) extends StreamPublisher[IO, A](stream) {
override protected final def runSubscription(subscribe: IO[Unit]): Unit =
subscribe.unsafeRunAndForget()
override protected final def runSubscription(run: IO[Unit]): Unit =
run.unsafeRunAndForget()
}

def apply[F[_], A](
stream: Stream[F, A]
)(implicit
F: Async[F]
): Resource[F, StreamPublisher[F, A]] =
Dispatcher.parallel[F](await = true).map { startDispatcher =>
new DispatcherStreamPublisher(stream, startDispatcher)
Dispatcher.parallel[F](await = true).map { dispatcher =>
new DispatcherStreamPublisher(stream, dispatcher)
}

def unsafe[A](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
*/
private[flow] final class StreamSubscription[F[_], A] private (
stream: Stream[F, A],
sub: Subscriber[A],
subscriber: Subscriber[A],
requests: AtomicLong,
resume: AtomicReference[() => Unit],
canceled: AtomicReference[() => Unit]
Expand All @@ -50,12 +50,12 @@ private[flow] final class StreamSubscription[F[_], A] private (
// Ensure we are on a terminal state; i.e. call `cancel`, before signaling the subscriber.
private def onError(ex: Throwable): Unit = {
cancel()
sub.onError(ex)
subscriber.onError(ex)
}

private def onComplete(): Unit = {
cancel()
sub.onComplete()
subscriber.onComplete()
}

// This is a def rather than a val, because it is only used once.
Expand Down Expand Up @@ -100,7 +100,7 @@ private[flow] final class StreamSubscription[F[_], A] private (
stream
.through(subscriptionPipe)
.chunks
.foreach(chunk => F.delay(chunk.foreach(sub.onNext)))
.foreach(chunk => F.delay(chunk.foreach(subscriber.onNext)))
.compile
.drain

Expand Down Expand Up @@ -173,28 +173,37 @@ private[flow] final class StreamSubscription[F[_], A] private (
private[flow] object StreamSubscription {
private final val Sentinel = () => ()

// Mostly for testing purposes.
def apply[F[_], A](stream: Stream[F, A], subscriber: Subscriber[A])(implicit
// UNSAFE + SIDE-EFFECTING!
// We cannot wrap this constructor in F[_].
// Some consumers expect we call Subscriber.onSubscribe(StreamSubscription)
// before returning from StreamPublisher.subscribe(Subscriber).
// See https://github.com/typelevel/fs2/issues/3359
def apply[F[_], A](
stream: Stream[F, A],
subscriber: Subscriber[A]
)(implicit
F: Async[F]
): F[StreamSubscription[F, A]] =
F.delay {
val requests = new AtomicLong(0L)
val resume = new AtomicReference(Sentinel)
val canceled = new AtomicReference(Sentinel)

new StreamSubscription(
stream,
subscriber,
requests,
resume,
canceled
)
}
): StreamSubscription[F, A] = {
val requests = new AtomicLong(0L)
val resume = new AtomicReference(Sentinel)
val canceled = new AtomicReference(Sentinel)

new StreamSubscription(
stream,
subscriber,
requests,
resume,
canceled
)
}

def subscribe[F[_], A](stream: Stream[F, A], subscriber: Subscriber[A])(implicit
def subscribe[F[_], A](
stream: Stream[F, A],
subscriber: Subscriber[A]
)(implicit
F: Async[F]
): F[Unit] =
apply(stream, subscriber).flatMap { subscription =>
F.delay(apply(stream, subscriber)).flatMap { subscription =>
F.delay(subscriber.onSubscribe(subscription)) >>
subscription.run
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CancellationSpec extends Fs2Suite {
IO(new AtomicBoolean(false))
.flatMap { flag =>
val subscriber = new DummySubscriber(flag, program)
StreamSubscription(s, subscriber).flatMap { subscription =>
IO(StreamSubscription(s, subscriber)).flatMap { subscription =>
(
subscription.run,
IO(subscriber.onSubscribe(subscription))
Expand Down