Skip to content

Commit

Permalink
Tests: use assertNoDiff with implicit location
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Jan 19, 2025
1 parent 9d52271 commit 0278345
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 56 deletions.
113 changes: 58 additions & 55 deletions tests/shared/src/main/scala/munit/BaseFrameworkSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,62 +24,65 @@ abstract class BaseFrameworkSuite extends BaseSuite {
else ex.getMessage().replace(BuildInfo.sourceDirectory.toString(), "")
.replace('\\', '/')

def check(t: FrameworkTest): Unit = test(t.cls.getSimpleName().withTags(t.tags)) {
val baos = new ByteArrayOutputStream()
val out = new PrintStream(baos)
val logger = new Logger {
def ansiCodesSupported(): Boolean = false
def error(x: String): Unit = out.println(x)
def warn(x: String): Unit = out.println(x)
def info(x: String): Unit = out.println(x)
def debug(x: String): Unit = () // ignore debugging output
def trace(x: Throwable): Unit = out.println(x)
}
val framework = new Framework
val runner = framework.runner(
t.arguments ++ Array("+l"), // use sbt loggers
Array(),
PlatformCompat.getThisClassLoader,
)
val tasks = runner.tasks(Array(
new TaskDef(t.cls.getName(), framework.munitFingerprint, false, Array())
))
val events = new StringBuilder()
val eventHandler = new EventHandler {
def handle(event: Event): Unit =
try {
events.append(t.onEvent(event))
val status = event.status().toString().toLowerCase()
val name = event.fullyQualifiedName()
events.append("==> ").append(status).append(" ").append(name)
if (event.throwable().isDefined()) events.append(" - ")
.append(exceptionMessage(event.throwable().get()))
events.append("\n")
} catch {
case NonFatal(e) =>
e.printStackTrace()
events.append(s"unexpected error: $e")
def check(t: FrameworkTest): Unit = {
import t.location
test(t.cls.getSimpleName.withTags(t.tags)) {
val baos = new ByteArrayOutputStream()
val out = new PrintStream(baos)
val logger = new Logger {
def ansiCodesSupported(): Boolean = false
def error(x: String): Unit = out.println(x)
def warn(x: String): Unit = out.println(x)
def info(x: String): Unit = out.println(x)
def debug(x: String): Unit = () // ignore debugging output
def trace(x: Throwable): Unit = out.println(x)
}
val framework = new Framework
val runner = framework.runner(
t.arguments ++ Array("+l"), // use sbt loggers
Array(),
PlatformCompat.getThisClassLoader,
)
val tasks = runner.tasks(Array(
new TaskDef(t.cls.getName(), framework.munitFingerprint, false, Array())
))
val events = new StringBuilder()
val eventHandler = new EventHandler {
def handle(event: Event): Unit =
try {
events.append(t.onEvent(event))
val status = event.status().toString().toLowerCase()
val name = event.fullyQualifiedName()
events.append("==> ").append(status).append(" ").append(name)
if (event.throwable().isDefined()) events.append(" - ")
.append(exceptionMessage(event.throwable().get()))
events.append("\n")
} catch {
case NonFatal(e) =>
e.printStackTrace()
events.append(s"unexpected error: $e")
}
}
implicit val ec = munitExecutionContext
val elapsedTimePattern = Pattern.compile(" ? \\d+\\.\\d+s ?")
TestingConsole.out = out
TestingConsole.err = out
for {
_ <- tasks.foldLeft(Future.successful(())) { case (base, task) =>
base.flatMap(_ =>
PlatformCompat.executeAsync(task, eventHandler, Array(logger))
)
}
}
implicit val ec = munitExecutionContext
val elapsedTimePattern = Pattern.compile(" ? \\d+\\.\\d+s ?")
TestingConsole.out = out
TestingConsole.err = out
for {
_ <- tasks.foldLeft(Future.successful(())) { case (base, task) =>
base.flatMap(_ =>
PlatformCompat.executeAsync(task, eventHandler, Array(logger))
)
} yield {
val stdout = AnsiColors
.filterAnsi(baos.toString(StandardCharsets.UTF_8.name()))
val obtained = AnsiColors.filterAnsi(t.format match {
case SbtFormat => events.toString().replace("\"\"\"", "'''")
case StdoutFormat => elapsedTimePattern.matcher(stdout)
.replaceAll(" <elapsed time>")
})
assertNoDiff(obtained, t.expected, stdout)
}
} yield {
val stdout = AnsiColors
.filterAnsi(baos.toString(StandardCharsets.UTF_8.name()))
val obtained = AnsiColors.filterAnsi(t.format match {
case SbtFormat => events.toString().replace("\"\"\"", "'''")
case StdoutFormat => elapsedTimePattern.matcher(stdout)
.replaceAll(" <elapsed time>")
})
assertNoDiff(obtained, t.expected, stdout)(t.location)
}
}(t.location)
}
}
2 changes: 1 addition & 1 deletion tests/shared/src/test/scala/munit/TypeCheckSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TypeCheckSuite extends FunSuite {
val expected = compat.get(BuildInfo.scalaVersion)
.orElse(compat.get(binaryVersion)).orElse(compat.get(majorVersion))
.getOrElse(compat(BuildInfo.scalaVersion))
assertNoDiff(obtained, expected)(loc)
assertNoDiff(obtained, expected)
}

val msg = "Hello"
Expand Down

0 comments on commit 0278345

Please # to comment.