Skip to content

Make object mocking thread safe #312

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 5 commits into from
Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,15 @@ lazy val commonSettings =
Nil
}
},
Test / scalacOptions += "-Ywarn-value-discard"
Test / scalacOptions += "-Ywarn-value-discard",
libraryDependencies ++= {
CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, major)) if major <= 12 =>
Seq()
case _ =>
Seq("org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.0-RC1")
}
}
)

lazy val publishSettings = Seq(
Expand Down
40 changes: 28 additions & 12 deletions common/src/main/scala/org/mockito/MockitoAPI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ import org.mockito.ReflectionUtils.InvocationOnMockOps
import org.mockito.internal.configuration.plugins.Plugins.getMockMaker
import org.mockito.internal.creation.MockSettingsImpl
import org.mockito.internal.exceptions.Reporter.notAMockPassedToVerifyNoMoreInteractions
import org.mockito.internal.handler.ScalaMockHandler
import org.mockito.internal.handler.{ ScalaMockHandler, ThreadAwareMockHandler }
import org.mockito.internal.progress.ThreadSafeMockingProgress.mockingProgress
import org.mockito.internal.stubbing.answers.ScalaThrowsException
import org.mockito.internal.util.MockUtil
import org.mockito.internal.util.reflection.LenientCopyTool
import org.mockito.internal.{ ValueClassExtractor, ValueClassWrapper }
import org.mockito.invocation.InvocationOnMock
import org.mockito.invocation.{ Invocation, InvocationContainer, InvocationOnMock, MockHandler }
import org.mockito.mock.MockCreationSettings
import org.mockito.stubbing._
import org.mockito.verification.{ VerificationAfterDelay, VerificationMode, VerificationWithTimeout }
Expand Down Expand Up @@ -472,7 +472,8 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
* <code>verify(aMock).iHaveSomeDefaultArguments("I'm not gonna pass the second argument", "default value")</code>
* as the value for the second parameter would have been null...
*/
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](implicit defaultAnswer: DefaultAnswer, $pt: Prettifier): T = mock(withSettings)
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](implicit defaultAnswer: DefaultAnswer, $pt: Prettifier): T =
createMock(withSettings)

/**
* Delegates to <code>Mockito.mock(type: Class[T], defaultAnswer: Answer[_])</code>
Expand All @@ -489,7 +490,7 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
* as the value for the second parameter would have been null...
*/
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](defaultAnswer: DefaultAnswer)(implicit $pt: Prettifier): T =
mock(withSettings(defaultAnswer))
createMock(withSettings(defaultAnswer))

/**
* Delegates to <code>Mockito.mock(type: Class[T], mockSettings: MockSettings)</code>
Expand All @@ -505,7 +506,13 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
* <code>verify(aMock).iHaveSomeDefaultArguments("I'm not gonna pass the second argument", "default value")</code>
* as the value for the second parameter would have been null...
*/
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](mockSettings: MockSettings)(implicit $pt: Prettifier): T = {
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](mockSettings: MockSettings)(implicit $pt: Prettifier): T =
createMock(mockSettings)

