diff --git a/core/src/main/scala/kafka/tools/StorageTool.scala b/core/src/main/scala/kafka/tools/StorageTool.scala index 43eb6579765f9..6dc4653961408 100644 --- a/core/src/main/scala/kafka/tools/StorageTool.scala +++ b/core/src/main/scala/kafka/tools/StorageTool.scala @@ -130,6 +130,8 @@ object StorageTool extends Logging { if (namespace.getBoolean("standalone")) { formatter.setInitialVoters(createStandaloneDynamicVoters(config)) } + Option(namespace.getList("add_scram")). + foreach(scramArgs => formatter.setScramArguments(scramArgs.asInstanceOf[util.List[String]])) configToLogDirectories(config).foreach(formatter.addDirectory(_)) formatter.run() } diff --git a/core/src/test/scala/unit/kafka/tools/StorageToolTest.scala b/core/src/test/scala/unit/kafka/tools/StorageToolTest.scala index 3a8f064b1df10..83b21b43fbaba 100644 --- a/core/src/test/scala/unit/kafka/tools/StorageToolTest.scala +++ b/core/src/test/scala/unit/kafka/tools/StorageToolTest.scala @@ -21,12 +21,14 @@ import java.io.{ByteArrayOutputStream, File, PrintStream} import java.nio.charset.StandardCharsets import java.nio.file.Files import java.util -import java.util.Properties +import java.util.{Optional, Properties} import kafka.server.KafkaConfig import kafka.utils.TestUtils import net.sourceforge.argparse4j.inf.ArgumentParserException +import org.apache.kafka.common.metadata.UserScramCredentialRecord import org.apache.kafka.common.utils.Utils import org.apache.kafka.server.common.Features +import org.apache.kafka.metadata.bootstrap.BootstrapDirectory import org.apache.kafka.metadata.properties.{MetaPropertiesEnsemble, PropertiesUtils} import org.apache.kafka.metadata.storage.FormatterException import org.apache.kafka.raft.QuorumConfig @@ -37,6 +39,7 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ValueSource import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters.IterableHasAsScala @Timeout(value = 40) class StorageToolTest { @@ -433,5 +436,49 @@ Found problem: contains("Formatting dynamic metadata voter directory %s".format(availableDirs.head)), "Failed to find content in output: " + stream.toString()) } -} + @Test + def testBootstrapScramRecords(): Unit = { + val availableDirs = Seq(TestUtils.tempDir()) + val properties = new Properties() + properties.putAll(defaultDynamicQuorumProperties) + properties.setProperty("log.dirs", availableDirs.mkString(",")) + val stream = new ByteArrayOutputStream() + val arguments = ListBuffer[String]( + "--release-version", "3.9-IV0", + "--add-scram", "SCRAM-SHA-512=[name=alice,password=changeit]", + "--add-scram", "SCRAM-SHA-512=[name=bob,password=changeit]" + ) + + assertEquals(0, runFormatCommand(stream, properties, arguments.toSeq)) + + // Not doing full SCRAM record validation since that's covered elsewhere. + // Just checking that we generate the correct number of records + val bootstrapMetadata = new BootstrapDirectory(availableDirs.head.toString, Optional.empty).read + val scramRecords = bootstrapMetadata.records().asScala + .filter(apiMessageAndVersion => apiMessageAndVersion.message().isInstanceOf[UserScramCredentialRecord]) + .map(apiMessageAndVersion => apiMessageAndVersion.message().asInstanceOf[UserScramCredentialRecord]) + .toList + assertEquals(2, scramRecords.size) + assertEquals("alice", scramRecords.head.name()) + assertEquals("bob", scramRecords.last.name()) + } + + @Test + def testScramRecordsOldReleaseVersion(): Unit = { + val availableDirs = Seq(TestUtils.tempDir()) + val properties = new Properties() + properties.putAll(defaultDynamicQuorumProperties) + properties.setProperty("log.dirs", availableDirs.mkString(",")) + val stream = new ByteArrayOutputStream() + val arguments = ListBuffer[String]( + "--release-version", "3.4", + "--add-scram", "SCRAM-SHA-512=[name=alice,password=changeit]", + "--add-scram", "SCRAM-SHA-512=[name=bob,password=changeit]" + ) + + assertEquals( + "SCRAM is only supported in metadata.version 3.5-IV2 or later.", + assertThrows(classOf[FormatterException], () => runFormatCommand(stream, properties, arguments.toSeq)).getMessage) + } +}