diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/ConvertToNewScala3Syntax.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/ConvertToNewScala3Syntax.scala index b3ff3f24bb..ac78ce659c 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/ConvertToNewScala3Syntax.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/ConvertToNewScala3Syntax.scala @@ -32,38 +32,41 @@ private class ConvertToNewScala3Syntax(ftoks: FormatTokens) ft: FormatToken, style: ScalafmtConfig ): Option[Replacement] = Option { + val flag = style.rewrite.scala3.newSyntax + def left = ftoks.prevNonComment(ft).left ft.right match { - case _: Token.LeftParen if dialect.allowSignificantIndentation => + case _: Token.LeftParen + if flag.control && dialect.allowSignificantIndentation => ft.meta.rightOwner match { - case _: Term.If if ftoks.prevNonComment(ft).left.is[Token.KwIf] => + case _: Term.If if left.is[Token.KwIf] => removeToken - case _: Term.While - if ftoks.prevNonComment(ft).left.is[Token.KwWhile] => + case _: Term.While if left.is[Token.KwWhile] => removeToken - case _: Term.For if ftoks.prevNonComment(ft).left.is[Token.KwFor] => + case _: Term.For if left.is[Token.KwFor] => removeToken - case _: Term.ForYield - if ftoks.prevNonComment(ft).left.is[Token.KwFor] => + case _: Term.ForYield if left.is[Token.KwFor] => removeToken case _ => null } - case _: Token.Colon if dialect.allowPostfixStarVarargSplices => + case _: Token.Colon + if flag.deprecated && dialect.allowPostfixStarVarargSplices => ft.meta.rightOwner match { case t: Term.Repeated if isSimpleRepeated(t) => removeToken // trick: to get "*", just remove ":" and "_" case _ => null } - case _: Token.At if dialect.allowPostfixStarVarargSplices => + case _: Token.At + if flag.deprecated && dialect.allowPostfixStarVarargSplices => ft.meta.rightOwner match { case Pat.Bind(_, _: Pat.SeqWildcard) => removeToken // trick: to get "*", just remove "@" and "_" case _ => null } - case _: Token.Underscore => + case _: Token.Underscore if flag.deprecated => ft.meta.rightOwner match { case _: Importee.Wildcard if dialect.allowStarWildcardImport => replaceTokenIdent("*", ft.right) @@ -81,14 +84,15 @@ private class ConvertToNewScala3Syntax(ftoks: FormatTokens) case _ => null } - case _: Token.RightArrow if dialect.allowAsForImportRename => + case _: Token.RightArrow + if flag.deprecated && dialect.allowAsForImportRename => ft.meta.rightOwner match { case _: Importee.Rename | _: Importee.Unimport => replaceTokenIdent("as", ft.right) case _ => null } - case Token.Ident("*") => + case Token.Ident("*") if flag.deprecated => ft.meta.rightOwner match { case _: Type.AnonymousParam if dialect.allowUnderscoreAsTypePlaceholder => @@ -106,18 +110,19 @@ private class ConvertToNewScala3Syntax(ftoks: FormatTokens) ft: FormatToken, style: ScalafmtConfig ): Option[(Replacement, Replacement)] = Option { + def nextRight = ftoks.nextNonComment(ftoks.next(ft)).right ft.right match { case x: Token.RightParen if left.how eq ReplacementType.Remove => ft.meta.rightOwner match { case _: Term.If => - if (!ftoks.nextNonComment(ftoks.next(ft)).right.is[Token.KwThen]) + if (!nextRight.is[Token.KwThen]) replaceToken("then")( new Token.KwThen(x.input, x.dialect, x.start) ) else removeToken case _: Term.While | _: Term.For => - if (!ftoks.nextNonComment(ftoks.next(ft)).right.is[Token.KwDo]) + if (!nextRight.is[Token.KwDo]) replaceToken("do")(new Token.KwDo(x.input, x.dialect, x.start)) else removeToken case _ => null diff --git a/scalafmt-tests/src/test/resources/scala3/OptionalBraces.stat b/scalafmt-tests/src/test/resources/scala3/OptionalBraces.stat index 9eeb602ee0..d6933d6fc7 100644 --- a/scalafmt-tests/src/test/resources/scala3/OptionalBraces.stat +++ b/scalafmt-tests/src/test/resources/scala3/OptionalBraces.stat @@ -1235,8 +1235,8 @@ object a { object a: val a = if a then // scalafmt: { rewrite.scala3.newSyntax.control = false } - if aa then // scalafmt: { rewrite.scala3.newSyntax.control = true } - aaa + if (aa) // scalafmt: { rewrite.scala3.newSyntax.control = true } + aaa // c1 else b val a = @@ -1300,11 +1300,14 @@ object a: for (a <- b) yield for a <- b yield foo while a > 0 do while a > 0 do foo // scalafmt: { rewrite.scala3.newSyntax.control = false } - while a > 0 do while a > 0 do foo + while (a > 0) + while a > 0 do foo // scalafmt: { rewrite.scala3.newSyntax.control = true } for a <- b do for a <- b do foo // scalafmt: { rewrite.scala3.newSyntax.control = false } - for a <- b do for a <- b do foo + for (a <- b) do + for (a <- b) + foo // scalafmt: { rewrite.scala3.newSyntax.control = true } for a <- b do for a <- b do foo end a @@ -1412,7 +1415,7 @@ object a: import java as j import Predef.{augmentString as _} // scalafmt: { rewrite.scala3.newSyntax.deprecated = false } - import Predef.{augmentString as _} + import Predef.{augmentString => _} <<< rewrite to new syntax, imports, scala2-source3 rewrite.scala3.convertToNewSyntax = true rewrite.scala3.removeOptionalBraces = yes diff --git a/scalafmt-tests/src/test/resources/scala3/OptionalBraces_fold.stat b/scalafmt-tests/src/test/resources/scala3/OptionalBraces_fold.stat index 166ad4062a..00e0324847 100644 --- a/scalafmt-tests/src/test/resources/scala3/OptionalBraces_fold.stat +++ b/scalafmt-tests/src/test/resources/scala3/OptionalBraces_fold.stat @@ -1162,8 +1162,8 @@ object a { object a: val a = if a then // scalafmt: { rewrite.scala3.newSyntax.control = false } - if aa then // scalafmt: { rewrite.scala3.newSyntax.control = true } - aaa + if (aa) // scalafmt: { rewrite.scala3.newSyntax.control = true } + aaa // c1 else b val a = @@ -1227,11 +1227,11 @@ object a: for (a <- b) yield for a <- b yield foo while a > 0 do while a > 0 do foo // scalafmt: { rewrite.scala3.newSyntax.control = false } - while a > 0 do while a > 0 do foo + while (a > 0) while a > 0 do foo // scalafmt: { rewrite.scala3.newSyntax.control = true } for a <- b do for a <- b do foo // scalafmt: { rewrite.scala3.newSyntax.control = false } - for a <- b do for a <- b do foo + for (a <- b) do for (a <- b) foo // scalafmt: { rewrite.scala3.newSyntax.control = true } for a <- b do for a <- b do foo end a @@ -1339,7 +1339,7 @@ object a: import java as j import Predef.{augmentString as _} // scalafmt: { rewrite.scala3.newSyntax.deprecated = false } - import Predef.{augmentString as _} + import Predef.{augmentString => _} <<< rewrite to new syntax, imports, scala2-source3 rewrite.scala3.convertToNewSyntax = true rewrite.scala3.removeOptionalBraces = yes diff --git a/scalafmt-tests/src/test/resources/scala3/OptionalBraces_keep.stat b/scalafmt-tests/src/test/resources/scala3/OptionalBraces_keep.stat index 82c85ffe06..5fee9dd819 100644 --- a/scalafmt-tests/src/test/resources/scala3/OptionalBraces_keep.stat +++ b/scalafmt-tests/src/test/resources/scala3/OptionalBraces_keep.stat @@ -1242,8 +1242,8 @@ object a { object a: val a = if a then // scalafmt: { rewrite.scala3.newSyntax.control = false } - if aa then // scalafmt: { rewrite.scala3.newSyntax.control = true } - aaa + if (aa) // scalafmt: { rewrite.scala3.newSyntax.control = true } + aaa // c1 else b @@ -1314,17 +1314,17 @@ object a: while a > 0 do foo // scalafmt: { rewrite.scala3.newSyntax.control = false } - while a > 0 do - while a > 0 do - foo + while (a > 0) + while a > 0 do + foo // scalafmt: { rewrite.scala3.newSyntax.control = true } for a <- b do for a <- b do foo // scalafmt: { rewrite.scala3.newSyntax.control = false } - for a <- b do - for a <- b do - foo + for (a <- b) do + for (a <- b) + foo // scalafmt: { rewrite.scala3.newSyntax.control = true } for a <- b do for a <- b do @@ -1434,7 +1434,7 @@ object a: import java as j import Predef.{augmentString as _} // scalafmt: { rewrite.scala3.newSyntax.deprecated = false } - import Predef.{augmentString as _} + import Predef.{augmentString => _} <<< rewrite to new syntax, imports, scala2-source3 rewrite.scala3.convertToNewSyntax = true rewrite.scala3.removeOptionalBraces = yes diff --git a/scalafmt-tests/src/test/resources/scala3/OptionalBraces_unfold.stat b/scalafmt-tests/src/test/resources/scala3/OptionalBraces_unfold.stat index 154de34fe3..6cf57d131e 100644 --- a/scalafmt-tests/src/test/resources/scala3/OptionalBraces_unfold.stat +++ b/scalafmt-tests/src/test/resources/scala3/OptionalBraces_unfold.stat @@ -1321,8 +1321,8 @@ object a { object a: val a = if a then // scalafmt: { rewrite.scala3.newSyntax.control = false } - if aa then // scalafmt: { rewrite.scala3.newSyntax.control = true } - aaa + if (aa) // scalafmt: { rewrite.scala3.newSyntax.control = true } + aaa // c1 else b @@ -1396,9 +1396,9 @@ object a: foo end while // scalafmt: { rewrite.scala3.newSyntax.control = false } - while a > 0 do - while a > 0 do - foo + while (a > 0) + while a > 0 do + foo end while // scalafmt: { rewrite.scala3.newSyntax.control = true } for a <- b do @@ -1406,9 +1406,9 @@ object a: foo end for // scalafmt: { rewrite.scala3.newSyntax.control = false } - for a <- b do - for a <- b do - foo + for (a <- b) do + for (a <- b) + foo end for // scalafmt: { rewrite.scala3.newSyntax.control = true } for a <- b do @@ -1554,7 +1554,7 @@ object a: import java as j import Predef.{augmentString as _} // scalafmt: { rewrite.scala3.newSyntax.deprecated = false } - import Predef.{augmentString as _} + import Predef.{augmentString => _} <<< rewrite to new syntax, imports, scala2-source3 rewrite.scala3.convertToNewSyntax = true rewrite.scala3.removeOptionalBraces = yes