private def createMock[T <: AnyRef: ClassTag: WeakTypeTag](
mockSettings: MockSettings,
mockHandler: (MockCreationSettings[T], Prettifier) => MockHandler[T] = (settings: MockCreationSettings[T], pt: Prettifier) => ScalaMockHandler(settings)(pt)
)(implicit $pt: Prettifier): T = {
val interfaces = ReflectionUtils.extraInterfaces

val realClass: Class[T] = mockSettings match {
Expand All @@ -520,7 +527,7 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
else mockSettings

def createMock(settings: MockCreationSettings[T]): T = {
val mock = getMockMaker.createMock(settings, ScalaMockHandler(settings))
val mock = getMockMaker.createMock(settings, mockHandler(settings, $pt))
val spiedInstance = settings.getSpiedInstance
if (spiedInstance != null) new LenientCopyTool().copyToMock(spiedInstance, mock)
mock
Expand Down Expand Up @@ -620,12 +627,21 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
/**
* Mocks the specified object only for the context of the block
*/
def withObjectMocked[O <: AnyRef: ClassTag](block: => Any): Unit = {
val moduleField = clazz[O].getDeclaredField("MODULE$")
val realImpl = moduleField.get(null)
ReflectionUtils.setFinalStatic(moduleField, mock[O])
try block
finally ReflectionUtils.setFinalStatic(moduleField, realImpl)
def withObjectMocked[O <: AnyRef: ClassTag](block: => Any)(implicit defaultAnswer: DefaultAnswer, $pt: Prettifier): Unit = {
val objectClass = clazz[O]
objectClass.synchronized {
val moduleField = objectClass.getDeclaredField("MODULE$")
val realImpl = moduleField.get(null)

val threadAwareMock = createMock(
withSettings(defaultAnswer),
(settings: MockCreationSettings[O], pt: Prettifier) => ThreadAwareMockHandler(settings)(pt)
)

ReflectionUtils.setFinalStatic(moduleField, threadAwareMock)
try block
finally ReflectionUtils.setFinalStatic(moduleField, realImpl)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion common/src/main/scala/org/mockito/ReflectionUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ object ReflectionUtils {
.getOrElse(Seq.empty)
}

def setFinalStatic(field: Field, newValue: Any) = {
def setFinalStatic(field: Field, newValue: Any): Unit = {
field.setAccessible(true)
val modifiersField = classOf[Field].getDeclaredField("modifiers")
modifiersField.setAccessible(true)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.mockito.internal.handler

import org.mockito.invocation.{ Invocation, InvocationContainer, MockHandler }
import org.mockito.mock.MockCreationSettings
import org.scalactic.Prettifier

class ThreadAwareMockHandler[T](settings: MockCreationSettings[T])(implicit $pt: Prettifier) extends MockHandler[T] {
private val currentThread = Thread.currentThread()
private val delegate = ScalaMockHandler(settings)

override def handle(invocation: Invocation): AnyRef =
if (Thread.currentThread() == currentThread) delegate.handle(invocation)
else invocation.callRealMethod()

override def getMockSettings: MockCreationSettings[T] = delegate.getMockSettings

override def getInvocationContainer: InvocationContainer = delegate.getInvocationContainer
}

object ThreadAwareMockHandler {
def apply[T](settings: MockCreationSettings[T])(implicit $pt: Prettifier): ThreadAwareMockHandler[T] =
new ThreadAwareMockHandler(settings)($pt)
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package user.org.mockito

import java.lang.reflect.{ Field, Modifier }
import java.util.concurrent.atomic.AtomicInteger

import org.mockito.invocation.InvocationOnMock
import org.mockito.{ clazz, ArgumentMatchersSugar, IdiomaticStubbing }
import org.mockito.{ ArgumentMatchersSugar, IdiomaticStubbing }
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import user.org.mockito.matchers.{ ValueCaseClassInt, ValueCaseClassString, ValueClass }

import scala.reflect.ClassTag
import scala.collection.parallel.immutable
import scala.concurrent.{ Await, Future }

class IdiomaticStubbingTest extends AnyWordSpec with Matchers with ArgumentMatchersSugar with IdiomaticMockitoTestSetup with IdiomaticStubbing {

Expand Down Expand Up @@ -313,5 +313,25 @@ class IdiomaticStubbingTest extends AnyWordSpec with Matchers with ArgumentMatch

FooObject.simpleMethod shouldBe "not mocked!"
}

"object stubbing should be thread safe" in {
immutable.ParSeq.range(1, 100).foreach { i =>
withObjectMocked[FooObject.type] {
FooObject.simpleMethod returns s"mocked!-$i"
FooObject.simpleMethod shouldBe s"mocked!-$i"
}
}
}

"object stubbing should be thread safe 2" in {
immutable.ParSeq.range(1, 100).foreach { i =>
if (i % 2 != 0)
withObjectMocked[FooObject.type] {
FooObject.simpleMethod returns s"mocked!-$i"
FooObject.simpleMethod shouldBe s"mocked!-$i"
}
else FooObject.simpleMethod shouldBe "not mocked!"
}
}
}
}