diff --git a/tests/shared/src/main/scala/munit/BaseFrameworkSuite.scala b/tests/shared/src/main/scala/munit/BaseFrameworkSuite.scala index 308a2301..386bc919 100644 --- a/tests/shared/src/main/scala/munit/BaseFrameworkSuite.scala +++ b/tests/shared/src/main/scala/munit/BaseFrameworkSuite.scala @@ -24,62 +24,65 @@ abstract class BaseFrameworkSuite extends BaseSuite { else ex.getMessage().replace(BuildInfo.sourceDirectory.toString(), "") .replace('\\', '/') - def check(t: FrameworkTest): Unit = test(t.cls.getSimpleName().withTags(t.tags)) { - val baos = new ByteArrayOutputStream() - val out = new PrintStream(baos) - val logger = new Logger { - def ansiCodesSupported(): Boolean = false - def error(x: String): Unit = out.println(x) - def warn(x: String): Unit = out.println(x) - def info(x: String): Unit = out.println(x) - def debug(x: String): Unit = () // ignore debugging output - def trace(x: Throwable): Unit = out.println(x) - } - val framework = new Framework - val runner = framework.runner( - t.arguments ++ Array("+l"), // use sbt loggers - Array(), - PlatformCompat.getThisClassLoader, - ) - val tasks = runner.tasks(Array( - new TaskDef(t.cls.getName(), framework.munitFingerprint, false, Array()) - )) - val events = new StringBuilder() - val eventHandler = new EventHandler { - def handle(event: Event): Unit = - try { - events.append(t.onEvent(event)) - val status = event.status().toString().toLowerCase() - val name = event.fullyQualifiedName() - events.append("==> ").append(status).append(" ").append(name) - if (event.throwable().isDefined()) events.append(" - ") - .append(exceptionMessage(event.throwable().get())) - events.append("\n") - } catch { - case NonFatal(e) => - e.printStackTrace() - events.append(s"unexpected error: $e") + def check(t: FrameworkTest): Unit = { + import t.location + test(t.cls.getSimpleName.withTags(t.tags)) { + val baos = new ByteArrayOutputStream() + val out = new PrintStream(baos) + val logger = new Logger { + def ansiCodesSupported(): Boolean = false + def error(x: String): Unit = out.println(x) + def warn(x: String): Unit = out.println(x) + def info(x: String): Unit = out.println(x) + def debug(x: String): Unit = () // ignore debugging output + def trace(x: Throwable): Unit = out.println(x) + } + val framework = new Framework + val runner = framework.runner( + t.arguments ++ Array("+l"), // use sbt loggers + Array(), + PlatformCompat.getThisClassLoader, + ) + val tasks = runner.tasks(Array( + new TaskDef(t.cls.getName(), framework.munitFingerprint, false, Array()) + )) + val events = new StringBuilder() + val eventHandler = new EventHandler { + def handle(event: Event): Unit = + try { + events.append(t.onEvent(event)) + val status = event.status().toString().toLowerCase() + val name = event.fullyQualifiedName() + events.append("==> ").append(status).append(" ").append(name) + if (event.throwable().isDefined()) events.append(" - ") + .append(exceptionMessage(event.throwable().get())) + events.append("\n") + } catch { + case NonFatal(e) => + e.printStackTrace() + events.append(s"unexpected error: $e") + } + } + implicit val ec = munitExecutionContext + val elapsedTimePattern = Pattern.compile(" ? \\d+\\.\\d+s ?") + TestingConsole.out = out + TestingConsole.err = out + for { + _ <- tasks.foldLeft(Future.successful(())) { case (base, task) => + base.flatMap(_ => + PlatformCompat.executeAsync(task, eventHandler, Array(logger)) + ) } - } - implicit val ec = munitExecutionContext - val elapsedTimePattern = Pattern.compile(" ? \\d+\\.\\d+s ?") - TestingConsole.out = out - TestingConsole.err = out - for { - _ <- tasks.foldLeft(Future.successful(())) { case (base, task) => - base.flatMap(_ => - PlatformCompat.executeAsync(task, eventHandler, Array(logger)) - ) + } yield { + val stdout = AnsiColors + .filterAnsi(baos.toString(StandardCharsets.UTF_8.name())) + val obtained = AnsiColors.filterAnsi(t.format match { + case SbtFormat => events.toString().replace("\"\"\"", "'''") + case StdoutFormat => elapsedTimePattern.matcher(stdout) + .replaceAll(" ") + }) + assertNoDiff(obtained, t.expected, stdout) } - } yield { - val stdout = AnsiColors - .filterAnsi(baos.toString(StandardCharsets.UTF_8.name())) - val obtained = AnsiColors.filterAnsi(t.format match { - case SbtFormat => events.toString().replace("\"\"\"", "'''") - case StdoutFormat => elapsedTimePattern.matcher(stdout) - .replaceAll(" ") - }) - assertNoDiff(obtained, t.expected, stdout)(t.location) } - }(t.location) + } } diff --git a/tests/shared/src/test/scala/munit/TypeCheckSuite.scala b/tests/shared/src/test/scala/munit/TypeCheckSuite.scala index f4a44336..beea01ec 100644 --- a/tests/shared/src/test/scala/munit/TypeCheckSuite.scala +++ b/tests/shared/src/test/scala/munit/TypeCheckSuite.scala @@ -14,7 +14,7 @@ class TypeCheckSuite extends FunSuite { val expected = compat.get(BuildInfo.scalaVersion) .orElse(compat.get(binaryVersion)).orElse(compat.get(majorVersion)) .getOrElse(compat(BuildInfo.scalaVersion)) - assertNoDiff(obtained, expected)(loc) + assertNoDiff(obtained, expected) } val msg = "Hello"