From c53d130ac3cd07e55a1b71678af9dcdffa1a38ff Mon Sep 17 00:00:00 2001 From: 1amageek Date: Thu, 3 Oct 2024 11:22:24 +0900 Subject: [PATCH 01/27] Update dependencies and fix language detection typo --- Package.resolved | 17 +++++++++++++---- Package.swift | 4 ++-- .../WhisperKit/Core/Audio/AudioChunker.swift | 2 +- .../WhisperKit/Core/Audio/AudioProcessor.swift | 2 +- Sources/WhisperKit/Core/WhisperKit.swift | 4 ++-- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/Package.resolved b/Package.resolved index 6cccf25..87fb996 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,12 +1,21 @@ { "pins" : [ + { + "identity" : "jinja", + "kind" : "remoteSourceControl", + "location" : "https://github.com/maiqingqiang/Jinja", + "state" : { + "revision" : "4ffa95ce02e013c992287e19e3bbd620b6cc233a", + "version" : "1.0.4" + } + }, { "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-argument-parser.git", "state" : { - "revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41", - "version" : "1.3.0" + "revision" : "41982a3656a71c768319979febd796c6fd111d5c", + "version" : "1.5.0" } }, { @@ -14,8 +23,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe", - "version" : "0.1.7" + "revision" : "0f2306713d48a75b862026ebb291926793773f52", + "version" : "0.1.12" } } ], diff --git a/Package.swift b/Package.swift index f3f111e..3515d89 100644 --- a/Package.swift +++ b/Package.swift @@ -20,8 +20,8 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"), - .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), + .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.12"), + .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.5.0"), ], targets: [ .target( diff --git a/Sources/WhisperKit/Core/Audio/AudioChunker.swift b/Sources/WhisperKit/Core/Audio/AudioChunker.swift index 467bfd6..7f56f8c 100644 --- a/Sources/WhisperKit/Core/Audio/AudioChunker.swift +++ b/Sources/WhisperKit/Core/Audio/AudioChunker.swift @@ -82,7 +82,7 @@ open class VADAudioChunker: AudioChunking { var startIndex = seekClipStart while startIndex < seekClipEnd - windowPadding { let currentFrameLength = startIndex - seekClipStart - if startIndex >= currentFrameLength, startIndex < 0 { + if startIndex >= currentFrameLength || startIndex < 0 { throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size") } diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index c3958cb..bdd54ce 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -95,7 +95,7 @@ public extension AudioProcessing { static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? { let currentFrameLength = audioArray.count - if startIndex >= currentFrameLength, startIndex < 0 { + if startIndex >= currentFrameLength || startIndex < 0 { Logging.error("startIndex is outside the buffer size") return nil } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index c1b66d5..f5a2cd2 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -417,14 +417,14 @@ open class WhisperKit { ) async throws -> (language: String, langProbs: [String: Float]) { let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer) - return try await detectLangauge(audioArray: audioArray) + return try await detectLanguage(audioArray: audioArray) } /// Detects the language of the audio samples in the provided array. /// /// - Parameter audioArray: An array of audio samples. /// - Returns: A tuple containing the detected language and the language log probabilities. - open func detectLangauge( + open func detectLanguage( audioArray: [Float] ) async throws -> (language: String, langProbs: [String: Float]) { if modelState != .loaded { From 26083402ec3b6b0cab984af5e1b10b3e3443561d Mon Sep 17 00:00:00 2001 From: 1amageek Date: Thu, 3 Oct 2024 12:12:00 +0900 Subject: [PATCH 02/27] Update AudioEncoder shape access and add tokenizer methods --- Sources/WhisperKit/Core/AudioEncoder.swift | 6 ++---- Sources/WhisperKit/Core/Models.swift | 13 +++++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/Sources/WhisperKit/Core/AudioEncoder.swift b/Sources/WhisperKit/Core/AudioEncoder.swift index 06337cd..c9c9358 100644 --- a/Sources/WhisperKit/Core/AudioEncoder.swift +++ b/Sources/WhisperKit/Core/AudioEncoder.swift @@ -22,16 +22,14 @@ public class AudioEncoder: AudioEncoding, WhisperMLModel { guard let inputDescription = model?.modelDescription.outputDescriptionsByName["encoder_output_embeds"] else { return nil } guard inputDescription.type == .multiArray else { return nil } guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } - let shape = shapeConstraint.shape.map { $0.intValue } - return shape[1] + return shapeConstraint.shape[0].intValue } public var sequenceLength: Int? { guard let inputDescription = model?.modelDescription.outputDescriptionsByName["encoder_output_embeds"] else { return nil } guard inputDescription.type == .multiArray else { return nil } guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } - let shape = shapeConstraint.shape.map { $0.intValue } - return shape[3] + return shapeConstraint.shape[1].intValue } public init() {} diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 3e05132..4ca732b 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -1155,6 +1155,15 @@ struct WhisperTokenizerWrapper: WhisperTokenizer { } extension WhisperTokenizerWrapper: Tokenizer { + + func applyChatTemplate(messages: [[String : String]]) throws -> [Int] { + try tokenizer.applyChatTemplate(messages: messages) + } + + func applyChatTemplate(messages: [[String : String]], chatTemplate: String?, addGenerationPrompt: Bool, truncation: Bool, maxLength: Int?) throws -> [Int] { + try tokenizer.applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength) + } + func tokenize(text: String) -> [String] { tokenizer.tokenize(text: text) } @@ -1166,6 +1175,10 @@ extension WhisperTokenizerWrapper: Tokenizer { func decode(tokens: [Int]) -> String { tokenizer.decode(tokens: tokens) } + + func encode(text: String, addSpecialTokens: Bool) -> [Int] { + tokenizer.encode(text: text, addSpecialTokens: addSpecialTokens) + } func convertTokenToId(_ token: String) -> Int? { tokenizer.convertTokenToId(token) From b8029fb0929bbe0cbf7dee1835603538d81f03ef Mon Sep 17 00:00:00 2001 From: 1amageek Date: Fri, 4 Oct 2024 14:08:50 +0900 Subject: [PATCH 03/27] Add `Sendable` conformance to several structs and enums --- Sources/WhisperKit/Core/Models.swift | 14 +++++++------- Sources/WhisperKit/Core/Text/TokenSampler.swift | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 4ca732b..67220cc 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -167,14 +167,14 @@ public struct ModelComputeOptions { // MARK: - Chunking -public struct AudioChunk { +public struct AudioChunk: Sendable { public var seekOffsetIndex: Int public var audioSamples: [Float] } // MARK: - Decoding -public enum DecodingTask: CustomStringConvertible, CaseIterable { +public enum DecodingTask: CustomStringConvertible, CaseIterable, Sendable { case transcribe case translate @@ -390,7 +390,7 @@ public enum WhisperError: Error, LocalizedError, Equatable { // Structs -public struct TranscriptionResult: Codable { +public struct TranscriptionResult: Codable, Sendable { public var text: String public var segments: [TranscriptionSegment] public var language: String @@ -478,7 +478,7 @@ public extension TranscriptionResult { } } -public struct TranscriptionSegment: Hashable, Codable { +public struct TranscriptionSegment: Hashable, Codable, Sendable { public var id: Int = 0 public var seek: Int = 0 public var start: Float = 0.0 @@ -493,7 +493,7 @@ public struct TranscriptionSegment: Hashable, Codable { public var words: [WordTiming]? = nil } -public struct WordTiming: Hashable, Codable { +public struct WordTiming: Hashable, Codable, Sendable { public var word: String public var tokens: [Int] public var start: Float @@ -501,7 +501,7 @@ public struct WordTiming: Hashable, Codable { public var probability: Float } -public struct TranscriptionProgress { +public struct TranscriptionProgress: Sendable { public var timings: TranscriptionTimings public var text: String public var tokens: [Int] @@ -533,7 +533,7 @@ public struct TranscriptionProgress { /// - Note: This callback should be lightweight and return as quickly as possible to avoid extra decoding loops public typealias TranscriptionCallback = ((TranscriptionProgress) -> Bool?)? -public struct TranscriptionTimings: Codable { +public struct TranscriptionTimings: Codable, Sendable { public var pipelineStart: CFAbsoluteTime public var firstTokenTime: CFAbsoluteTime public var inputAudioSeconds: TimeInterval diff --git a/Sources/WhisperKit/Core/Text/TokenSampler.swift b/Sources/WhisperKit/Core/Text/TokenSampler.swift index ce15cd5..4e6e1b7 100644 --- a/Sources/WhisperKit/Core/Text/TokenSampler.swift +++ b/Sources/WhisperKit/Core/Text/TokenSampler.swift @@ -10,7 +10,7 @@ public protocol TokenSampling { func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult } -public struct SamplingResult { +public struct SamplingResult: Sendable { public var tokens: [Int] public var logProbs: [Float] public var completed: Bool From 2af7d50f73699a89569722075e937220b70d05a5 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 10:31:51 +0900 Subject: [PATCH 04/27] Refactor AudioProcessor to use actor model and async --- .../WhisperKit/Core/Audio/AudioProcessor.swift | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index bdd54ce..e5d8d38 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -169,7 +169,7 @@ public extension AudioProcessing { } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -public class AudioProcessor: NSObject, AudioProcessing { +public actor AudioProcessor: @preconcurrency AudioProcessing { private var lastInputDevice: DeviceID? public var audioEngine: AVAudioEngine? public var audioSamples: ContiguousArray = [] @@ -672,9 +672,11 @@ public class AudioProcessor: NSObject, AudioProcessing { return devices } #endif - + deinit { - stopRecording() + Task { + await self.stopRecording() + } } } @@ -789,9 +791,11 @@ public extension AudioProcessor { return } } - - let newBufferArray = Self.convertBufferToArray(buffer: buffer) - self.processBuffer(newBufferArray) + let targetBuffer = buffer + Task { + let newBufferArray = Self.convertBufferToArray(buffer: targetBuffer) + await self.processBuffer(newBufferArray) + } } audioEngine.prepare() From dd1c4d5e24e685e6c68c9debe766af843517355c Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 10:38:51 +0900 Subject: [PATCH 05/27] Add Sendable conformance to various types and protocols --- Sources/WhisperKit/Core/Audio/AudioProcessor.swift | 2 +- Sources/WhisperKit/Core/Models.swift | 4 ++-- Sources/WhisperKit/Core/ResultWriter.swift | 8 ++++---- Sources/WhisperKit/Core/Utils.swift | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index e5d8d38..1336163 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -13,7 +13,7 @@ public typealias DeviceID = AudioDeviceID public typealias DeviceID = String #endif -public struct AudioDevice: Identifiable, Hashable { +public struct AudioDevice: Identifiable, Hashable, Sendable { public let id: DeviceID public let name: String } diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 67220cc..7bf1a50 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -247,13 +247,13 @@ public struct DecodingCache { } } -public enum ChunkingStrategy: String, CaseIterable { +public enum ChunkingStrategy: String, CaseIterable, Sendable { case none case vad } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -public struct DecodingFallback { +public struct DecodingFallback: Sendable { public var needsFallback: Bool public var fallbackReason: String diff --git a/Sources/WhisperKit/Core/ResultWriter.swift b/Sources/WhisperKit/Core/ResultWriter.swift index 00d694c..cfea6c9 100644 --- a/Sources/WhisperKit/Core/ResultWriter.swift +++ b/Sources/WhisperKit/Core/ResultWriter.swift @@ -3,7 +3,7 @@ import Foundation -public protocol ResultWriting { +public protocol ResultWriting: Sendable { var outputDir: String { get } func write(result: TranscriptionResult, to file: String, options: [String: Any]?) -> Result func formatTime(seconds: Float, alwaysIncludeHours: Bool, decimalMarker: String) -> String @@ -37,7 +37,7 @@ public extension ResultWriting { } } -open class WriteJSON: ResultWriting { +public struct WriteJSON: ResultWriting { public let outputDir: String public init(outputDir: String) { @@ -66,7 +66,7 @@ open class WriteJSON: ResultWriting { } } -open class WriteSRT: ResultWriting { +public struct WriteSRT: ResultWriting { public let outputDir: String public init(outputDir: String) { @@ -101,7 +101,7 @@ open class WriteSRT: ResultWriting { } } -open class WriteVTT: ResultWriting { +public struct WriteVTT: ResultWriting { public let outputDir: String public init(outputDir: String) { diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index 8713510..b91e069 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -836,7 +836,7 @@ public class Logging { } extension Logging { - enum AudioEncoding { + enum AudioEncoding: Sendable { static let logger = Logger( subsystem: Constants.Logging.subsystem, category: "AudioEncoding" @@ -846,7 +846,7 @@ extension Logging { } extension Logging { - enum FeatureExtractor { + enum FeatureExtractor: Sendable { static let logger = Logger( subsystem: Constants.Logging.subsystem, category: "FeatureExtractor" @@ -856,7 +856,7 @@ extension Logging { } extension Logging { - enum TranscribeTask { + enum TranscribeTask: Sendable { static let logger = Logger( subsystem: Constants.Logging.subsystem, category: "TranscribeTask" From 208893d56ff55fb5769e24c5733b166a74c6aa11 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 12:33:32 +0900 Subject: [PATCH 06/27] Update development team and package dependencies --- .../WhisperAX.xcodeproj/project.pbxproj | 2 +- .../xcshareddata/swiftpm/Package.resolved | 19 ++++++++++++++----- .../Core/Audio/AudioProcessor.swift | 6 +++++- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj index bfb9069..4c92d93 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj @@ -869,7 +869,7 @@ CURRENT_PROJECT_VERSION = 1; DEAD_CODE_STRIPPING = YES; DEVELOPMENT_ASSET_PATHS = "\"WhisperAX/Preview Content\""; - DEVELOPMENT_TEAM = PP83DTRKSA; + DEVELOPMENT_TEAM = 88ACA86N96; ENABLE_HARDENED_RUNTIME = YES; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 41e3727..a506738 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,6 +1,15 @@ { - "originHash" : "cd17206b47bb810af9459722192530e3838d8e6629a970988e32a432aaa05f6e", + "originHash" : "420a1723357da21f9e31b01403fd3d66df6e400a752d242d05b2c3d5667e3c33", "pins" : [ + { + "identity" : "jinja", + "kind" : "remoteSourceControl", + "location" : "https://github.com/maiqingqiang/Jinja", + "state" : { + "revision" : "b435eb62b0d3d5f34167ec70a128355486981712", + "version" : "1.0.5" + } + }, { "identity" : "networkimage", "kind" : "remoteSourceControl", @@ -15,8 +24,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-argument-parser.git", "state" : { - "revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41", - "version" : "1.3.0" + "revision" : "41982a3656a71c768319979febd796c6fd111d5c", + "version" : "1.5.0" } }, { @@ -33,8 +42,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe", - "version" : "0.1.7" + "revision" : "0f2306713d48a75b862026ebb291926793773f52", + "version" : "0.1.12" } } ], diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 1336163..ede8f0f 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -182,6 +182,10 @@ public actor AudioProcessor: @preconcurrency AudioProcessing { public var audioBufferCallback: (([Float]) -> Void)? public var maxBufferLength = WhisperKit.sampleRate * WhisperKit.chunkLength // 30 seconds of audio at 16,000 Hz public var minBufferLength = Int(Double(WhisperKit.sampleRate) * 0.1) // 0.1 second of audio at 16,000 Hz + + public init() { + + } // MARK: - Loading and conversion @@ -792,8 +796,8 @@ public extension AudioProcessor { } } let targetBuffer = buffer + let newBufferArray = Self.convertBufferToArray(buffer: targetBuffer) Task { - let newBufferArray = Self.convertBufferToArray(buffer: targetBuffer) await self.processBuffer(newBufferArray) } } From 9539e8be48e11526fd0af59e91b719c6ee31463d Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 15:48:23 +0900 Subject: [PATCH 07/27] Update package version and clean up code formatting --- .../xcshareddata/swiftpm/Package.resolved | 4 ++-- .../WhisperKit/Core/Audio/AudioProcessor.swift | 15 ++++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index a506738..bc58b75 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -33,8 +33,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/gonzalezreal/swift-markdown-ui.git", "state" : { - "revision" : "ae799d015a5374708f7b4c85f3294c05f2a564e2", - "version" : "2.3.0" + "revision" : "55441810c0f678c78ed7e2ebd46dde89228e02fc", + "version" : "2.4.0" } }, { diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index ede8f0f..201399b 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -701,7 +701,6 @@ public extension AudioProcessor { let signalEnergy = Self.calculateEnergy(of: buffer) let newEnergy = (relativeEnergy, signalEnergy.avg, signalEnergy.max, signalEnergy.min) self.audioEnergy.append(newEnergy) - // Call the callback with the new buffer audioBufferCallback?(buffer) @@ -804,13 +803,19 @@ public extension AudioProcessor { audioEngine.prepare() try audioEngine.start() - + return audioEngine } - + func purgeAudioSamples(keepingLast keep: Int) { - if audioSamples.count > keep { - audioSamples.removeFirst(audioSamples.count - keep) + let samplesToRemove = audioSamples.count - keep + if samplesToRemove > 0 { + audioSamples.removeFirst(samplesToRemove) + } + + let energiesToRemove = samplesToRemove / minBufferLength + if energiesToRemove > 0 { + audioEnergy.removeFirst(min(energiesToRemove, audioEnergy.count)) } } From 769dc29fa03b2669d44b3ca4544a0da606e9f735 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 16:30:26 +0900 Subject: [PATCH 08/27] Refactor audio energy calculations and buffer conversion --- .../Core/Audio/AudioProcessor.swift | 70 +++++++------------ 1 file changed, 26 insertions(+), 44 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 201399b..06beec0 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -490,19 +490,21 @@ public actor AudioProcessor: @preconcurrency AudioProcessing { var rmsEnergy: Float = 0.0 var minEnergy: Float = 0.0 var maxEnergy: Float = 0.0 - - // Calculate the root mean square of the signal vDSP_rmsqv(signal, 1, &rmsEnergy, vDSP_Length(signal.count)) - - // Calculate the maximum sample value of the signal - vDSP_maxmgv(signal, 1, &maxEnergy, vDSP_Length(signal.count)) - - // Calculate the minimum sample value of the signal - vDSP_minmgv(signal, 1, &minEnergy, vDSP_Length(signal.count)) - + vDSP_maxv(signal, 1, &maxEnergy, vDSP_Length(signal.count)) + vDSP_minv(signal, 1, &minEnergy, vDSP_Length(signal.count)) return (rmsEnergy, maxEnergy, minEnergy) } + public static func calculateRelativeEnergy(of signal: [Float], relativeTo reference: Float) -> Float { + let signalEnergy = calculateAverageEnergy(of: signal) + let referenceEnergy = max(1e-8, reference) + let dbEnergy = 20 * log10(signalEnergy) + let refEnergy = 20 * log10(referenceEnergy) + let normalizedEnergy = rescale(value: dbEnergy, min: refEnergy, max: 0) + return max(0, min(normalizedEnergy, 1)) + } + public static func calculateRelativeEnergy(of signal: [Float], relativeTo reference: Float?) -> Float { let signalEnergy = calculateAverageEnergy(of: signal) @@ -522,41 +524,13 @@ public actor AudioProcessor: @preconcurrency AudioProcessing { return max(0, min(normalizedEnergy, 1)) } - public static func convertBufferToArray(buffer: AVAudioPCMBuffer, chunkSize: Int = 1024) -> [Float] { + public static func convertBufferToArray(buffer: AVAudioPCMBuffer) -> [Float] { guard let channelData = buffer.floatChannelData else { return [] } - let frameLength = Int(buffer.frameLength) let startPointer = channelData[0] - - var result: [Float] = [] - result.reserveCapacity(frameLength) // Reserve the capacity to avoid multiple allocations - - var currentFrame = 0 - while currentFrame < frameLength { - let remainingFrames = frameLength - currentFrame - let currentChunkSize = min(chunkSize, remainingFrames) - - var chunk = [Float](repeating: 0, count: currentChunkSize) - - chunk.withUnsafeMutableBufferPointer { bufferPointer in - vDSP_mmov( - startPointer.advanced(by: currentFrame), - bufferPointer.baseAddress!, - vDSP_Length(currentChunkSize), - 1, - vDSP_Length(currentChunkSize), - 1 - ) - } - - result.append(contentsOf: chunk) - currentFrame += currentChunkSize - - memset(startPointer.advanced(by: currentFrame - currentChunkSize), 0, currentChunkSize * MemoryLayout.size) - } - + let result = Array(UnsafeBufferPointer(start: startPointer, count: frameLength)) return result } @@ -691,15 +665,23 @@ public extension AudioProcessor { /// We have a new buffer, process and store it. /// NOTE: Assumes audio is 16khz mono func processBuffer(_ buffer: [Float]) { + let bufferCount = buffer.count + let previousCount = audioSamples.count + audioSamples.reserveCapacity(previousCount + bufferCount) audioSamples.append(contentsOf: buffer) - // Find the lowest average energy of the last 20 buffers ~2 seconds - let minAvgEnergy = self.audioEnergy.suffix(20).reduce(Float.infinity) { min($0, $1.avg) } - let relativeEnergy = Self.calculateRelativeEnergy(of: buffer, relativeTo: minAvgEnergy) + // エネルギー計算 + let recentAudioEnergy = self.audioEnergy.suffix(relativeEnergyWindow) + let minAvgEnergy: Float + if recentAudioEnergy.isEmpty { + minAvgEnergy = 1e-8 // デフォルトの最小エネルギー値 + } else { + minAvgEnergy = recentAudioEnergy.reduce(Float.infinity) { min($0, $1.avg) } + } - // Update energy for buffers with valid data + let relativeEnergy = Self.calculateRelativeEnergy(of: buffer, relativeTo: minAvgEnergy) let signalEnergy = Self.calculateEnergy(of: buffer) - let newEnergy = (relativeEnergy, signalEnergy.avg, signalEnergy.max, signalEnergy.min) + let newEnergy = (rel: relativeEnergy, avg: signalEnergy.avg, max: signalEnergy.max, min: signalEnergy.min) self.audioEnergy.append(newEnergy) // Call the callback with the new buffer audioBufferCallback?(buffer) From c8219d3cb0d9c4f95fd6b1f52b71ed23570f50da Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 16:33:23 +0900 Subject: [PATCH 09/27] Refactor calculateRelativeEnergy method for clarity --- .../Core/Audio/AudioProcessor.swift | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 06beec0..7afbf11 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -496,31 +496,12 @@ public actor AudioProcessor: @preconcurrency AudioProcessing { return (rmsEnergy, maxEnergy, minEnergy) } - public static func calculateRelativeEnergy(of signal: [Float], relativeTo reference: Float) -> Float { - let signalEnergy = calculateAverageEnergy(of: signal) - let referenceEnergy = max(1e-8, reference) - let dbEnergy = 20 * log10(signalEnergy) - let refEnergy = 20 * log10(referenceEnergy) - let normalizedEnergy = rescale(value: dbEnergy, min: refEnergy, max: 0) - return max(0, min(normalizedEnergy, 1)) - } - public static func calculateRelativeEnergy(of signal: [Float], relativeTo reference: Float?) -> Float { let signalEnergy = calculateAverageEnergy(of: signal) - - // Make sure reference is greater than 0 - // Default 1e-3 measured empirically in a silent room let referenceEnergy = max(1e-8, reference ?? 1e-3) - - // Convert to dB let dbEnergy = 20 * log10(signalEnergy) let refEnergy = 20 * log10(referenceEnergy) - - // Normalize based on reference - // NOTE: since signalEnergy elements are floats from 0 to 1, max (full volume) is always 0dB let normalizedEnergy = rescale(value: dbEnergy, min: refEnergy, max: 0) - - // Clamp from 0 to 1 return max(0, min(normalizedEnergy, 1)) } From 2c0549c458206e922de58a7fe633e08312a286af Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 17:35:59 +0900 Subject: [PATCH 10/27] Optimize audio buffer processing with vDSP_mmov --- Sources/WhisperKit/Core/Audio/AudioProcessor.swift | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 7afbf11..0dc203d 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -511,7 +511,17 @@ public actor AudioProcessor: @preconcurrency AudioProcessing { } let frameLength = Int(buffer.frameLength) let startPointer = channelData[0] - let result = Array(UnsafeBufferPointer(start: startPointer, count: frameLength)) + var result = [Float](unsafeUninitializedCapacity: frameLength) { bufferPointer, initializedCount in + vDSP_mmov( + startPointer, + bufferPointer.baseAddress!, + vDSP_Length(frameLength), + 1, + vDSP_Length(frameLength), + 1 + ) + initializedCount = frameLength + } return result } From 22aaa70066e917d8754145859bec4ac330e0fc5c Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 19:00:12 +0900 Subject: [PATCH 11/27] Refactor audio sample access methods in AudioProcessor --- .../Core/Audio/AudioProcessor.swift | 25 +++++++++++++------ .../Core/Audio/AudioStreamTranscriber.swift | 7 +++--- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 0dc203d..7e4a2f1 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -47,13 +47,13 @@ public protocol AudioProcessing { ) -> MLMultiArray? /// Stores the audio samples to be transcribed - var audioSamples: ContiguousArray { get } + func getAudioSamples() -> ContiguousArray /// Empties the audio samples array, keeping the last `keep` samples func purgeAudioSamples(keepingLast keep: Int) /// A measure of current buffer's energy in dB normalized from 0 - 1 based on the quietest buffer's energy in a specified window - var relativeEnergy: [Float] { get } + func getRelativeEnergy() -> [Float] /// How many past buffers of audio to use to calculate relative energy /// The lowest average energy value in the buffer within this amount of previous buffers will used as the silence baseline @@ -172,12 +172,7 @@ public extension AudioProcessing { public actor AudioProcessor: @preconcurrency AudioProcessing { private var lastInputDevice: DeviceID? public var audioEngine: AVAudioEngine? - public var audioSamples: ContiguousArray = [] - public var audioEnergy: [(rel: Float, avg: Float, max: Float, min: Float)] = [] public var relativeEnergyWindow: Int = 20 - public var relativeEnergy: [Float] { - return self.audioEnergy.map { $0.rel } - } public var audioBufferCallback: (([Float]) -> Void)? public var maxBufferLength = WhisperKit.sampleRate * WhisperKit.chunkLength // 30 seconds of audio at 16,000 Hz @@ -186,6 +181,22 @@ public actor AudioProcessor: @preconcurrency AudioProcessing { public init() { } + + private var audioSamples: ContiguousArray = [] + + public func getAudioSamples() -> ContiguousArray { + self.audioSamples + } + + private var audioEnergy: [(rel: Float, avg: Float, max: Float, min: Float)] = [] + + public func getAudioEnergy() -> [(rel: Float, avg: Float, max: Float, min: Float)] { + self.audioEnergy + } + + public func getRelativeEnergy() -> [Float] { + self.audioEnergy.map(\.rel) + } // MARK: - Loading and conversion diff --git a/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift b/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift index f91ba53..4ec21f4 100644 --- a/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift +++ b/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift @@ -106,7 +106,7 @@ public actor AudioStreamTranscriber { } private func onAudioBufferCallback() { - state.bufferEnergy = audioProcessor.relativeEnergy + state.bufferEnergy = audioProcessor.getRelativeEnergy() } private func onProgressCallback(_ progress: TranscriptionProgress) { @@ -124,7 +124,7 @@ public actor AudioStreamTranscriber { private func transcribeCurrentBuffer() async throws { // Retrieve the current audio buffer from the audio processor - let currentBuffer = audioProcessor.audioSamples + let currentBuffer = await audioProcessor.getAudioSamples() // Calculate the size and duration of the next buffer segment let nextBufferSize = currentBuffer.count - state.lastBufferSize @@ -139,8 +139,9 @@ public actor AudioStreamTranscriber { } if useVAD { + let relativeEnergy = await audioProcessor.getRelativeEnergy() let voiceDetected = AudioProcessor.isVoiceDetected( - in: audioProcessor.relativeEnergy, + in: relativeEnergy, nextBufferInSeconds: nextBufferSeconds, silenceThreshold: silenceThreshold ) From 8ceaa0adce06e7e780386a74997a12910e5fe1cc Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 19:32:52 +0900 Subject: [PATCH 12/27] Remove unnecessary weak self references in closure --- Sources/WhisperKit/Core/Audio/AudioProcessor.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 7e4a2f1..60d3fcf 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -767,8 +767,7 @@ public extension AudioProcessor { } let bufferSize = AVAudioFrameCount(minBufferLength) // 100ms - 400ms supported - inputNode.installTap(onBus: 0, bufferSize: bufferSize, format: nodeFormat) { [weak self] (buffer: AVAudioPCMBuffer, _: AVAudioTime) in - guard let self = self else { return } + inputNode.installTap(onBus: 0, bufferSize: bufferSize, format: nodeFormat) { (buffer: AVAudioPCMBuffer, _: AVAudioTime) in var buffer = buffer if !buffer.format.sampleRate.isEqual(to: Double(WhisperKit.sampleRate)) { do { @@ -780,7 +779,8 @@ public extension AudioProcessor { } let targetBuffer = buffer let newBufferArray = Self.convertBufferToArray(buffer: targetBuffer) - Task { + Task { [weak self] in + guard let self = self else { return } await self.processBuffer(newBufferArray) } } From 4d4233e13790055de9ab0b61fd11ce5d54b99e0e Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 19:35:21 +0900 Subject: [PATCH 13/27] Refactor audio processing to use async/await methods --- Examples/WhisperAX/WhisperAX/Views/ContentView.swift | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 2a182fb..b164111 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -1206,9 +1206,10 @@ struct ContentView: View { #endif try? audioProcessor.startRecordingLive(inputDeviceID: deviceId) { _ in - DispatchQueue.main.async { - bufferEnergy = whisperKit?.audioProcessor.relativeEnergy ?? [] - bufferSeconds = Double(whisperKit?.audioProcessor.audioSamples.count ?? 0) / Double(WhisperKit.sampleRate) + Task { @MainActor in + bufferEnergy = await whisperKit?.audioProcessor.getRelativeEnergy() ?? [] + let audioSamples = await whisperKit?.audioProcessor.getAudioSamples() ?? [] + bufferSeconds = Double(audioSamples.count) / Double(WhisperKit.sampleRate) } } @@ -1406,7 +1407,7 @@ struct ContentView: View { guard let whisperKit = whisperKit else { return } // Retrieve the current audio buffer from the audio processor - let currentBuffer = whisperKit.audioProcessor.audioSamples + let currentBuffer = whisperKit.audioProcessor.getAudioSamples() // Calculate the size and duration of the next buffer segment let nextBufferSize = currentBuffer.count - lastBufferSize @@ -1424,8 +1425,9 @@ struct ContentView: View { } if useVAD { + let relativeEnergy = whisperKit.audioProcessor.getRelativeEnergy() let voiceDetected = AudioProcessor.isVoiceDetected( - in: whisperKit.audioProcessor.relativeEnergy, + in: relativeEnergy, nextBufferInSeconds: nextBufferSeconds, silenceThreshold: Float(silenceThreshold) ) From f9bcd1d81ad4cf93f0510af34b74a8145255e560 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 20:04:10 +0900 Subject: [PATCH 14/27] Use weak self in audio tap closure to prevent retain cycle --- Sources/WhisperKit/Core/Audio/AudioProcessor.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 60d3fcf..c422787 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -767,7 +767,7 @@ public extension AudioProcessor { } let bufferSize = AVAudioFrameCount(minBufferLength) // 100ms - 400ms supported - inputNode.installTap(onBus: 0, bufferSize: bufferSize, format: nodeFormat) { (buffer: AVAudioPCMBuffer, _: AVAudioTime) in + inputNode.installTap(onBus: 0, bufferSize: bufferSize, format: nodeFormat) { [weak self] (buffer: AVAudioPCMBuffer, _: AVAudioTime) in var buffer = buffer if !buffer.format.sampleRate.isEqual(to: Double(WhisperKit.sampleRate)) { do { From ea5d8537cc5b198061700526712181def415afbe Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 20:51:03 +0900 Subject: [PATCH 15/27] Log file name in error message for transcriber --- Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift b/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift index 4ec21f4..c01229e 100644 --- a/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift +++ b/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift @@ -99,7 +99,7 @@ public actor AudioStreamTranscriber { do { try await transcribeCurrentBuffer() } catch { - Logging.error("Error: \(error.localizedDescription)") + Logging.error("Error: \(#file) \(error.localizedDescription)") break } } From 5909d11e11720949ebe5b510ce6e7aaf91ee9eb9 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 21:12:07 +0900 Subject: [PATCH 16/27] Refactor VADAudioChunker to a struct from a class --- .../WhisperKit/Core/Audio/AudioChunker.swift | 11 +- Sources/WhisperKit/Core/WhisperKit.swift | 300 +++++++++--------- 2 files changed, 157 insertions(+), 154 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioChunker.swift b/Sources/WhisperKit/Core/Audio/AudioChunker.swift index 7f56f8c..2ee4efb 100644 --- a/Sources/WhisperKit/Core/Audio/AudioChunker.swift +++ b/Sources/WhisperKit/Core/Audio/AudioChunker.swift @@ -43,7 +43,7 @@ public extension AudioChunking { /// A audio chunker that splits audio into smaller pieces based on voice activity detection @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -open class VADAudioChunker: AudioChunking { +public struct VADAudioChunker: AudioChunking { /// prevent hallucinations at the end of the clip by stopping up to 1.0s early private let windowPadding: Int private let vad: VoiceActivityDetector @@ -81,12 +81,12 @@ open class VADAudioChunker: AudioChunking { // Typically this will be the full audio file, unless seek points are explicitly provided var startIndex = seekClipStart while startIndex < seekClipEnd - windowPadding { - let currentFrameLength = startIndex - seekClipStart - if startIndex >= currentFrameLength || startIndex < 0 { + // 配列範囲内にあるかチェック + if startIndex >= audioArray.count || startIndex < 0 { throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size") } - // Make sure we still need chunking for this seek clip, otherwise use the original seek clip end + // Adjust the end index based on VAD or maxChunkLength var endIndex = seekClipEnd if startIndex + maxChunkLength < endIndex { // Adjust the end index based on VAD @@ -97,6 +97,8 @@ open class VADAudioChunker: AudioChunking { ) } + // Ensure endIndex is within the array bounds + endIndex = min(endIndex, audioArray.count) guard endIndex > startIndex else { break } @@ -108,4 +110,5 @@ open class VADAudioChunker: AudioChunking { } return chunkedAudio } + } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index f5a2cd2..a32c024 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -16,7 +16,7 @@ open class WhisperKit { public private(set) var modelState: ModelState = .unloaded public var modelCompute: ModelComputeOptions public var tokenizer: WhisperTokenizer? - + /// Protocols public var audioProcessor: any AudioProcessing public var featureExtractor: any FeatureExtracting @@ -24,23 +24,23 @@ open class WhisperKit { public var textDecoder: any TextDecoding public var logitsFilters: [any LogitsFiltering] public var segmentSeeker: any SegmentSeeking - + /// Shapes public static let sampleRate: Int = 16000 public static let hopLength: Int = 160 public static let chunkLength: Int = 30 // seconds public static let windowSamples: Int = 480_000 // sampleRate * chunkLength public static let secondsPerTimeToken = Float(0.02) - + /// Progress public private(set) var currentTimings: TranscriptionTimings public private(set) var progress = Progress() - + /// Configuration public var modelFolder: URL? public var tokenizerFolder: URL? public private(set) var useBackgroundDownloadSession: Bool - + public init(_ config: WhisperKitConfig = WhisperKitConfig()) async throws { modelCompute = config.computeOptions ?? ModelComputeOptions() audioProcessor = config.audioProcessor ?? AudioProcessor() @@ -53,7 +53,7 @@ open class WhisperKit { useBackgroundDownloadSession = config.useBackgroundDownloadSession currentTimings = TranscriptionTimings() Logging.shared.logLevel = config.verbose ? config.logLevel : .none - + try await setupModels( model: config.model, downloadBase: config.downloadBase, @@ -61,19 +61,19 @@ open class WhisperKit { modelFolder: config.modelFolder, download: config.download ) - + if let prewarm = config.prewarm, prewarm { Logging.info("Prewarming models...") try await prewarmModels() } - + // If load is not passed in, load based on whether a modelFolder is passed if config.load ?? (config.modelFolder != nil) { Logging.info("Loading models...") try await loadModels() } } - + public convenience init( model: String? = nil, downloadBase: URL? = nil, @@ -113,21 +113,21 @@ open class WhisperKit { load: load, download: download, useBackgroundDownloadSession: useBackgroundDownloadSession - ) + ) try await self.init(config) } - + // MARK: - Model Loading - + public static func recommendedModels() -> (default: String, disabled: [String]) { let deviceName = Self.deviceName() Logging.debug("Running on \(deviceName)") - + let defaultModel = modelSupport(for: deviceName).default let disabledModels = modelSupport(for: deviceName).disabled return (defaultModel, disabledModels) } - + public static func deviceName() -> String { var utsname = utsname() uname(&utsname) @@ -138,14 +138,14 @@ open class WhisperKit { } return deviceName } - + public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["openai_*", "distil-whisper_*"]) async throws -> [String] { let hubApi = HubApi() let modelFiles = try await hubApi.getFilenames(from: repo, matching: matching) - + return formatModelFiles(modelFiles) } - + public static func formatModelFiles(_ modelFiles: [String]) -> [String] { let modelFilters = ModelVariant.allCases.map { "\($0.description)\($0.description.contains("large") ? "" : "/")" } // Include quantized models for large let modelVariants = modelFiles.map { $0.components(separatedBy: "/")[0] + "/" } @@ -156,32 +156,32 @@ open class WhisperKit { } return count > 0 }) - + let availableModels = filteredVariants.map { variant -> String in variant.trimmingFromEnd(character: "/", upto: 1) } - + // Sorting order based on enum let sizeOrder = ModelVariant.allCases.map { $0.description } - + let sortedModels = availableModels.sorted { firstModel, secondModel in // Extract the base size without any additional qualifiers let firstModelBase = sizeOrder.first(where: { firstModel.contains($0) }) ?? "" let secondModelBase = sizeOrder.first(where: { secondModel.contains($0) }) ?? "" - + if firstModelBase == secondModelBase { // If base sizes are the same, sort alphabetically return firstModel < secondModel } else { // Sort based on the size order return sizeOrder.firstIndex(of: firstModelBase) ?? sizeOrder.count - < sizeOrder.firstIndex(of: secondModelBase) ?? sizeOrder.count + < sizeOrder.firstIndex(of: secondModelBase) ?? sizeOrder.count } } - + return sortedModels } - + public static func download( variant: String, downloadBase: URL? = nil, @@ -196,9 +196,9 @@ open class WhisperKit { Logging.debug("Searching for models matching \"\(modelSearchPath)\" in \(repo)") let modelFiles = try await hubApi.getFilenames(from: repo, matching: [modelSearchPath]) var uniquePaths = Set(modelFiles.map { $0.components(separatedBy: "/").first! }) - + var variantPath: String? = nil - + if uniquePaths.count == 1 { variantPath = uniquePaths.first } else { @@ -208,17 +208,17 @@ open class WhisperKit { Logging.debug("Searching for models matching \"\(adjustedModelSearchPath)\" in \(repo)") let adjustedModelFiles = try await hubApi.getFilenames(from: repo, matching: [adjustedModelSearchPath]) uniquePaths = Set(adjustedModelFiles.map { $0.components(separatedBy: "/").first! }) - + if uniquePaths.count == 1 { variantPath = uniquePaths.first } } - + guard let variantPath else { // If there is still ambiguity, throw an error throw WhisperError.modelsUnavailable("Multiple models found matching \"\(modelSearchPath)\"") } - + Logging.debug("Downloading model \(variantPath)...") let modelFolder = try await hubApi.snapshot(from: repo, matching: [modelSearchPath]) { progress in Logging.debug(progress) @@ -226,7 +226,7 @@ open class WhisperKit { callback(progress) } } - + let modelFolderName = modelFolder.appending(path: variantPath) return modelFolderName } catch { @@ -234,7 +234,7 @@ open class WhisperKit { throw error } } - + /// Sets up the model folder either from a local path or by downloading from a repository. open func setupModels( model: String?, @@ -245,7 +245,7 @@ open class WhisperKit { ) async throws { // Determine the model variant to use let modelVariant = model ?? WhisperKit.recommendedModels().default - + // If a local model folder is provided, use it; otherwise, download the model if let folder = modelFolder { self.modelFolder = URL(fileURLWithPath: folder) @@ -267,36 +267,36 @@ open class WhisperKit { } } } - + open func prewarmModels() async throws { try await loadModels(prewarmMode: true) } - + open func loadModels( prewarmMode: Bool = false ) async throws { modelState = prewarmMode ? .prewarming : .loading - + let modelLoadStart = CFAbsoluteTimeGetCurrent() - + guard let path = modelFolder else { throw WhisperError.modelsUnavailable("Model folder is not set.") } - + Logging.debug("Loading models from \(path.path) with prewarmMode: \(prewarmMode)") - + // Find either mlmodelc or mlpackage models let logmelUrl = detectModelURL(inFolder: path, named: "MelSpectrogram") let encoderUrl = detectModelURL(inFolder: path, named: "AudioEncoder") let decoderUrl = detectModelURL(inFolder: path, named: "TextDecoder") let decoderPrefillUrl = detectModelURL(inFolder: path, named: "TextDecoderContextPrefill") - + for item in [logmelUrl, encoderUrl, decoderUrl] { if !FileManager.default.fileExists(atPath: item.path) { throw WhisperError.modelsUnavailable("Model file not found at \(item.path)") } } - + if let featureExtractor = featureExtractor as? WhisperMLModel { Logging.debug("Loading feature extractor") try await featureExtractor.loadModel( @@ -306,7 +306,7 @@ open class WhisperKit { ) Logging.debug("Loaded feature extractor") } - + if FileManager.default.fileExists(atPath: decoderPrefillUrl.path) { Logging.debug("Loading text decoder prefill data") textDecoder.prefillData = TextDecoderContextPrefill() @@ -317,7 +317,7 @@ open class WhisperKit { ) Logging.debug("Loaded text decoder prefill data") } - + if let textDecoder = textDecoder as? WhisperMLModel { Logging.debug("Loading text decoder") let decoderLoadStart = CFAbsoluteTimeGetCurrent() @@ -327,30 +327,30 @@ open class WhisperKit { prewarmMode: prewarmMode ) currentTimings.decoderLoadTime = CFAbsoluteTimeGetCurrent() - decoderLoadStart - + Logging.debug("Loaded text decoder in \(String(format: "%.2f", currentTimings.decoderLoadTime))s") } - + if let audioEncoder = audioEncoder as? WhisperMLModel { Logging.debug("Loading audio encoder") let encoderLoadStart = CFAbsoluteTimeGetCurrent() - + try await audioEncoder.loadModel( at: encoderUrl, computeUnits: modelCompute.audioEncoderCompute, prewarmMode: prewarmMode ) currentTimings.encoderLoadTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart - + Logging.debug("Loaded audio encoder in \(String(format: "%.2f", currentTimings.encoderLoadTime))s") } - + if prewarmMode { modelState = .prewarmed currentTimings.prewarmLoadTime = CFAbsoluteTimeGetCurrent() - modelLoadStart return } - + // Check model dimensions to assign appropriate tokenizer guard let logitsDim = textDecoder.logitsSize, let encoderDim = audioEncoder.embedSize else { throw WhisperError.tokenizerUnavailable() @@ -359,55 +359,55 @@ open class WhisperKit { modelVariant = detectVariant(logitsDim: logitsDim, encoderDim: encoderDim) Logging.debug("Loading tokenizer for \(modelVariant)") let tokenizerLoadStart = CFAbsoluteTimeGetCurrent() - + let tokenizer = try await loadTokenizer( for: modelVariant, tokenizerFolder: tokenizerFolder, useBackgroundSession: useBackgroundDownloadSession ) currentTimings.tokenizerLoadTime = CFAbsoluteTimeGetCurrent() - tokenizerLoadStart - + self.tokenizer = tokenizer textDecoder.tokenizer = tokenizer Logging.debug("Loaded tokenizer in \(String(format: "%.2f", currentTimings.tokenizerLoadTime))s") - + modelState = .loaded - + currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart + currentTimings.prewarmLoadTime - + Logging.info("Loaded models for whisper size: \(modelVariant) in \(String(format: "%.2f", currentTimings.modelLoading))s") } - + open func unloadModels() async { modelState = .unloading - + for model in [featureExtractor, audioEncoder, textDecoder] { if let model = model as? WhisperMLModel { model.unloadModel() } } - + modelState = .unloaded - + Logging.info("Unloaded all models") } - + open func clearState() { audioProcessor.stopRecording() currentTimings = TranscriptionTimings() } - + deinit { audioProcessor.stopRecording() } - + /// Pass in your own logging callback here open func loggingCallback(_ callback: Logging.LoggingCallback?) { Logging.shared.loggingCallback = callback } - + // MARK: - Detect language - + /// Detects the language of the audio file at the specified path. /// /// - Parameter audioPath: The file path of the audio file. @@ -419,7 +419,7 @@ open class WhisperKit { let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer) return try await detectLanguage(audioArray: audioArray) } - + /// Detects the language of the audio samples in the provided array. /// /// - Parameter audioArray: An array of audio samples. @@ -430,22 +430,22 @@ open class WhisperKit { if modelState != .loaded { try await loadModels() } - + // Ensure the model is multilingual, as language detection is only supported for these models guard textDecoder.isModelMultilingual else { throw WhisperError.decodingFailed("Language detection not supported for this model") } - + // Tokenizer required for decoding guard let tokenizer else { throw WhisperError.tokenizerUnavailable() } - + let options = DecodingOptions(verbose: Logging.shared.logLevel != .none) let decoderInputs = try textDecoder.prepareDecoderInputs(withPrompt: [tokenizer.specialTokens.startOfTranscriptToken]) decoderInputs.kvCacheUpdateMask[0] = 1.0 decoderInputs.decoderKeyPaddingMask[0] = 0.0 - + // Detect language using up to the first 30 seconds guard let audioSamples = AudioProcessor.padOrTrimAudio(fromArray: audioArray, startAt: 0, toLength: WhisperKit.windowSamples) else { throw WhisperError.transcriptionFailed("Audio samples are nil") @@ -456,7 +456,7 @@ open class WhisperKit { guard let encoderOutput = try await audioEncoder.encodeFeatures(melOutput) else { throw WhisperError.transcriptionFailed("Encoder output is nil") } - + let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: tokenizer.specialTokens.endToken, decodingOptions: options) guard let languageDecodingResult: DecodingResult = try? await textDecoder.detectLanguage( from: encoderOutput, @@ -467,12 +467,12 @@ open class WhisperKit { ) else { throw WhisperError.decodingFailed("Language detection failed") } - + return (language: languageDecodingResult.language, langProbs: languageDecodingResult.languageProbs) } - + // MARK: - Transcribe multiple audio files - + /// Convenience method to transcribe multiple audio files asynchronously and return the results as an array of optional arrays of `TranscriptionResult`. /// - Returns: An array of optional arrays containing `TranscriptionResult`. open func transcribe( @@ -488,7 +488,7 @@ open class WhisperKit { let results = transcribeResults.toOptionalArrays() return results } - + /// Transcribes multiple audio files asynchronously and returns the results as an array of tuples containing the file path and the `Result` object. /// /// This method processes the provided audio file paths by loading the audio data and then transcribing the audio arrays. @@ -507,45 +507,45 @@ open class WhisperKit { ) async -> [Result<[TranscriptionResult], Swift.Error>] { // Start timing the audio loading and conversion process let loadAudioStart = Date() - + // Load and extract audio data from the provided file paths let loadedAudioResult = await AudioProcessor.loadAudio(at: audioPaths) let audioArrays = loadedAudioResult.compactMap { try? $0.get() } - + // Calculate the time taken to load and convert audio let loadAndConvertTime = Date().timeIntervalSince(loadAudioStart) currentTimings.audioLoading = loadAndConvertTime Logging.debug("Total Audio Loading and Converting Time: \(loadAndConvertTime)") - + // Transcribe the loaded audio arrays let transcribeResults = await transcribeWithResults( audioArrays: audioArrays, decodeOptions: decodeOptions, callback: callback ) - + // Initialize the result array to hold final transcription results var result = [Result<[TranscriptionResult], Swift.Error>]() var transcribeResultIndex = 0 - + // Iterate over loadedAudioResult and map each to the corresponding transcription result for audioResult in loadedAudioResult { switch audioResult { - case .success: - // Append transcription result if audio loading was successful (may still contain failure) - result.append(transcribeResults[transcribeResultIndex]) - transcribeResultIndex += 1 - case let .failure(error): - // Append failure result if audio loading failed - result.append(.failure(error)) + case .success: + // Append transcription result if audio loading was successful (may still contain failure) + result.append(transcribeResults[transcribeResultIndex]) + transcribeResultIndex += 1 + case let .failure(error): + // Append failure result if audio loading failed + result.append(.failure(error)) } } - + return result } - + // MARK: - Transcribe multiple audio arrays - + /// Convenience method to transcribe multiple audio arrays asynchronously and return the results as an array of optional arrays of `TranscriptionResult`. /// - Returns: An array of optional arrays containing `TranscriptionResult`. open func transcribe( @@ -558,10 +558,10 @@ open class WhisperKit { decodeOptions: decodeOptions, callback: callback ) - + return transcribeResults.toOptionalArrays() } - + /// Transcribes multiple audio arrays asynchronously and returns the results as an array of `Result` objects. /// /// This method processes the provided audio arrays by dividing them into batches based on the concurrent worker count @@ -587,7 +587,7 @@ open class WhisperKit { callback: callback ) } - + /// Method to transcribe multiple audio arrays asynchronously with optional associated decoding options and return the results as an array of `Result` objects. /// - Parameters: /// - audioArrays: An array of arrays, each containing audio @@ -601,18 +601,18 @@ open class WhisperKit { callback: TranscriptionCallback = nil ) async -> [Result<[TranscriptionResult], Swift.Error>] { var result = [Result<[TranscriptionResult], Swift.Error>]() - + guard audioArrays.count == decodeOptionsArray.count else { return [.failure(WhisperError.transcriptionFailed("The number of audio arrays and decoding options must be balanced."))] } - + // Determine the number of concurrent workers from decodeOptions based on the maximum value or default to 0 let concurrentWorkerCount = decodeOptionsArray.map { $0?.concurrentWorkerCount ?? 0 }.max() ?? 0 - + // Chunk the audio arrays based on the number of concurrent workers // If concurrentWorkerCount is 0, all audio arrays are processed in one batch let batchedAudioArrays = concurrentWorkerCount == 0 ? [audioArrays] : audioArrays.batched(into: concurrentWorkerCount) - + for (batchIndex, audioArrayBatch) in batchedAudioArrays.enumerated() { // Use withTaskGroup to manage concurrent transcription tasks let partialResult = await withTaskGroup(of: [(index: Int, result: Result<[TranscriptionResult], Swift.Error>)].self) { taskGroup -> [Result<[TranscriptionResult], Swift.Error>] in @@ -623,10 +623,10 @@ open class WhisperKit { batchedProgress.windowId = audioIndex + batchIndex * audioArrayBatch.count return callback?(batchedProgress) } - + // Setup decoding options for the current audio array let batchedDecodeOptions = decodeOptionsArray[audioIndex] - + // Add a new task to the task group for each audio array taskGroup.addTask { do { @@ -643,29 +643,29 @@ open class WhisperKit { } } } - + // Collect results from all completed tasks in the task group var batchResult = [(index: Int, result: Result<[TranscriptionResult], Swift.Error>)]() for await result in taskGroup { batchResult.append(contentsOf: result) } - + // Sort the results by index to maintain the original order (they may not be in order due to concurrency) batchResult.sort(by: { $0.index < $1.index }) - + // Map the sorted batch results to a simple array of results return batchResult.map { $0.result } } - + // Append the results of each batch to the final result array result.append(contentsOf: partialResult) } - + return result } - + // MARK: - Transcribe single audio file - + @available(*, deprecated, message: "Subject to removal in a future version. Use `transcribe(audioPath:decodeOptions:callback:) async throws -> [TranscriptionResult]` instead.") @_disfavoredOverload open func transcribe( @@ -676,7 +676,7 @@ open class WhisperKit { let result: [TranscriptionResult] = try await transcribe(audioPath: audioPath, decodeOptions: decodeOptions, callback: callback) return result.first } - + /// Transcribes an audio file from the given path asynchronously. /// - Parameters: /// - audioPath: The file path to the audio file to be transcribed. @@ -693,24 +693,24 @@ open class WhisperKit { let loadAudioStart = Date() let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) let loadTime = Date().timeIntervalSince(loadAudioStart) - + let convertAudioStart = Date() let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer) let convertTime = Date().timeIntervalSince(convertAudioStart) currentTimings.audioLoading = loadTime + convertTime Logging.debug("Audio loading time: \(loadTime), Audio convert time: \(convertTime)") - + let transcribeResults: [TranscriptionResult] = try await transcribe( audioArray: audioArray, decodeOptions: decodeOptions, callback: callback ) - + return transcribeResults } - + // MARK: - Transcribe single audio sample array - + /// Deprecated @available(*, deprecated, message: "Subject to removal in a future version. Use `transcribe(audioArray:decodeOptions:callback:) async throws -> [TranscriptionResult]` instead.") @_disfavoredOverload @@ -722,7 +722,7 @@ open class WhisperKit { let result: [TranscriptionResult] = try await transcribe(audioArray: audioArray, decodeOptions: decodeOptions, callback: callback) return result.first } - + /// Main entry point for transcribing audio /// - Parameters: /// - audioArray: Array of 16khz raw float audio samples @@ -736,11 +736,11 @@ open class WhisperKit { callback: TranscriptionCallback = nil ) async throws -> [TranscriptionResult] { var transcribeResults = [TranscriptionResult]() - + // Determine if the audio array requires chunking let isChunkable = audioArray.count > WhisperKit.windowSamples switch (isChunkable, decodeOptions?.chunkingStrategy) { - case (true, .vad): + case (true, .vad): // We have some audio that will require multiple windows and a strategy to chunk them let vad = decodeOptions?.voiceActivityDetector ?? EnergyVAD() let chunker = VADAudioChunker(vad: vad) @@ -749,35 +749,35 @@ open class WhisperKit { maxChunkLength: WhisperKit.windowSamples, decodeOptions: decodeOptions ) - - // Reset the seek times since we've already chunked the audio - var chunkedOptions = decodeOptions - chunkedOptions?.clipTimestamps = [] - let chunkedDecodeOptions = Array(repeating: chunkedOptions, count: audioChunks.count) - - // Send chunked samples to transcribe (note: this is recursive) - let chunkedResults: [Result<[TranscriptionResult], Swift.Error>] = await transcribeWithOptions( - audioArrays: audioChunks.map { $0.audioSamples }, - decodeOptionsArray: chunkedDecodeOptions, - callback: callback - ) - - // Update the seek offsets based on the audio chunks - let updatedTranscriptionResults = chunker.updateSeekOffsetsForResults( - chunkedResults: chunkedResults, - audioChunks: audioChunks - ) - - transcribeResults = updatedTranscriptionResults - default: - // Audio is short enough to transcribe in a single window and doesn't require chunking - transcribeResults = try await runTranscribeTask( - audioArray: audioArray, - decodeOptions: decodeOptions, - callback: callback - ) + + // Reset the seek times since we've already chunked the audio + var chunkedOptions = decodeOptions + chunkedOptions?.clipTimestamps = [] + let chunkedDecodeOptions = Array(repeating: chunkedOptions, count: audioChunks.count) + + // Send chunked samples to transcribe (note: this is recursive) + let chunkedResults: [Result<[TranscriptionResult], Swift.Error>] = await transcribeWithOptions( + audioArrays: audioChunks.map { $0.audioSamples }, + decodeOptionsArray: chunkedDecodeOptions, + callback: callback + ) + + // Update the seek offsets based on the audio chunks + let updatedTranscriptionResults = chunker.updateSeekOffsetsForResults( + chunkedResults: chunkedResults, + audioChunks: audioChunks + ) + + transcribeResults = updatedTranscriptionResults + default: + // Audio is short enough to transcribe in a single window and doesn't require chunking + transcribeResults = try await runTranscribeTask( + audioArray: audioArray, + decodeOptions: decodeOptions, + callback: callback + ) } - + if let decodeOptions, decodeOptions.verbose { Logging.info("Total Transcription Results: \(transcribeResults.count)") for (i, transcribeTaskResult) in transcribeResults.enumerated() { @@ -785,10 +785,10 @@ open class WhisperKit { transcribeTaskResult.logSegments() } } - + return transcribeResults } - + /// Runs the transcription task on a single audio sample array asynchronously. /// - Returns: An array of `TranscriptionResult`. /// - Throws: An error if the transcription fails or if the tokenizer is unavailable. @@ -800,16 +800,16 @@ open class WhisperKit { if modelState != .loaded { try await loadModels() } - + guard let tokenizer else { // Tokenizer required for decoding throw WhisperError.tokenizerUnavailable() } - + let childProgress = Progress() progress.totalUnitCount += 1 progress.addChild(childProgress, withPendingUnitCount: 1) - + let transcribeTask = TranscribeTask( currentTimings: currentTimings, progress: childProgress, @@ -819,25 +819,25 @@ open class WhisperKit { textDecoder: textDecoder, tokenizer: tokenizer ) - + do { try Task.checkCancellation() - + let transcribeTaskResult = try await transcribeTask.run( audioArray: audioArray, decodeOptions: decodeOptions, callback: callback ) - + if let decodeOptions, decodeOptions.verbose { transcribeTaskResult.logTimings() } - + if progress.isFinished { // Reset progress if it is completed progress = Progress() } - + return [transcribeTaskResult] } catch { // Handle cancellation From 184b990379a044ad9290af4579273cb6d08787bc Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 22:29:32 +0900 Subject: [PATCH 17/27] Refactor voice activity detection to use protocols --- .../WhisperKit/Core/Audio/AudioChunker.swift | 4 +- .../Core/Audio/AudioStreamTranscriber.swift | 2 +- Sources/WhisperKit/Core/Audio/EnergyVAD.swift | 42 ++--- .../Core/Audio/VoiceActivityDetectable.swift | 124 ++++++++++++ .../Core/Audio/VoiceActivityDetector.swift | 139 +------------- Sources/WhisperKit/Core/Configurations.swift | 4 +- .../WhisperKit/Core/Text/TokenSampler.swift | 176 +++++++++++++++--- 7 files changed, 290 insertions(+), 201 deletions(-) create mode 100644 Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift diff --git a/Sources/WhisperKit/Core/Audio/AudioChunker.swift b/Sources/WhisperKit/Core/Audio/AudioChunker.swift index 2ee4efb..3b5e091 100644 --- a/Sources/WhisperKit/Core/Audio/AudioChunker.swift +++ b/Sources/WhisperKit/Core/Audio/AudioChunker.swift @@ -46,9 +46,9 @@ public extension AudioChunking { public struct VADAudioChunker: AudioChunking { /// prevent hallucinations at the end of the clip by stopping up to 1.0s early private let windowPadding: Int - private let vad: VoiceActivityDetector + private let vad: any VoiceActivityDetectable - public init(windowPadding: Int = 16000, vad: VoiceActivityDetector? = nil) { + public init(windowPadding: Int = 16000, vad: (any VoiceActivityDetectable)? = nil) { self.windowPadding = windowPadding self.vad = vad ?? EnergyVAD() } diff --git a/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift b/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift index c01229e..7481f98 100644 --- a/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift +++ b/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift @@ -5,7 +5,7 @@ import Foundation @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public extension AudioStreamTranscriber { - struct State { + struct State: Sendable { public var isRecording: Bool = false public var currentFallbacks: Int = 0 public var lastBufferSize: Int = 0 diff --git a/Sources/WhisperKit/Core/Audio/EnergyVAD.swift b/Sources/WhisperKit/Core/Audio/EnergyVAD.swift index 3c8f0e7..53ece40 100644 --- a/Sources/WhisperKit/Core/Audio/EnergyVAD.swift +++ b/Sources/WhisperKit/Core/Audio/EnergyVAD.swift @@ -5,46 +5,27 @@ import Foundation /// Voice activity detection based on energy threshold @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -final class EnergyVAD: VoiceActivityDetector { - var energyThreshold: Float +public struct EnergyVAD: VoiceActivityDetectable { + public let sampleRate: Int + public let frameLengthSamples: Int + public let frameOverlapSamples: Int + public var energyThreshold: Float - /// Initialize a new EnergyVAD instance - /// - Parameters: - /// - sampleRate: Audio sample rate - /// - frameLength: Frame length in seconds - /// - frameOverlap: frame overlap in seconds, this will include `frameOverlap` length audio into the `frameLength` and is helpful to catch audio that starts exactly at chunk boundaries - /// - energyThreshold: minimal energy threshold - convenience init( + public init( sampleRate: Int = WhisperKit.sampleRate, frameLength: Float = 0.1, frameOverlap: Float = 0.0, energyThreshold: Float = 0.02 ) { - self.init( - sampleRate: sampleRate, - // Compute frame length and overlap in number of samples - frameLengthSamples: Int(frameLength * Float(sampleRate)), - frameOverlapSamples: Int(frameOverlap * Float(sampleRate)), - energyThreshold: energyThreshold - ) - } - - required init( - sampleRate: Int = 16000, - frameLengthSamples: Int, - frameOverlapSamples: Int = 0, - energyThreshold: Float = 0.02 - ) { + self.sampleRate = sampleRate + self.frameLengthSamples = Int(frameLength * Float(sampleRate)) + self.frameOverlapSamples = Int(frameOverlap * Float(sampleRate)) self.energyThreshold = energyThreshold - super.init(sampleRate: sampleRate, frameLengthSamples: frameLengthSamples, frameOverlapSamples: frameOverlapSamples) } - - override func voiceActivity(in waveform: [Float]) -> [Bool] { + + public func voiceActivity(in waveform: [Float]) -> [Bool] { let chunkRatio = Double(waveform.count) / Double(frameLengthSamples) - - // Round up if uneven, the final chunk will not be a full `frameLengthSamples` long let count = Int(chunkRatio.rounded(.up)) - let chunkedVoiceActivity = AudioProcessor.calculateVoiceActivityInChunks( of: waveform, chunkCount: count, @@ -52,7 +33,6 @@ final class EnergyVAD: VoiceActivityDetector { frameOverlapSamples: frameOverlapSamples, energyThreshold: energyThreshold ) - return chunkedVoiceActivity } } diff --git a/Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift b/Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift new file mode 100644 index 0000000..3f2d772 --- /dev/null +++ b/Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift @@ -0,0 +1,124 @@ +// +// VoiceActivityDetectable.swift +// whisperkit +// +// Created by Norikazu Muramoto on 2024/10/03. +// + +/// Protocol defining the interface for Voice Activity Detection (VAD) +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) +public protocol VoiceActivityDetectable: Sendable { + var sampleRate: Int { get } + var frameLengthSamples: Int { get } + var frameOverlapSamples: Int { get } + + func voiceActivity(in waveform: [Float]) -> [Bool] + func calculateActiveChunks(in waveform: [Float]) -> [(startIndex: Int, endIndex: Int)] + func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int + func voiceActivityIndexToSeconds(_ index: Int) -> Float + func findLongestSilence(in vadResult: [Bool]) -> (startIndex: Int, endIndex: Int)? + func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] + func calculateNonSilentSeekClips(in waveform: [Float]) -> [(start: Int, end: Int)] + func calculateSeekTimestamps(in waveform: [Float]) -> [(startTime: Float, endTime: Float)] +} + +extension VoiceActivityDetectable { + + public func calculateActiveChunks(in waveform: [Float]) -> [(startIndex: Int, endIndex: Int)] { + let vad = voiceActivity(in: waveform) + var result = [(startIndex: Int, endIndex: Int)]() + var currentStartIndex: Int? + + for (index, vadChunk) in vad.enumerated() { + if vadChunk { + let chunkStart = index * frameLengthSamples + let chunkEnd = min(chunkStart + frameLengthSamples, waveform.count) + + if currentStartIndex != nil { + result[result.count - 1].endIndex = chunkEnd + } else { + currentStartIndex = chunkStart + result.append((startIndex: chunkStart, endIndex: chunkEnd)) + } + } else { + currentStartIndex = nil + } + } + + return result + } + + public func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int { + return index * frameLengthSamples + } + + public func voiceActivityIndexToSeconds(_ index: Int) -> Float { + return Float(voiceActivityIndexToAudioSampleIndex(index)) / Float(sampleRate) + } + + public func findLongestSilence(in vadResult: [Bool]) -> (startIndex: Int, endIndex: Int)? { + var longestStartIndex: Int? + var longestEndIndex: Int? + var longestCount = 0 + var index = 0 + while index < vadResult.count { + if vadResult[index] { + index += 1 + } else { + var endIndex = index + while endIndex < vadResult.count, !vadResult[endIndex] { + endIndex += 1 + } + let count = endIndex - index + if count > longestCount { + longestCount = count + longestStartIndex = index + longestEndIndex = endIndex + } + index = endIndex + } + } + if let longestStartIndex, let longestEndIndex { + return (startIndex: longestStartIndex, endIndex: longestEndIndex) + } else { + return nil + } + } + + // MARK - Utility + + public func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] { + let nonSilentChunks = calculateActiveChunks(in: waveform) + var clipTimestamps = [Float]() + + for chunk in nonSilentChunks { + let startTimestamp = Float(chunk.startIndex) / Float(sampleRate) + let endTimestamp = Float(chunk.endIndex) / Float(sampleRate) + + clipTimestamps.append(contentsOf: [startTimestamp, endTimestamp]) + } + + return clipTimestamps + } + + public func calculateNonSilentSeekClips(in waveform: [Float]) -> [(start: Int, end: Int)] { + let clipTimestamps = voiceActivityClipTimestamps(in: waveform) + let options = DecodingOptions(clipTimestamps: clipTimestamps) + let seekClips = prepareSeekClips(contentFrames: waveform.count, decodeOptions: options) + return seekClips + } + + public func calculateSeekTimestamps(in waveform: [Float]) -> [(startTime: Float, endTime: Float)] { + let nonSilentChunks = calculateActiveChunks(in: waveform) + var seekTimestamps = [(startTime: Float, endTime: Float)]() + + for chunk in nonSilentChunks { + let startTimestamp = Float(chunk.startIndex) / Float(sampleRate) + let endTimestamp = Float(chunk.endIndex) / Float(sampleRate) + + seekTimestamps.append(contentsOf: [(startTime: startTimestamp, endTime: endTimestamp)]) + } + + return seekTimestamps + } +} diff --git a/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift b/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift index bb7ef62..dd4c529 100644 --- a/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift +++ b/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift @@ -6,22 +6,11 @@ import Foundation /// A base class for Voice Activity Detection (VAD), used to identify and separate segments of audio that contain human speech from those that do not. /// Subclasses must implement the `voiceActivity(in:)` method to provide specific voice activity detection functionality. @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -open class VoiceActivityDetector { - /// The sample rate of the audio signal, in samples per second. +public struct VoiceActivityDetector: VoiceActivityDetectable { public let sampleRate: Int - - /// The length of each frame in samples. public let frameLengthSamples: Int - - /// The number of samples overlapping between consecutive frames. public let frameOverlapSamples: Int - - /// Initializes a new `VoiceActivityDetector` instance with the specified parameters. - /// - Parameters: - /// - sampleRate: The sample rate of the audio signal in samples per second. Defaults to 16000. - /// - frameLengthSamples: The length of each frame in samples. - /// - frameOverlapSamples: The number of samples overlapping between consecutive frames. Defaults to 0. - /// - Note: Subclasses should override the `voiceActivity(in:)` method to provide specific VAD functionality. + public init( sampleRate: Int = 16000, frameLengthSamples: Int, @@ -31,126 +20,8 @@ open class VoiceActivityDetector { self.frameLengthSamples = frameLengthSamples self.frameOverlapSamples = frameOverlapSamples } - - /// Analyzes the provided audio waveform to determine which segments contain voice activity. - /// - Parameter waveform: An array of `Float` values representing the audio waveform. - /// - Returns: An array of `Bool` values where `true` indicates the presence of voice activity and `false` indicates silence. - open func voiceActivity(in waveform: [Float]) -> [Bool] { - fatalError("`voiceActivity` must be implemented by subclass") - } - - /// Calculates and returns a list of active audio chunks, each represented by a start and end index. - /// - Parameter waveform: An array of `Float` values representing the audio waveform. - /// - Returns: An array of tuples where each tuple contains the start and end indices of an active audio chunk. - public func calculateActiveChunks(in waveform: [Float]) -> [(startIndex: Int, endIndex: Int)] { - let vad: [Bool] = voiceActivity(in: waveform) - var result = [(startIndex: Int, endIndex: Int)]() - - // Temporary variables to hold the start of the current non-silent segment - var currentStartIndex: Int? - - for (index, vadChunk) in vad.enumerated() { - if vadChunk { - let chunkStart = index * frameLengthSamples - let chunkEnd = min(chunkStart + frameLengthSamples, waveform.count) - - if currentStartIndex != nil { - // If we already have a starting point, just update the end point in the last added segment - result[result.count - 1].endIndex = chunkEnd - } else { - // If there is no current start, this is a new segment - currentStartIndex = chunkStart - result.append((startIndex: chunkStart, endIndex: chunkEnd)) - } - } else { - // Reset currentStartIndex when encountering a silent chunk - currentStartIndex = nil - } - } - - return result - } - - /// Converts a voice activity index to the corresponding audio sample index. - /// - Parameter index: The voice activity index to convert. - /// - Returns: The corresponding audio sample index. - public func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int { - return index * frameLengthSamples - } - - public func voiceActivityIndexToSeconds(_ index: Int) -> Float { - return Float(voiceActivityIndexToAudioSampleIndex(index)) / Float(sampleRate) - } - - /// Identifies the longest continuous period of silence within the provided voice activity detection results. - /// - Parameter vadResult: An array of `Bool` values representing voice activity detection results. - /// - Returns: A tuple containing the start and end indices of the longest silence period, or `nil` if no silence is found. - public func findLongestSilence(in vadResult: [Bool]) -> (startIndex: Int, endIndex: Int)? { - var longestStartIndex: Int? - var longestEndIndex: Int? - var longestCount = 0 - var index = 0 - while index < vadResult.count { - let value = vadResult[index] - if value { - // found non-silence, skip - index += 1 - } else { - // found beginning of silence, find the end - var endIndex = index - while endIndex < vadResult.count, !vadResult[endIndex] { - endIndex += 1 - } - let count = endIndex - index - if count > longestCount { - longestCount = count - longestStartIndex = index - longestEndIndex = endIndex - } - index = endIndex - } - } - if let longestStartIndex, let longestEndIndex { - return (startIndex: longestStartIndex, endIndex: longestEndIndex) - } else { - return nil - } - } - - // MARK - Utility - - func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] { - let nonSilentChunks = calculateActiveChunks(in: waveform) - var clipTimestamps = [Float]() - - for chunk in nonSilentChunks { - let startTimestamp = Float(chunk.startIndex) / Float(sampleRate) - let endTimestamp = Float(chunk.endIndex) / Float(sampleRate) - - clipTimestamps.append(contentsOf: [startTimestamp, endTimestamp]) - } - - return clipTimestamps - } - - func calculateNonSilentSeekClips(in waveform: [Float]) -> [(start: Int, end: Int)] { - let clipTimestamps = voiceActivityClipTimestamps(in: waveform) - let options = DecodingOptions(clipTimestamps: clipTimestamps) - let seekClips = prepareSeekClips(contentFrames: waveform.count, decodeOptions: options) - return seekClips - } - - func calculateSeekTimestamps(in waveform: [Float]) -> [(startTime: Float, endTime: Float)] { - let nonSilentChunks = calculateActiveChunks(in: waveform) - var seekTimestamps = [(startTime: Float, endTime: Float)]() - - for chunk in nonSilentChunks { - let startTimestamp = Float(chunk.startIndex) / Float(sampleRate) - let endTimestamp = Float(chunk.endIndex) / Float(sampleRate) - - seekTimestamps.append(contentsOf: [(startTime: startTimestamp, endTime: endTimestamp)]) - } - - return seekTimestamps + + public func voiceActivity(in waveform: [Float]) -> [Bool] { + fatalError("voiceActivity(in:) must be implemented by conforming types") } } diff --git a/Sources/WhisperKit/Core/Configurations.swift b/Sources/WhisperKit/Core/Configurations.swift index c7a38b3..77d47ce 100644 --- a/Sources/WhisperKit/Core/Configurations.swift +++ b/Sources/WhisperKit/Core/Configurations.swift @@ -143,7 +143,7 @@ public struct DecodingOptions { public var noSpeechThreshold: Float? public var concurrentWorkerCount: Int public var chunkingStrategy: ChunkingStrategy? - public var voiceActivityDetector: VoiceActivityDetector? + public var voiceActivityDetector: (any VoiceActivityDetectable)? public init( verbose: Bool = false, @@ -172,7 +172,7 @@ public struct DecodingOptions { noSpeechThreshold: Float? = 0.6, concurrentWorkerCount: Int = 16, chunkingStrategy: ChunkingStrategy? = nil, - voiceActivityDetector: VoiceActivityDetector? = nil + voiceActivityDetector: (any VoiceActivityDetectable)? = nil ) { self.verbose = verbose self.task = task diff --git a/Sources/WhisperKit/Core/Text/TokenSampler.swift b/Sources/WhisperKit/Core/Text/TokenSampler.swift index 4e6e1b7..1dd9b6e 100644 --- a/Sources/WhisperKit/Core/Text/TokenSampler.swift +++ b/Sources/WhisperKit/Core/Text/TokenSampler.swift @@ -21,73 +21,73 @@ open class GreedyTokenSampler: TokenSampling { public var temperature: FloatType public var eotToken: Int public var decodingOptions: DecodingOptions - + public init(temperature: FloatType, eotToken: Int, decodingOptions: DecodingOptions) { self.temperature = temperature self.eotToken = eotToken self.decodingOptions = decodingOptions } - + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { var softmaxOutput: BNNSNDArrayDescriptor? var argmaxOutput: BNNSNDArrayDescriptor? var softmaxInput: BNNSNDArrayDescriptor? var softmaxInputNeedsDeallocate = false - + var nextToken: Int? - + do { let logitsRawPointer = UnsafeMutableRawBufferPointer( start: logits.dataPointer, count: logits.count * MemoryLayout.stride ) - + let logitsDescriptor = BNNSNDArrayDescriptor( data: logitsRawPointer, scalarType: FloatType.self, shape: .vector(logits.count, stride: 1) )! - + softmaxInput = logitsDescriptor - + // Scale logits by temperature if > 0 if temperature != 0.0 { let scaledLogits = BNNSNDArrayDescriptor.allocateUninitialized( scalarType: FloatType.self, shape: .vector(logits.count, stride: 1) ) - + try! BNNS.applyActivation( activation: BNNS.ActivationFunction.linear(alpha: Float(1 / temperature)), input: logitsDescriptor, output: scaledLogits, batchSize: 1 ) - + softmaxInput = scaledLogits softmaxInputNeedsDeallocate = true } - + // Always softmax once softmaxOutput = BNNSNDArrayDescriptor.allocateUninitialized( scalarType: Float.self, shape: .vector(logits.count, stride: 1) ) - + try BNNS.applyActivation( activation: BNNS.ActivationFunction.softmax, input: softmaxInput!, output: softmaxOutput!, batchSize: 1 ) - + if temperature != 0.0 { // top-k multinomial sampling let k = decodingOptions.topK - + let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float.self, shape: .vector(k, stride: 1)) let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int32.self, shape: .vector(k, stride: 1)) - + try! BNNS.applyTopK( k: k, input: softmaxOutput!, @@ -96,13 +96,13 @@ open class GreedyTokenSampler: TokenSampling { axis: 0, batchSize: 1 ) - + let bestValuesResult = bestValues.makeArray(of: Float.self)! let bestIndicesResult = bestIndices.makeArray(of: Int32.self)! - + bestValues.deallocate() bestIndices.deallocate() - + // multinomial sample from top-k let sumOfbestIndicesResult = bestValuesResult.reduce(0, +) let rnd = Float.random(in: 0.. SamplingResult { var finalTokens = tokens var finalLogProbs = logProbs @@ -164,7 +164,7 @@ open class GreedyTokenSampler: TokenSampling { finalTokens.append(eotToken) finalLogProbs.append(0) } - + return SamplingResult(tokens: finalTokens, logProbs: finalLogProbs, completed: true) } } @@ -175,7 +175,7 @@ open class BeamSearchTokenSampler: TokenSampling { public var patience: Float var maxCandidates: Int var finishedSequences: [Float] - + public init( beamSize: Int, eotToken: Int, @@ -191,18 +191,132 @@ open class BeamSearchTokenSampler: TokenSampling { fatalError("Invalid beam size \(beamSize) or patience \(patience)") } } - + public func reset() { finishedSequences = [] } - + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { // TODO: Implement fatalError("Not implemented: \(#function)") } - + public func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult { // TODO: Implement fatalError("Not implemented: \(#function)") } } + +@available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, *) +open class NTokenSampler: TokenSampling { + public var temperature: Float + public var eotToken: Int + public var decodingOptions: DecodingOptions + + public init(temperature: Float, eotToken: Int, decodingOptions: DecodingOptions) { + self.temperature = temperature + self.eotToken = eotToken + self.decodingOptions = decodingOptions + } + + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { + // MLMultiArrayがFloat32であることを確認 + guard logits.dataType == .float32 else { + fatalError("Logits MLMultiArray must be of type Float32") + } + + let logitsCount = logits.count + + // ロジットデータへのアクセス + let logitsPointer = logits.dataPointer.bindMemory(to: Float.self, capacity: logitsCount) + let logitsBuffer = UnsafeBufferPointer(start: logitsPointer, count: logitsCount) + var logitsArray = [Float](logitsBuffer) + + // 温度が0より大きい場合はロジットをスケーリング + if temperature != 0.0 { + let tempReciprocal = 1.0 / temperature + vDSP_vsmul(logitsArray, 1, [tempReciprocal], &logitsArray, 1, vDSP_Length(logitsCount)) + } + + // ソフトマックス計算 + var softmaxOutput = [Float](repeating: 0, count: logitsCount) + computeSoftmax(logitsArray, result: &softmaxOutput) + + var nextToken: Int = 0 + var nextLogprob: Float = 0.0 + + if temperature != 0.0 { + // トップKのサンプリング + let k = min(decodingOptions.topK, logitsCount) + + // 値とインデックスをペアにしてソート + let indices = Array(0.. $1.0 } + let topKPairs = sortedPairs.prefix(k) + + let topKValues = topKPairs.map { $0.0 } + let topKIndices = topKPairs.map { $0.1 } + + // トップKの確率を正規化 + let sumTopK = topKValues.reduce(0, +) + let normalizedTopKValues = topKValues.map { $0 / sumTopK } + + // トップKからサンプリング + let randomValue = Float.random(in: 0..<1) + var cumulativeProbability: Float = 0.0 + for (i, probability) in normalizedTopKValues.enumerated() { + cumulativeProbability += probability + if randomValue < cumulativeProbability { + nextToken = topKIndices[i] + nextLogprob = log(probability) + break + } + } + } else { + // アーグマックスサンプリング + var maxValue: Float = 0 + var maxIndex: vDSP_Length = 0 + vDSP_maxvi(softmaxOutput, 1, &maxValue, &maxIndex, vDSP_Length(logitsCount)) + nextToken = Int(maxIndex) + nextLogprob = log(maxValue) + } + + let nextTokens = tokens + [nextToken] + let nextLogprobs = logProbs + [nextLogprob] + let completed = nextToken == eotToken + + return SamplingResult(tokens: nextTokens, logProbs: nextLogprobs, completed: completed) + } + + public func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult { + var finalTokens = tokens + var finalLogProbs = logProbs + if tokens.last != eotToken { + finalTokens.append(eotToken) + finalLogProbs.append(0) + } + + return SamplingResult(tokens: finalTokens, logProbs: finalLogProbs, completed: true) + } + + // ソフトマックスを効率的に計算するヘルパー関数 + func computeSoftmax(_ input: [Float], result: inout [Float]) { + var input = input + + // オーバーフローを防ぐために最大値を引く + var maxValue: Float = 0 + vDSP_maxv(input, 1, &maxValue, vDSP_Length(input.count)) + var negativeMax = -maxValue + vDSP_vsadd(input, 1, &negativeMax, &input, 1, vDSP_Length(input.count)) + + // 指数関数を適用 + vvexpf(&result, input, [Int32(input.count)]) + + // 指数関数の合計を計算 + var sumOfExponents: Float = 0 + vDSP_sve(result, 1, &sumOfExponents, vDSP_Length(input.count)) + + // 合計で割って確率を得る + vDSP_vsdiv(result, 1, &sumOfExponents, &result, 1, vDSP_Length(input.count)) + } + } From 933b71b468eccaf5678fcf1781e325cde59ac100 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sat, 5 Oct 2024 22:46:53 +0900 Subject: [PATCH 18/27] Add audio converter initialization in resampling process --- Sources/WhisperKit/Core/Audio/AudioProcessor.swift | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index c422787..6895bab 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -770,6 +770,10 @@ public extension AudioProcessor { inputNode.installTap(onBus: 0, bufferSize: bufferSize, format: nodeFormat) { [weak self] (buffer: AVAudioPCMBuffer, _: AVAudioTime) in var buffer = buffer if !buffer.format.sampleRate.isEqual(to: Double(WhisperKit.sampleRate)) { + guard let converter = AVAudioConverter(from: nodeFormat, to: desiredFormat) else { + Logging.error("Failed to create audio converter") + return + } do { buffer = try Self.resampleBuffer(buffer, with: converter) } catch { From 2dbb87f4b4e50cb880cb43b15294b7dabf3dffa1 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sun, 6 Oct 2024 16:54:08 +0900 Subject: [PATCH 19/27] Refactor AudioProcessor to use SampleRange type --- .../Core/Audio/AudioProcessor.swift | 2 +- Sources/WhisperKit/Core/Audio/EnergyVAD.swift | 14 +++++++- .../WhisperKit/Core/Audio/SampleRange.swift | 10 ++++++ .../Core/Audio/VoiceActivityDetectable.swift | 20 ++++++------ Sources/WhisperKit/Core/Utils.swift | 4 +-- Tests/WhisperKitTests/UnitTests.swift | 32 +++++++++---------- 6 files changed, 52 insertions(+), 30 deletions(-) create mode 100644 Sources/WhisperKit/Core/Audio/SampleRange.swift diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 6895bab..6732ac6 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -452,7 +452,7 @@ public actor AudioProcessor: @preconcurrency AudioProcessing { /// - Returns: an array of tuples indicating the start and end indices of non-silent chunks public static func calculateNonSilentChunks( in signal: [Float] - ) -> [(startIndex: Int, endIndex: Int)] { + ) -> [SampleRange] { EnergyVAD().calculateActiveChunks(in: signal) } diff --git a/Sources/WhisperKit/Core/Audio/EnergyVAD.swift b/Sources/WhisperKit/Core/Audio/EnergyVAD.swift index 53ece40..d8bc97a 100644 --- a/Sources/WhisperKit/Core/Audio/EnergyVAD.swift +++ b/Sources/WhisperKit/Core/Audio/EnergyVAD.swift @@ -10,7 +10,7 @@ public struct EnergyVAD: VoiceActivityDetectable { public let frameLengthSamples: Int public let frameOverlapSamples: Int public var energyThreshold: Float - + public init( sampleRate: Int = WhisperKit.sampleRate, frameLength: Float = 0.1, @@ -22,6 +22,18 @@ public struct EnergyVAD: VoiceActivityDetectable { self.frameOverlapSamples = Int(frameOverlap * Float(sampleRate)) self.energyThreshold = energyThreshold } + + init( + sampleRate: Int = 16000, + frameLengthSamples: Int, + frameOverlapSamples: Int = 0, + energyThreshold: Float = 0.02 + ) { + self.sampleRate = sampleRate + self.frameLengthSamples = frameLengthSamples + self.frameOverlapSamples = frameOverlapSamples + self.energyThreshold = energyThreshold + } public func voiceActivity(in waveform: [Float]) -> [Bool] { let chunkRatio = Double(waveform.count) / Double(frameLengthSamples) diff --git a/Sources/WhisperKit/Core/Audio/SampleRange.swift b/Sources/WhisperKit/Core/Audio/SampleRange.swift new file mode 100644 index 0000000..87469a3 --- /dev/null +++ b/Sources/WhisperKit/Core/Audio/SampleRange.swift @@ -0,0 +1,10 @@ +// +// SampleRange.swift +// whisperkit +// +// Created by Norikazu Muramoto on 2024/10/06. +// + +public typealias FrameRange = (start: Int, end: Int) +public typealias SampleRange = (startIndex: Int, endIndex: Int) +public typealias TimestampRange = (startTime: Float, endTime: Float) diff --git a/Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift b/Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift index 3f2d772..d48e2c4 100644 --- a/Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift +++ b/Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift @@ -13,20 +13,20 @@ public protocol VoiceActivityDetectable: Sendable { var frameOverlapSamples: Int { get } func voiceActivity(in waveform: [Float]) -> [Bool] - func calculateActiveChunks(in waveform: [Float]) -> [(startIndex: Int, endIndex: Int)] + func calculateActiveChunks(in waveform: [Float]) -> [SampleRange] func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int func voiceActivityIndexToSeconds(_ index: Int) -> Float - func findLongestSilence(in vadResult: [Bool]) -> (startIndex: Int, endIndex: Int)? + func findLongestSilence(in vadResult: [Bool]) -> SampleRange? func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] - func calculateNonSilentSeekClips(in waveform: [Float]) -> [(start: Int, end: Int)] - func calculateSeekTimestamps(in waveform: [Float]) -> [(startTime: Float, endTime: Float)] + func calculateNonSilentSeekClips(in waveform: [Float]) -> [FrameRange] + func calculateSeekTimestamps(in waveform: [Float]) -> [TimestampRange] } extension VoiceActivityDetectable { - public func calculateActiveChunks(in waveform: [Float]) -> [(startIndex: Int, endIndex: Int)] { + public func calculateActiveChunks(in waveform: [Float]) -> [SampleRange] { let vad = voiceActivity(in: waveform) - var result = [(startIndex: Int, endIndex: Int)]() + var result = [SampleRange]() var currentStartIndex: Int? for (index, vadChunk) in vad.enumerated() { @@ -56,7 +56,7 @@ extension VoiceActivityDetectable { return Float(voiceActivityIndexToAudioSampleIndex(index)) / Float(sampleRate) } - public func findLongestSilence(in vadResult: [Bool]) -> (startIndex: Int, endIndex: Int)? { + public func findLongestSilence(in vadResult: [Bool]) -> SampleRange? { var longestStartIndex: Int? var longestEndIndex: Int? var longestCount = 0 @@ -101,16 +101,16 @@ extension VoiceActivityDetectable { return clipTimestamps } - public func calculateNonSilentSeekClips(in waveform: [Float]) -> [(start: Int, end: Int)] { + public func calculateNonSilentSeekClips(in waveform: [Float]) -> [FrameRange] { let clipTimestamps = voiceActivityClipTimestamps(in: waveform) let options = DecodingOptions(clipTimestamps: clipTimestamps) let seekClips = prepareSeekClips(contentFrames: waveform.count, decodeOptions: options) return seekClips } - public func calculateSeekTimestamps(in waveform: [Float]) -> [(startTime: Float, endTime: Float)] { + public func calculateSeekTimestamps(in waveform: [Float]) -> [TimestampRange] { let nonSilentChunks = calculateActiveChunks(in: waveform) - var seekTimestamps = [(startTime: Float, endTime: Float)]() + var seekTimestamps = [TimestampRange]() for chunk in nonSilentChunks { let startTimestamp = Float(chunk.startIndex) / Float(sampleRate) diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index b91e069..1b5276f 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -259,7 +259,7 @@ extension AVAudioPCMBuffer { // MARK: - Helpers @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -func prepareSeekClips(contentFrames: Int, decodeOptions: DecodingOptions?) -> [(start: Int, end: Int)] { +func prepareSeekClips(contentFrames: Int, decodeOptions: DecodingOptions?) -> [FrameRange] { let options = decodeOptions ?? DecodingOptions() var seekPoints: [Int] = options.clipTimestamps.map { Int(round($0 * Float(WhisperKit.sampleRate))) } if seekPoints.count == 0 { @@ -270,7 +270,7 @@ func prepareSeekClips(contentFrames: Int, decodeOptions: DecodingOptions?) -> [( seekPoints.append(contentFrames) } - var seekClips: [(start: Int, end: Int)] = [] + var seekClips: [FrameRange] = [] for i in stride(from: 0, to: seekPoints.count, by: 2) { let start = seekPoints[i] let end = i + 1 < seekPoints.count ? seekPoints[i + 1] : contentFrames diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 709e4ec..c92f222 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -1132,38 +1132,38 @@ final class UnitTests: XCTestCase { // When looking for silence boundaries, a smaller frame length is preferred let vadForSilence = EnergyVAD(frameLengthSamples: 320) let nonSilentChunks1 = vadForSilence.calculateActiveChunks(in: []) - XCTAssertEqual(nonSilentChunks1.map(\.startIndex), []) - XCTAssertEqual(nonSilentChunks1.map(\.endIndex), []) + XCTAssertEqual(nonSilentChunks1.map(\SampleRange.startIndex), []) + XCTAssertEqual(nonSilentChunks1.map(\SampleRange.endIndex), []) let nonSilentChunks2 = vadForSilence.calculateActiveChunks(in: Array(repeating: 0, count: 1600)) - XCTAssertEqual(nonSilentChunks2.map(\.startIndex), []) - XCTAssertEqual(nonSilentChunks2.map(\.endIndex), []) + XCTAssertEqual(nonSilentChunks2.map(\SampleRange.startIndex), []) + XCTAssertEqual(nonSilentChunks2.map(\SampleRange.endIndex), []) let nonSilentChunks3 = vadForSilence.calculateActiveChunks(in: Array(repeating: 1, count: 1600)) - XCTAssertEqual(nonSilentChunks3.map(\.startIndex), [0]) - XCTAssertEqual(nonSilentChunks3.map(\.endIndex), [1600]) + XCTAssertEqual(nonSilentChunks3.map(\SampleRange.startIndex), [0]) + XCTAssertEqual(nonSilentChunks3.map(\SampleRange.endIndex), [1600]) let nonSilentChunks4 = vadForSilence.calculateActiveChunks(in: Array(repeating: 0, count: 1600) + Array(repeating: 1, count: 1600)) - XCTAssertEqual(nonSilentChunks4.map(\.startIndex), [1600]) - XCTAssertEqual(nonSilentChunks4.map(\.endIndex), [3200]) + XCTAssertEqual(nonSilentChunks4.map(\SampleRange.startIndex), [1600]) + XCTAssertEqual(nonSilentChunks4.map(\SampleRange.endIndex), [3200]) let nonSilentChunksWithUnevenFrameLength1 = vadForSilence.calculateActiveChunks(in: Array(repeating: 1, count: 1601)) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength1.map(\.startIndex), [0]) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength1.map(\.endIndex), [1601]) + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength1.map(\SampleRange.startIndex), [0]) + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength1.map(\SampleRange.endIndex), [1601]) let nonSilentChunksWithUnevenFrameLength2 = vadForSilence.calculateActiveChunks(in: Array(repeating: 1, count: 1599)) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength2.map(\.startIndex), [0]) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength2.map(\.endIndex), [1599]) + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength2.map(\SampleRange.startIndex), [0]) + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength2.map(\SampleRange.endIndex), [1599]) let nonSilentChunksWithUnevenFrameLength3 = vadForSilence.calculateActiveChunks(in: Array(repeating: 1, count: 1599) + Array(repeating: 0, count: 1600)) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength3.map(\.startIndex), [0]) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength3.map(\.endIndex), [1600]) // frame length + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength3.map(\SampleRange.startIndex), [0]) + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength3.map(\SampleRange.endIndex), [1600]) // frame length // Even with a smaller frame lenth, sometimes we need an overlap to detect them when they are very close to the boundary let vadWithOverlap = EnergyVAD(frameLengthSamples: 320, frameOverlapSamples: 80) let nonSilentChunksWithOverlap = vadWithOverlap.calculateActiveChunks(in: Array(repeating: 0, count: 1600) + Array(repeating: 1, count: 1600)) - XCTAssertEqual(nonSilentChunksWithOverlap.map(\.startIndex), [1280]) - XCTAssertEqual(nonSilentChunksWithOverlap.map(\.endIndex), [3200]) + XCTAssertEqual(nonSilentChunksWithOverlap.map(\SampleRange.startIndex), [1280]) + XCTAssertEqual(nonSilentChunksWithOverlap.map(\SampleRange.endIndex), [3200]) // When specifically looking for speech instead of silence, a larger window is preferred let vadWithLargeWindow = EnergyVAD(frameLength: 0.2, frameOverlap: 0.1) From 368333f624736f2ac8cefda3a71483e1c1f4bedb Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sun, 6 Oct 2024 18:59:04 +0900 Subject: [PATCH 20/27] Make AudioProcessing conform to Actor protocol --- Sources/WhisperKit/Core/Audio/AudioProcessor.swift | 2 +- .../WhisperKit/Core/Audio/AudioStreamTranscriber.swift | 10 +++++++--- Sources/WhisperKit/Core/WhisperKit.swift | 8 ++++++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 6732ac6..731cc0a 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -18,7 +18,7 @@ public struct AudioDevice: Identifiable, Hashable, Sendable { public let name: String } -public protocol AudioProcessing { +public protocol AudioProcessing: Actor { /// Loads audio data from a specified file path. /// - Parameters: /// - audioFilePath: The file path of the audio file. diff --git a/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift b/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift index 7481f98..2c06957 100644 --- a/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift +++ b/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift @@ -79,7 +79,7 @@ public actor AudioStreamTranscriber { return } state.isRecording = true - try audioProcessor.startRecordingLive { [weak self] _ in + try await audioProcessor.startRecordingLive { [weak self] _ in Task { [weak self] in await self?.onAudioBufferCallback() } @@ -90,7 +90,9 @@ public actor AudioStreamTranscriber { public func stopStreamTranscription() { state.isRecording = false - audioProcessor.stopRecording() + Task { + await audioProcessor.stopRecording() + } Logging.info("Realtime transcription has ended") } @@ -106,7 +108,9 @@ public actor AudioStreamTranscriber { } private func onAudioBufferCallback() { - state.bufferEnergy = audioProcessor.getRelativeEnergy() + Task { + state.bufferEnergy = await audioProcessor.getRelativeEnergy() + } } private func onProgressCallback(_ progress: TranscriptionProgress) { diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index a32c024..15e4f84 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -393,12 +393,16 @@ open class WhisperKit { } open func clearState() { - audioProcessor.stopRecording() + Task { + await audioProcessor.stopRecording() + } currentTimings = TranscriptionTimings() } deinit { - audioProcessor.stopRecording() + Task { + await audioProcessor.stopRecording() + } } /// Pass in your own logging callback here From c41fb22fec4a0a1fdca39704ca27417a895b14a2 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sun, 6 Oct 2024 19:25:04 +0900 Subject: [PATCH 21/27] Refactor SegmentSeeker to improve readability and performance --- .../WhisperKit/Core/Text/SegmentSeeker.swift | 500 +++++++----------- 1 file changed, 199 insertions(+), 301 deletions(-) diff --git a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift index 33a45dc..9c7a433 100644 --- a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift +++ b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift @@ -19,7 +19,7 @@ public protocol SegmentSeeking { specialToken: Int, tokenizer: WhisperTokenizer ) -> (Int, [TranscriptionSegment]?) - + func addWordTimestamps( segments: [TranscriptionSegment], alignmentWeights: MLMultiArray, @@ -35,12 +35,11 @@ public protocol SegmentSeeking { } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -open class SegmentSeeker: SegmentSeeking { +public struct SegmentSeeker: SegmentSeeking { public init() {} - + // MARK: - Seek & Segments - - // TODO: simplify this interface + public func findSeekPointAndSegments( decodingResult: DecodingResult, options: DecodingOptions, @@ -52,77 +51,66 @@ open class SegmentSeeker: SegmentSeeking { specialToken: Int, tokenizer: WhisperTokenizer ) -> (Int, [TranscriptionSegment]?) { - // check if we need to skip this segment entirely - // if so, reset currentSegments, continue to next window, otherwise: + // Check if we need to skip this segment entirely var seek = currentSeek let timeOffset = Float(seek) / Float(sampleRate) let secondsPerTimeToken = WhisperKit.secondsPerTimeToken + if let threshold = options.noSpeechThreshold { - // check no speech threshold for segment var shouldSkip = decodingResult.noSpeechProb > threshold - - // check avg logprob threshold for segment + if let logProbThreshold = options.logProbThreshold, - decodingResult.avgLogProb > logProbThreshold - { + decodingResult.avgLogProb > logProbThreshold { // Confidence in overall segment overrides no speech threshold shouldSkip = false } - + if shouldSkip { - // skip one full segment, this one is silent + // Skip one full segment, this one is silent seek += segmentSize return (seek, nil) } } - + var currentSegments: [TranscriptionSegment] = [] - - // loop through all consecutive timestamps and turn them into `TranscriptionSegments` + + // Process tokens to identify timestamps and create segments let currentTokens = decodingResult.tokens let currentLogProbs = decodingResult.tokenLogProbs let isTimestampToken = currentTokens.map { $0 >= timeToken } - - // check if single or double timestamp ending - let lastThreeTokens = isTimestampToken.suffix(3) - let singleTimestampEnding = lastThreeTokens == [false, true, false] - let noTimestampEnding = lastThreeTokens == [false, false, false] - - // find all end indexes of time token pairs + + // Find all end indexes of time token pairs var sliceIndexes = [Int]() - var previousTokenIsTimestamp = false - for (currentTokenIsTimestampIndex, currentTokenIsTimestamp) in isTimestampToken.enumerated() { + for (currentIndex, currentTokenIsTimestamp) in isTimestampToken.enumerated() { if previousTokenIsTimestamp && currentTokenIsTimestamp { - sliceIndexes.append(currentTokenIsTimestampIndex) + sliceIndexes.append(currentIndex) } previousTokenIsTimestamp = currentTokenIsTimestamp } - - // Window contains multiple consecutive timestamps, split into sub-segments + + // Optimize handling of timestamp endings if !sliceIndexes.isEmpty { - // If the last timestamp is not consecutive, we need to add it as the final slice manually - if singleTimestampEnding { - let singleTimestampEndingIndex = isTimestampToken.lastIndex(where: { $0 })! - sliceIndexes.append(singleTimestampEndingIndex + 1) - } else if noTimestampEnding { - sliceIndexes.append(currentTokens.count) - } - + let lastTimestampIndex = isTimestampToken.lastIndex(of: true) ?? currentTokens.count - 1 + sliceIndexes.append(lastTimestampIndex + 1) + var lastSliceStart = 0 for currentSliceEnd in sliceIndexes { let slicedTokens = Array(currentTokens[lastSliceStart..= timeToken } - - let startTimestampSeconds = Float(timestampTokens.first! - timeToken) * secondsPerTimeToken - let endTimestampSeconds = Float(timestampTokens.last! - timeToken) * secondsPerTimeToken - + + guard let firstTimestamp = timestampTokens.first, + let lastTimestamp = timestampTokens.last else { continue } + + let startTimestampSeconds = Float(firstTimestamp - timeToken) * secondsPerTimeToken + let endTimestampSeconds = Float(lastTimestamp - timeToken) * secondsPerTimeToken + // Decode segment text let wordTokens = slicedTokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } let slicedTextTokens = options.skipSpecialTokens ? wordTokens : slicedTokens let sliceText = tokenizer.decode(tokens: slicedTextTokens) - + let newSegment = TranscriptionSegment( id: allSegmentsCount + currentSegments.count, seek: seek, @@ -139,10 +127,9 @@ open class SegmentSeeker: SegmentSeeking { currentSegments.append(newSegment) lastSliceStart = currentSliceEnd } - + // Seek to the last timestamp in the segment - if !noTimestampEnding { - let lastTimestampToken = currentTokens[lastSliceStart - (singleTimestampEnding ? 1 : 0)] - timeToken + if let lastTimestampToken = currentTokens[lastSliceStart - 1] - timeToken as Int? { let lastTimestampSeconds = Float(lastTimestampToken) * secondsPerTimeToken let lastTimestampSamples = Int(lastTimestampSeconds * Float(sampleRate)) seek += lastTimestampSamples @@ -150,23 +137,19 @@ open class SegmentSeeker: SegmentSeeking { seek += segmentSize } } else { - // Model is not giving any consecutive timestamps, so lump all the current tokens together - var durationSeconds = Float(segmentSize) / Float(sampleRate) - - // Find any timestamp that is not 0.00 - let timestampTokens = currentTokens.filter { $0 > timeToken } - - // If there are no consecutive timestamps at all, check if there is at least one timestamp at the end - // If there is at least one, use that to record a more accurate end time - if !timestampTokens.isEmpty, let lastTimestamp = timestampTokens.last { + // Handle case with no consecutive timestamps + let durationSeconds: Float + if let lastTimestamp = currentTokens.last(where: { $0 > timeToken }) { durationSeconds = Float(lastTimestamp - timeToken) * secondsPerTimeToken + } else { + durationSeconds = Float(segmentSize) / Float(sampleRate) } - + // Decode segment text let wordTokens = decodingResult.tokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } let segmentTextTokens = options.skipSpecialTokens ? wordTokens : decodingResult.tokens let segmentText = tokenizer.decode(tokens: segmentTextTokens) - + let newSegment = TranscriptionSegment( id: allSegmentsCount + currentSegments.count, seek: seek, @@ -181,154 +164,131 @@ open class SegmentSeeker: SegmentSeeking { noSpeechProb: decodingResult.noSpeechProb ) currentSegments.append(newSegment) - - // Model has told us there is no more speech in this segment, move on to next + seek += segmentSize - // TODO: use this logic instead once we handle no speech - // seek += Int(durationSeconds * Float(sampleRate)) } - + return (seek, currentSegments) } - + // MARK: - Word Timestamps - - /// Matrix is a 2D array of alignment weights of shape (n, m) where n is the number of rows representing text tokens - /// and m is the number of columns representing audio tokens + func dynamicTimeWarping(withMatrix matrix: MLMultiArray) throws -> (textIndices: [Int], timeIndices: [Int]) { guard matrix.shape.count == 2, let numberOfRows = matrix.shape[0] as? Int, - let numberOfColumns = matrix.shape[1] as? Int - else { + let numberOfColumns = matrix.shape[1] as? Int else { throw WhisperError.segmentingFailed("Invalid alignment matrix shape") } - - // Initialize cost matrix and trace matrix - var costMatrix = Array(repeating: Array(repeating: Double.infinity, count: numberOfColumns + 1), count: numberOfRows + 1) - var traceMatrix = Array(repeating: Array(repeating: -1, count: numberOfColumns + 1), count: numberOfRows + 1) - - costMatrix[0][0] = 0 + + // Flatten the MLMultiArray to a 1D array for Accelerate functions + let matrixData = Array(UnsafeBufferPointer(start: matrix.dataPointer.assumingMemoryBound(to: Double.self), count: numberOfRows * numberOfColumns)) + + // Prepare cost matrix and direction matrix + var costMatrix = [Double](repeating: .infinity, count: (numberOfRows + 1) * (numberOfColumns + 1)) + var directionMatrix = [Int](repeating: -1, count: (numberOfRows + 1) * (numberOfColumns + 1)) + + costMatrix[0] = 0 for i in 1...numberOfColumns { - traceMatrix[0][i] = 2 + directionMatrix[i] = 2 } for i in 1...numberOfRows { - traceMatrix[i][0] = 1 + directionMatrix[i * (numberOfColumns + 1)] = 1 } - + + // Perform DTW using optimized loops for row in 1...numberOfRows { for column in 1...numberOfColumns { - let matrixValue = -matrix[(row - 1) * numberOfColumns + (column - 1)].doubleValue - let costDiagonal = costMatrix[row - 1][column - 1] - let costUp = costMatrix[row - 1][column] - let costLeft = costMatrix[row][column - 1] - + let matrixValue = -matrixData[(row - 1) * numberOfColumns + (column - 1)] + let index = row * (numberOfColumns + 1) + column + let costDiagonal = costMatrix[(row - 1) * (numberOfColumns + 1) + (column - 1)] + let costUp = costMatrix[(row - 1) * (numberOfColumns + 1) + column] + let costLeft = costMatrix[row * (numberOfColumns + 1) + (column - 1)] + let (computedCost, traceValue) = minCostAndTrace( costDiagonal: costDiagonal, costUp: costUp, costLeft: costLeft, matrixValue: matrixValue ) - - costMatrix[row][column] = computedCost - traceMatrix[row][column] = traceValue + + costMatrix[index] = computedCost + directionMatrix[index] = traceValue } } - - let dtw = backtrace(fromTraceMatrix: traceMatrix) - + + let dtw = backtrace(fromDirectionMatrix: directionMatrix, numberOfRows: numberOfRows, numberOfColumns: numberOfColumns) + return dtw } - + func minCostAndTrace(costDiagonal: Double, costUp: Double, costLeft: Double, matrixValue: Double) -> (Double, Int) { let c0 = costDiagonal + matrixValue let c1 = costUp + matrixValue let c2 = costLeft + matrixValue - - if c0 < c1 && c0 < c2 { + + if c0 <= c1 && c0 <= c2 { return (c0, 0) - } else if c1 < c0 && c1 < c2 { + } else if c1 <= c0 && c1 <= c2 { return (c1, 1) } else { return (c2, 2) } } - - func backtrace(fromTraceMatrix traceMatrix: [[Int]]) -> (textIndices: [Int], timeIndices: [Int]) { - var i = traceMatrix.count - 1 - var j = traceMatrix[0].count - 1 - + + func backtrace(fromDirectionMatrix directionMatrix: [Int], numberOfRows: Int, numberOfColumns: Int) -> (textIndices: [Int], timeIndices: [Int]) { + var i = numberOfRows + var j = numberOfColumns + var textIndices = [Int]() var timeIndices = [Int]() - + + let width = numberOfColumns + 1 + while i > 0 || j > 0 { textIndices.append(i - 1) timeIndices.append(j - 1) - - switch traceMatrix[i][j] { - case 0: - i -= 1 - j -= 1 - case 1: - i -= 1 - case 2: - j -= 1 - default: - break + + let dir = directionMatrix[i * width + j] + switch dir { + case 0: + i -= 1 + j -= 1 + case 1: + i -= 1 + case 2: + j -= 1 + default: + break } } - + return (textIndices.reversed(), timeIndices.reversed()) } - + func mergePunctuations(alignment: [WordTiming], prepended: String, appended: String) -> [WordTiming] { - var prependedAlignment = [WordTiming]() - var appendedAlignment = [WordTiming]() - - // Include the first word if it's not a prepended punctuation - if !alignment.isEmpty && !prepended.contains(alignment[0].word.trimmingCharacters(in: .whitespaces)) { - prependedAlignment.append(alignment[0]) - } - - // Merge prepended punctuations - for i in 1.. [WordTiming] { - // TODO: Use accelerate framework for these two, they take roughly the same time let (textIndices, timeIndices) = try dynamicTimeWarping(withMatrix: alignmentWeights) let (words, wordTokens) = tokenizer.splitToWordTokens(tokenIds: wordTokenIds) - + if wordTokens.count <= 1 { return [] } - + // Calculate start times and end times - var startTimes: [Float] = [0.0] - var endTimes = [Float]() - var currentTokenIndex = textIndices.first ?? 0 - for index in 0.. [TranscriptionSegment]? { - // Initialize arrays to hold the extracted and filtered data + // Prepare data for alignment var wordTokenIds = [Int]() var filteredLogProbs = [Float]() var filteredIndices = [Int]() - var lastSpeechTimestamp = lastSpeechTimestamp - - // Iterate through each segment + var indexOffset = 0 for segment in segments { for (index, token) in segment.tokens.enumerated() { wordTokenIds.append(token) - filteredIndices.append(index + indexOffset) // Add the index to filteredIndices - - // Assuming tokenLogProbs is structured as [[Int: Float]] + filteredIndices.append(index + indexOffset) if let logProb = segment.tokenLogProbs[index][token] { filteredLogProbs.append(logProb) } } - - // Update the indexOffset as we start a new segment indexOffset += segment.tokens.count } - - // Filter alignmentWeights using filteredIndices + + // Efficiently filter alignmentWeights using filteredIndices let shape = alignmentWeights.shape guard let columnCount = shape.last?.intValue else { throw WhisperError.segmentingFailed("Invalid shape in alignmentWeights") } - - let filteredAlignmentWeights = initMLMultiArray(shape: [filteredIndices.count, columnCount] as [NSNumber], dataType: alignmentWeights.dataType, initialValue: FloatType(0)) - - alignmentWeights.withUnsafeMutableBytes { weightsPointer, weightsStride in - filteredAlignmentWeights.withUnsafeMutableBytes { filteredWeightsPointer, filteredWeightsStride in - for (newIndex, originalIndex) in filteredIndices.enumerated() { - let sourcePointer = weightsPointer.baseAddress!.advanced(by: Int(originalIndex * columnCount * MemoryLayout.stride)) - let destinationPointer = filteredWeightsPointer.baseAddress!.advanced(by: Int(newIndex * columnCount * MemoryLayout.stride)) - - memcpy(destinationPointer, sourcePointer, columnCount * MemoryLayout.stride) - } - } - } - - Logging.debug("Alignment weights shape: \(filteredAlignmentWeights.shape)") - - var alignment = try findAlignment( + + let filteredAlignmentWeights = try filterAlignmentWeights( + alignmentWeights: alignmentWeights, + filteredIndices: filteredIndices, + rowCount: filteredIndices.count, + columnCount: columnCount + ) + + let alignment = try findAlignment( wordTokenIds: wordTokenIds, alignmentWeights: filteredAlignmentWeights, tokenLogProbs: filteredLogProbs, tokenizer: tokenizer, timings: timings ) - - // TODO: This section is considered a "hack" in the source repo - // Reference: https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/timing.py#L305 - var wordDurations = alignment.map { $0.end - $0.start } - wordDurations = wordDurations.filter { $0 > 0 } - - let medianDuration: Float = wordDurations.isEmpty ? 0.0 : wordDurations.sorted(by: <)[wordDurations.count / 2] - let constrainedMedianDuration = min(0.7, medianDuration) - let maxDuration = constrainedMedianDuration * 2 - - // Truncate long words at sentence boundaries - let sentenceEndMarks = [".", "。", "!", "!", "?", "?"] - if !wordDurations.isEmpty { - for i in 1.. maxDuration { - if sentenceEndMarks.contains(alignment[i].word) { - alignment[i].end = alignment[i].start + maxDuration - } else if i > 0, sentenceEndMarks.contains(alignment[i - 1].word) { - alignment[i].start = alignment[i].end - maxDuration - } - } - } - } - + // Process alignment for punctuations let mergedAlignment = mergePunctuations(alignment: alignment, prepended: prependPunctuations, appended: appendPunctuations) - + var wordIndex = 0 let timeOffset = Float(seek) / Float(WhisperKit.sampleRate) var updatedSegments = [TranscriptionSegment]() - + for segment in segments { var savedTokens = 0 let textTokens = segment.tokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } var wordsInSegment = [WordTiming]() - - for timing in mergedAlignment[wordIndex...] where savedTokens < textTokens.count { + + while wordIndex < mergedAlignment.count && savedTokens < textTokens.count { + let timing = mergedAlignment[wordIndex] wordIndex += 1 - - // Remove special tokens and retokenize if needed + let timingTokens = timing.tokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } if timingTokens.isEmpty { continue } - - let start = (timeOffset + timing.start).rounded(2) - let end = (timeOffset + timing.end).rounded(2) - let probability = timing.probability.rounded(2) + + let start = (timeOffset + timing.start).rounded(toPlaces: 2) + let end = (timeOffset + timing.end).rounded(toPlaces: 2) + let probability = timing.probability.rounded(toPlaces: 2) let wordTiming = WordTiming(word: timing.word, tokens: timingTokens, start: start, end: end, probability: probability) wordsInSegment.append(wordTiming) - + savedTokens += timingTokens.count } - + // Create an updated segment with the word timings var updatedSegment = segment - - // TODO: This section is considered a "hack" in the source repo - // Reference: https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/timing.py#L342 - // Truncate long words at segment boundaries - if let firstWord = wordsInSegment.first, let lastWord = wordsInSegment.last { - // Logic for the first word - if firstWord.end - lastSpeechTimestamp > constrainedMedianDuration * 4 && - (firstWord.end - firstWord.start > maxDuration || - (wordsInSegment.count > 1 && wordsInSegment[1].end - firstWord.start > maxDuration * 2)) - { - if wordsInSegment.count > 1 && wordsInSegment[1].end - wordsInSegment[1].start > maxDuration { - let boundary = max(wordsInSegment[1].end / 2, wordsInSegment[1].end - maxDuration) - wordsInSegment[0].end = boundary - wordsInSegment[1].start = boundary - } - wordsInSegment[0].start = max(lastSpeechTimestamp, firstWord.end - maxDuration) - } - - // Prefer segment-level start timestamp if the first word is too long. - if segment.start < firstWord.end && segment.start - 0.5 > firstWord.start { - wordsInSegment[0].start = max(0, min(firstWord.end - constrainedMedianDuration, segment.start)) - } else { - updatedSegment.start = firstWord.start - } - - // Prefer segment-level end timestamp if the last word is too long. - if updatedSegment.end > lastWord.start && segment.end + 0.5 < lastWord.end { - wordsInSegment[wordsInSegment.count - 1].end = max(lastWord.start + constrainedMedianDuration, segment.end) - } else { - updatedSegment.end = lastWord.end - } - - lastSpeechTimestamp = updatedSegment.end - } - updatedSegment.words = wordsInSegment updatedSegments.append(updatedSegment) } - + return updatedSegments } + + private func filterAlignmentWeights( + alignmentWeights: MLMultiArray, + filteredIndices: [Int], + rowCount: Int, + columnCount: Int + ) throws -> MLMultiArray { + let filteredAlignmentWeights = try MLMultiArray(shape: [rowCount, columnCount] as [NSNumber], dataType: .double) + + let sourcePointer = alignmentWeights.dataPointer.assumingMemoryBound(to: Double.self) + let destinationPointer = filteredAlignmentWeights.dataPointer.assumingMemoryBound(to: Double.self) + + for (newIndex, originalIndex) in filteredIndices.enumerated() { + let sourceRow = sourcePointer.advanced(by: originalIndex * columnCount) + let destinationRow = destinationPointer.advanced(by: newIndex * columnCount) + destinationRow.update(from: sourceRow, count: columnCount) + } + + return filteredAlignmentWeights + } +} + +extension Float { + func rounded(toPlaces places: Int) -> Float { + let divisor = pow(10, Float(places)) + return (self * divisor).rounded() / divisor + } } From 621b1f3d1fd7430531d86e26d88fea6a8d6a9b4f Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sun, 6 Oct 2024 19:57:51 +0900 Subject: [PATCH 22/27] Refactor SegmentSeeker to improve clarity and efficiency --- .../WhisperKit/Core/Text/SegmentSeeker.swift | 88 ++++++++++++------- 1 file changed, 54 insertions(+), 34 deletions(-) diff --git a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift index 9c7a433..3858458 100644 --- a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift +++ b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift @@ -51,7 +51,6 @@ public struct SegmentSeeker: SegmentSeeking { specialToken: Int, tokenizer: WhisperTokenizer ) -> (Int, [TranscriptionSegment]?) { - // Check if we need to skip this segment entirely var seek = currentSeek let timeOffset = Float(seek) / Float(sampleRate) let secondsPerTimeToken = WhisperKit.secondsPerTimeToken @@ -61,12 +60,10 @@ public struct SegmentSeeker: SegmentSeeking { if let logProbThreshold = options.logProbThreshold, decodingResult.avgLogProb > logProbThreshold { - // Confidence in overall segment overrides no speech threshold shouldSkip = false } if shouldSkip { - // Skip one full segment, this one is silent seek += segmentSize return (seek, nil) } @@ -74,12 +71,11 @@ public struct SegmentSeeker: SegmentSeeking { var currentSegments: [TranscriptionSegment] = [] - // Process tokens to identify timestamps and create segments + // トークンを処理してタイムスタンプを特定し、セグメントを作成 let currentTokens = decodingResult.tokens let currentLogProbs = decodingResult.tokenLogProbs let isTimestampToken = currentTokens.map { $0 >= timeToken } - // Find all end indexes of time token pairs var sliceIndexes = [Int]() var previousTokenIsTimestamp = false for (currentIndex, currentTokenIsTimestamp) in isTimestampToken.enumerated() { @@ -89,7 +85,6 @@ public struct SegmentSeeker: SegmentSeeking { previousTokenIsTimestamp = currentTokenIsTimestamp } - // Optimize handling of timestamp endings if !sliceIndexes.isEmpty { let lastTimestampIndex = isTimestampToken.lastIndex(of: true) ?? currentTokens.count - 1 sliceIndexes.append(lastTimestampIndex + 1) @@ -106,7 +101,6 @@ public struct SegmentSeeker: SegmentSeeking { let startTimestampSeconds = Float(firstTimestamp - timeToken) * secondsPerTimeToken let endTimestampSeconds = Float(lastTimestamp - timeToken) * secondsPerTimeToken - // Decode segment text let wordTokens = slicedTokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } let slicedTextTokens = options.skipSpecialTokens ? wordTokens : slicedTokens let sliceText = tokenizer.decode(tokens: slicedTextTokens) @@ -128,16 +122,14 @@ public struct SegmentSeeker: SegmentSeeking { lastSliceStart = currentSliceEnd } - // Seek to the last timestamp in the segment - if let lastTimestampToken = currentTokens[lastSliceStart - 1] - timeToken as Int? { - let lastTimestampSeconds = Float(lastTimestampToken) * secondsPerTimeToken + if let lastTimestampToken = currentTokens[lastSliceStart - 1] as Int? { + let lastTimestampSeconds = Float(lastTimestampToken - timeToken) * secondsPerTimeToken let lastTimestampSamples = Int(lastTimestampSeconds * Float(sampleRate)) seek += lastTimestampSamples } else { seek += segmentSize } } else { - // Handle case with no consecutive timestamps let durationSeconds: Float if let lastTimestamp = currentTokens.last(where: { $0 > timeToken }) { durationSeconds = Float(lastTimestamp - timeToken) * secondsPerTimeToken @@ -145,7 +137,6 @@ public struct SegmentSeeker: SegmentSeeking { durationSeconds = Float(segmentSize) / Float(sampleRate) } - // Decode segment text let wordTokens = decodingResult.tokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } let segmentTextTokens = options.skipSpecialTokens ? wordTokens : decodingResult.tokens let segmentText = tokenizer.decode(tokens: segmentTextTokens) @@ -180,12 +171,27 @@ public struct SegmentSeeker: SegmentSeeking { throw WhisperError.segmentingFailed("Invalid alignment matrix shape") } - // Flatten the MLMultiArray to a 1D array for Accelerate functions - let matrixData = Array(UnsafeBufferPointer(start: matrix.dataPointer.assumingMemoryBound(to: Double.self), count: numberOfRows * numberOfColumns)) + let elementCount = numberOfRows * numberOfColumns + var matrixData = [Double](repeating: 0.0, count: elementCount) - // Prepare cost matrix and direction matrix - var costMatrix = [Double](repeating: .infinity, count: (numberOfRows + 1) * (numberOfColumns + 1)) - var directionMatrix = [Int](repeating: -1, count: (numberOfRows + 1) * (numberOfColumns + 1)) + switch matrix.dataType { + case .double: + let pointer = matrix.dataPointer.assumingMemoryBound(to: Double.self) + for i in 0.. [TranscriptionSegment]? { - // Prepare data for alignment + // アライメントのためのデータを準備 var wordTokenIds = [Int]() var filteredLogProbs = [Float]() var filteredIndices = [Int]() + var lastSpeechTimestamp = lastSpeechTimestamp var indexOffset = 0 for segment in segments { @@ -371,7 +377,7 @@ public struct SegmentSeeker: SegmentSeeking { indexOffset += segment.tokens.count } - // Efficiently filter alignmentWeights using filteredIndices + // alignmentWeights を効率的にフィルタリング let shape = alignmentWeights.shape guard let columnCount = shape.last?.intValue else { throw WhisperError.segmentingFailed("Invalid shape in alignmentWeights") @@ -384,7 +390,7 @@ public struct SegmentSeeker: SegmentSeeking { columnCount: columnCount ) - let alignment = try findAlignment( + var alignment = try findAlignment( wordTokenIds: wordTokenIds, alignmentWeights: filteredAlignmentWeights, tokenLogProbs: filteredLogProbs, @@ -392,7 +398,7 @@ public struct SegmentSeeker: SegmentSeeking { timings: timings ) - // Process alignment for punctuations + // 句読点の処理 let mergedAlignment = mergePunctuations(alignment: alignment, prepended: prependPunctuations, appended: appendPunctuations) var wordIndex = 0 @@ -426,7 +432,7 @@ public struct SegmentSeeker: SegmentSeeking { savedTokens += timingTokens.count } - // Create an updated segment with the word timings + // word timings を持つ更新されたセグメントを作成 var updatedSegment = segment updatedSegment.words = wordsInSegment updatedSegments.append(updatedSegment) @@ -441,15 +447,29 @@ public struct SegmentSeeker: SegmentSeeking { rowCount: Int, columnCount: Int ) throws -> MLMultiArray { - let filteredAlignmentWeights = try MLMultiArray(shape: [rowCount, columnCount] as [NSNumber], dataType: .double) + let filteredAlignmentWeights = try MLMultiArray(shape: [rowCount, columnCount] as [NSNumber], dataType: alignmentWeights.dataType) - let sourcePointer = alignmentWeights.dataPointer.assumingMemoryBound(to: Double.self) - let destinationPointer = filteredAlignmentWeights.dataPointer.assumingMemoryBound(to: Double.self) - - for (newIndex, originalIndex) in filteredIndices.enumerated() { - let sourceRow = sourcePointer.advanced(by: originalIndex * columnCount) - let destinationRow = destinationPointer.advanced(by: newIndex * columnCount) - destinationRow.update(from: sourceRow, count: columnCount) + switch alignmentWeights.dataType { + case .double: + let sourcePointer = alignmentWeights.dataPointer.assumingMemoryBound(to: Double.self) + let destinationPointer = filteredAlignmentWeights.dataPointer.assumingMemoryBound(to: Double.self) + + for (newIndex, originalIndex) in filteredIndices.enumerated() { + let sourceRow = sourcePointer.advanced(by: originalIndex * columnCount) + let destinationRow = destinationPointer.advanced(by: newIndex * columnCount) + destinationRow.update(from: sourceRow, count: columnCount) + } + case .float32: + let sourcePointer = alignmentWeights.dataPointer.assumingMemoryBound(to: Float.self) + let destinationPointer = filteredAlignmentWeights.dataPointer.assumingMemoryBound(to: Float.self) + + for (newIndex, originalIndex) in filteredIndices.enumerated() { + let sourceRow = sourcePointer.advanced(by: originalIndex * columnCount) + let destinationRow = destinationPointer.advanced(by: newIndex * columnCount) + destinationRow.update(from: sourceRow, count: columnCount) + } + default: + throw WhisperError.segmentingFailed("Unsupported MLMultiArray data type: \(alignmentWeights.dataType)") } return filteredAlignmentWeights From f646268f64ff2e99c7f88326d1e004871c696e43 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sun, 6 Oct 2024 20:07:43 +0900 Subject: [PATCH 23/27] Refactor SegmentSeeker to simplify alignment handling --- .../WhisperKit/Core/Text/SegmentSeeker.swift | 33 +++++-------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift index 3858458..25772d7 100644 --- a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift +++ b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift @@ -363,9 +363,8 @@ public struct SegmentSeeker: SegmentSeeking { var wordTokenIds = [Int]() var filteredLogProbs = [Float]() var filteredIndices = [Int]() - var lastSpeechTimestamp = lastSpeechTimestamp - var indexOffset = 0 + for segment in segments { for (index, token) in segment.tokens.enumerated() { wordTokenIds.append(token) @@ -390,7 +389,7 @@ public struct SegmentSeeker: SegmentSeeking { columnCount: columnCount ) - var alignment = try findAlignment( + let alignment = try findAlignment( wordTokenIds: wordTokenIds, alignmentWeights: filteredAlignmentWeights, tokenLogProbs: filteredLogProbs, @@ -449,27 +448,13 @@ public struct SegmentSeeker: SegmentSeeking { ) throws -> MLMultiArray { let filteredAlignmentWeights = try MLMultiArray(shape: [rowCount, columnCount] as [NSNumber], dataType: alignmentWeights.dataType) - switch alignmentWeights.dataType { - case .double: - let sourcePointer = alignmentWeights.dataPointer.assumingMemoryBound(to: Double.self) - let destinationPointer = filteredAlignmentWeights.dataPointer.assumingMemoryBound(to: Double.self) - - for (newIndex, originalIndex) in filteredIndices.enumerated() { - let sourceRow = sourcePointer.advanced(by: originalIndex * columnCount) - let destinationRow = destinationPointer.advanced(by: newIndex * columnCount) - destinationRow.update(from: sourceRow, count: columnCount) - } - case .float32: - let sourcePointer = alignmentWeights.dataPointer.assumingMemoryBound(to: Float.self) - let destinationPointer = filteredAlignmentWeights.dataPointer.assumingMemoryBound(to: Float.self) - - for (newIndex, originalIndex) in filteredIndices.enumerated() { - let sourceRow = sourcePointer.advanced(by: originalIndex * columnCount) - let destinationRow = destinationPointer.advanced(by: newIndex * columnCount) - destinationRow.update(from: sourceRow, count: columnCount) - } - default: - throw WhisperError.segmentingFailed("Unsupported MLMultiArray data type: \(alignmentWeights.dataType)") + let sourcePointer = alignmentWeights.dataPointer.assumingMemoryBound(to: UInt16.self) + let destinationPointer = filteredAlignmentWeights.dataPointer.assumingMemoryBound(to: UInt16.self) + + for (newIndex, originalIndex) in filteredIndices.enumerated() { + let sourceRow = sourcePointer.advanced(by: originalIndex * columnCount) + let destinationRow = destinationPointer.advanced(by: newIndex * columnCount) + destinationRow.update(from: sourceRow, count: columnCount) } return filteredAlignmentWeights From db661669274391c9159ed09dfbdee5a2cab04037 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sun, 6 Oct 2024 20:17:39 +0900 Subject: [PATCH 24/27] Refactor SegmentSeeker to handle Float16 data type --- .../WhisperKit/Core/Text/SegmentSeeker.swift | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift index 25772d7..a411b5d 100644 --- a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift +++ b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift @@ -174,19 +174,11 @@ public struct SegmentSeeker: SegmentSeeking { let elementCount = numberOfRows * numberOfColumns var matrixData = [Double](repeating: 0.0, count: elementCount) - switch matrix.dataType { - case .double: - let pointer = matrix.dataPointer.assumingMemoryBound(to: Double.self) - for i in 0.. Date: Sun, 6 Oct 2024 20:18:19 +0900 Subject: [PATCH 25/27] Remove unnecessary comments in SegmentSeeker.swift --- Sources/WhisperKit/Core/Text/SegmentSeeker.swift | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift index a411b5d..befa19b 100644 --- a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift +++ b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift @@ -69,9 +69,7 @@ public struct SegmentSeeker: SegmentSeeking { } } - var currentSegments: [TranscriptionSegment] = [] - - // トークンを処理してタイムスタンプを特定し、セグメントを作成 + var currentSegments: [TranscriptionSegment] = [] let currentTokens = decodingResult.tokens let currentLogProbs = decodingResult.tokenLogProbs let isTimestampToken = currentTokens.map { $0 >= timeToken } From bb66ae149d83d8d83f74a28c6992194f78dcf8f5 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Sun, 6 Oct 2024 20:26:51 +0900 Subject: [PATCH 26/27] Refactor SegmentSeeker for improved clarity and performance --- .../WhisperKit/Core/Text/SegmentSeeker.swift | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift index befa19b..3aa35ff 100644 --- a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift +++ b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift @@ -51,6 +51,7 @@ public struct SegmentSeeker: SegmentSeeking { specialToken: Int, tokenizer: WhisperTokenizer ) -> (Int, [TranscriptionSegment]?) { + // このセグメントをスキップする必要があるか確認 var seek = currentSeek let timeOffset = Float(seek) / Float(sampleRate) let secondsPerTimeToken = WhisperKit.secondsPerTimeToken @@ -69,7 +70,8 @@ public struct SegmentSeeker: SegmentSeeking { } } - var currentSegments: [TranscriptionSegment] = [] + var currentSegments: [TranscriptionSegment] = [] + let currentTokens = decodingResult.tokens let currentLogProbs = decodingResult.tokenLogProbs let isTimestampToken = currentTokens.map { $0 >= timeToken } @@ -99,6 +101,7 @@ public struct SegmentSeeker: SegmentSeeking { let startTimestampSeconds = Float(firstTimestamp - timeToken) * secondsPerTimeToken let endTimestampSeconds = Float(lastTimestamp - timeToken) * secondsPerTimeToken + // セグメントテキストをデコード let wordTokens = slicedTokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } let slicedTextTokens = options.skipSpecialTokens ? wordTokens : slicedTokens let sliceText = tokenizer.decode(tokens: slicedTextTokens) @@ -120,6 +123,7 @@ public struct SegmentSeeker: SegmentSeeking { lastSliceStart = currentSliceEnd } + // セグメント内の最後のタイムスタンプまでシークを進める if let lastTimestampToken = currentTokens[lastSliceStart - 1] as Int? { let lastTimestampSeconds = Float(lastTimestampToken - timeToken) * secondsPerTimeToken let lastTimestampSamples = Int(lastTimestampSeconds * Float(sampleRate)) @@ -128,6 +132,7 @@ public struct SegmentSeeker: SegmentSeeking { seek += segmentSize } } else { + // 連続したタイムスタンプがない場合の処理 let durationSeconds: Float if let lastTimestamp = currentTokens.last(where: { $0 > timeToken }) { durationSeconds = Float(lastTimestamp - timeToken) * secondsPerTimeToken @@ -135,6 +140,7 @@ public struct SegmentSeeker: SegmentSeeking { durationSeconds = Float(segmentSize) / Float(sampleRate) } + // セグメントテキストをデコード let wordTokens = decodingResult.tokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } let segmentTextTokens = options.skipSpecialTokens ? wordTokens : decodingResult.tokens let segmentText = tokenizer.decode(tokens: segmentTextTokens) @@ -169,16 +175,19 @@ public struct SegmentSeeker: SegmentSeeking { throw WhisperError.segmentingFailed("Invalid alignment matrix shape") } + // MLMultiArray を Float16 型として扱う let elementCount = numberOfRows * numberOfColumns - var matrixData = [Double](repeating: 0.0, count: elementCount) + let pointer = matrix.dataPointer.bindMemory(to: UInt16.self, capacity: elementCount) - let pointer = matrix.dataPointer.assumingMemoryBound(to: UInt16.self) + // Float16 から Double に変換しながらデータを読み込む + var matrixData = [Double](repeating: 0.0, count: elementCount) for i in 0.. MLMultiArray { - let filteredAlignmentWeights = try MLMultiArray(shape: [rowCount, columnCount] as [NSNumber], dataType: alignmentWeights.dataType) - - let sourcePointer = alignmentWeights.dataPointer.assumingMemoryBound(to: UInt16.self) - let destinationPointer = filteredAlignmentWeights.dataPointer.assumingMemoryBound(to: UInt16.self) + let filteredAlignmentWeights = try MLMultiArray(shape: [rowCount, columnCount] as [NSNumber], dataType: .float16) + let sourcePointer = alignmentWeights.dataPointer.bindMemory(to: UInt16.self, capacity: alignmentWeights.count) + let destinationPointer = filteredAlignmentWeights.dataPointer.bindMemory(to: UInt16.self, capacity: filteredAlignmentWeights.count) for (newIndex, originalIndex) in filteredIndices.enumerated() { let sourceRow = sourcePointer.advanced(by: originalIndex * columnCount) From 8bfdd882157c479356e6f5627f7f27b59f473e06 Mon Sep 17 00:00:00 2001 From: 1amageek Date: Thu, 24 Oct 2024 03:40:41 +0900 Subject: [PATCH 27/27] Refactor audio processor deinit and improve memory management --- .../WhisperKit/Core/Audio/AudioProcessor.swift | 16 ++++++++-------- Sources/WhisperKit/Core/WhisperKit.swift | 12 ++++++++++-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index 731cc0a..cd86e81 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -522,7 +522,9 @@ public actor AudioProcessor: @preconcurrency AudioProcessing { } let frameLength = Int(buffer.frameLength) let startPointer = channelData[0] - var result = [Float](unsafeUninitializedCapacity: frameLength) { bufferPointer, initializedCount in + let result = [Float]( + unsafeUninitializedCapacity: frameLength + ) { bufferPointer, initializedCount in vDSP_mmov( startPointer, bufferPointer.baseAddress!, @@ -654,9 +656,11 @@ public actor AudioProcessor: @preconcurrency AudioProcessing { #endif deinit { - Task { - await self.stopRecording() - } + audioEngine?.stop() + audioEngine = nil + + audioSamples.removeAll() + audioEnergy.removeAll() } } @@ -770,10 +774,6 @@ public extension AudioProcessor { inputNode.installTap(onBus: 0, bufferSize: bufferSize, format: nodeFormat) { [weak self] (buffer: AVAudioPCMBuffer, _: AVAudioTime) in var buffer = buffer if !buffer.format.sampleRate.isEqual(to: Double(WhisperKit.sampleRate)) { - guard let converter = AVAudioConverter(from: nodeFormat, to: desiredFormat) else { - Logging.error("Failed to create audio converter") - return - } do { buffer = try Self.resampleBuffer(buffer, with: converter) } catch { diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 15e4f84..9aaae39 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -400,9 +400,17 @@ open class WhisperKit { } deinit { - Task { - await audioProcessor.stopRecording() + modelState = .unloading + if let featureExtractor = featureExtractor as? WhisperMLModel { + featureExtractor.unloadModel() } + if let audioEncoder = audioEncoder as? WhisperMLModel { + audioEncoder.unloadModel() + } + if let textDecoder = textDecoder as? WhisperMLModel { + textDecoder.unloadModel() + } + modelState = .unloaded } /// Pass in your own logging callback here