Skip to content

Commit

Permalink
Add simpler generation interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Oct 19, 2024
1 parent 26d2535 commit 3553c91
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 134 deletions.
4 changes: 2 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ let package = Package(
],
path: "Sources/F5TTS",
resources: [
.copy("mel_filters.npy"),
.copy("test_en_1_ref_short.wav")
.copy("Resources/test_en_1_ref_short.wav"),
.copy("Resources/mel_filters.npy")
]
),
.executableTarget(
Expand Down
31 changes: 20 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@

# F5 TTS for Swift (WIP)
# F5 TTS for Swift

Implementation of [F5-TTS](https://arxiv.org/abs/2410.06885) in Swift, using the [MLX Swift](https://github.com/ml-explore/mlx-swift) framework.

You can listen to a [sample here](https://s3.amazonaws.com/lucasnewman.datasets/f5tts/sample.wav) that was generated in ~11 seconds on an M3 Max MacBook Pro.

See the [Python repository](https://github.com/lucasnewman/f5-tts-mlx) for additional details on the model architecture.

This repository is based on the original Pytorch implementation available [here](https://github.com/SWivid/F5-TTS).


Expand All @@ -19,21 +20,29 @@ A pretrained model is available [on Huggingface](https://hf.co/lucasnewman/f5-tt
## Usage

```swift
import Vocos
import F5TTS

let f5tts = try await F5TTS.fromPretrained(repoId: "lucasnewman/f5-tts-mlx")
let vocos = try await Vocos.fromPretrained(repoId: "lucasnewman/vocos-mel-24khz-mlx") // if decoding to audio output

let inputAudio = MLXArray(...)
let generatedAudio = try await f5tts.generate(text: "The quick brown fox jumped over the lazy dog.")
```

The result is an MLXArray with 24kHz audio samples.

If you want to use your own reference audio sample, make sure it's a mono, 24kHz wav file of around 5-10 seconds:

```swift
let generatedAudio = try await f5tts.generate(
text: "The quick brown fox jumped over the lazy dog.",
referenceAudioURL: ...,
referenceAudioText: "This is the caption for the reference audio."
)
```

You can convert an audio file to the correct format with ffmpeg like this:

let (outputAudio, _) = f5tts.sample(
cond: inputAudio,
text: ["This is the caption for the reference audio and generation text."],
duration: ...,
vocoder: vocos.decode) { progress in
print("Progress: \(Int(progress * 100))%")
}
```bash
ffmpeg -i /path/to/audio.wav -ac 1 -ar 24000 -sample_fmt s16 -t 10 /path/to/output_audio.wav
```

## Appreciation
Expand Down
231 changes: 175 additions & 56 deletions Sources/F5TTS/CFM.swift → Sources/F5TTS/F5TTS.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,15 @@ import Hub
import MLX
import MLXNN
import MLXRandom
import Vocos

// utilities

func lensToMask(t: MLXArray, length: Int? = nil) -> MLXArray {
let maxLength = length ?? t.max(keepDims: false).item(Int.self)
let seq = MLXArray(0..<maxLength)
let expandedSeq = seq.expandedDimensions(axis: 0)
let expandedT = t.expandedDimensions(axis: 1)
return MLX.less(expandedSeq, expandedT)
}

func padToLength(_ t: MLXArray, length: Int, value: Float? = nil) -> MLXArray {
let ndim = t.ndim

guard let seqLen = t.shape.last, length > seqLen else {
return t[0..., .ellipsis]
}

let paddingValue = MLXArray(value ?? 0.0)

let padded: MLXArray
switch ndim {
case 1:
padded = MLX.padded(t, widths: [.init((0, length - seqLen))], value: paddingValue)
case 2:
padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen))], value: paddingValue)
case 3:
padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen)), .init((0, 0))], value: paddingValue)
default:
fatalError("Unsupported padding dims: \(ndim)")
}

return padded[0..., .ellipsis]
}

func padSequence(_ t: [MLXArray], paddingValue: Float = 0) -> MLXArray {
let maxLen = t.map { $0.shape.last ?? 0 }.max() ?? 0
let t = MLX.stacked(t, axis: 0)
return padToLength(t, length: maxLen, value: paddingValue)
}

func listStrToIdx(_ text: [String], vocabCharMap: [String: Int], paddingValue: Int = -1) -> MLXArray {
let listIdxTensors = text.map { str in str.map { char in vocabCharMap[String(char), default: 0] }}
let mlxArrays = listIdxTensors.map { MLXArray($0) }
let paddedText = padSequence(mlxArrays, paddingValue: Float(paddingValue))
return paddedText.asType(.int32)
}

// MARK: -
// MARK: - F5TTS

public class F5TTS: Module {
enum F5TTSError: Error {
case unableToLoadModel
case unableToLoadReferenceAudio
case unableToDetermineDuration
}

public let melSpec: MelSpec
Expand Down Expand Up @@ -100,20 +56,20 @@ public class F5TTS: Module {
return MLX.stacked(ys, axis: 0)
}

public func sample(
private func sample(
cond: MLXArray,
text: [String],
duration: Any,
lens: MLXArray? = nil,
steps: Int = 32,
cfgStrength: Float = 2.0,
swayCoef: Float? = -1.0,
cfgStrength: Double = 2.0,
swayCoef: Double? = -1.0,
seed: Int? = nil,
maxDuration: Int = 4096,
vocoder: ((MLXArray) -> MLXArray)? = nil,
noRefAudio: Bool = false,
editMask: MLXArray? = nil,
progressHandler: ((Float) -> Void)? = nil
progressHandler: ((Double) -> Void)? = nil
) -> (MLXArray, MLXArray) {
MLX.eval(self.parameters())

Expand Down Expand Up @@ -183,7 +139,7 @@ public class F5TTS: Module {
mask: mask
)

progressHandler?(t)
progressHandler?(Double(t))

return pred + (pred - nullPred) * cfgStrength
}
Expand Down Expand Up @@ -218,13 +174,82 @@ public class F5TTS: Module {

return (out, trajectory)
}

public func generate(
text: String,
referenceAudioURL: URL? = nil,
referenceAudioText: String? = nil,
duration: TimeInterval? = nil,
cfg: Double = 2.0,
sway: Double = -1.0,
speed: Double = 1.0,
seed: Int? = nil,
progressHandler: ((Double) -> Void)? = nil
) async throws -> MLXArray {
print("Loading Vocos model...")
let vocos = try await Vocos.fromPretrained(repoId: "lucasnewman/vocos-mel-24khz-mlx")

// load the reference audio + text

var audio: MLXArray
let referenceText: String

if let referenceAudioURL {
audio = try F5TTS.loadAudioArray(url: referenceAudioURL)
referenceText = referenceAudioText ?? ""
} else {
let refAudioAndCaption = try F5TTS.referenceAudio()
(audio, referenceText) = refAudioAndCaption
}

let refAudioDuration = Double(audio.shape[0]) / Double(F5TTS.sampleRate)
print("Using reference audio with duration: \(refAudioDuration)")

// use a heuristic to determine the duration if not provided

var generatedDuration = duration
if generatedDuration == nil {
generatedDuration = F5TTS.estimatedDuration(refAudio: audio, refText: referenceText, text: text)
}

guard let generatedDuration else {
throw F5TTSError.unableToDetermineDuration
}
print("Using generated duration: \(generatedDuration)")

// generate the audio

let normalizedAudio = F5TTS.normalizeAudio(audio: audio)

let processedText = referenceText + " " + text
let frameDuration = Int((refAudioDuration + generatedDuration) * F5TTS.framesPerSecond)
print("Generating \(generatedDuration) seconds (\(frameDuration) total frames) of audio...")

let (outputAudio, _) = self.sample(
cond: normalizedAudio.expandedDimensions(axis: 0),
text: [processedText],
duration: frameDuration,
steps: 32,
cfgStrength: cfg,
swayCoef: sway,
seed: seed,
vocoder: vocos.decode
) { progress in
print("Generation progress: \(progress)")
}

let generatedAudio = outputAudio[audio.shape[0]...]
return generatedAudio
}
}

// MARK: -
// MARK: - Pretrained Models

public extension F5TTS {
static func fromPretrained(repoId: String) async throws -> F5TTS {
let modelDirectoryURL = try await Hub.snapshot(from: repoId, matching: ["*.safetensors", "*.txt"])
static func fromPretrained(repoId: String, downloadProgress: ((Progress) -> Void)? = nil) async throws -> F5TTS {
let modelDirectoryURL = try await Hub.snapshot(from: repoId, matching: ["*.safetensors", "*.txt"]) { progress in
downloadProgress?(progress)
}
return try self.fromPretrained(modelDirectoryURL: modelDirectoryURL)
}

Expand Down Expand Up @@ -273,3 +298,97 @@ public extension F5TTS {
return f5tts
}
}

// MARK: - Utilities

public extension F5TTS {
static var sampleRate: Int = 24000
static var hopLength: Int = 256
static var framesPerSecond: Double = .init(sampleRate) / Double(hopLength)

static func loadAudioArray(url: URL) throws -> MLXArray {
return try AudioUtilities.loadAudioFile(url: url)
}

static func referenceAudio() throws -> (MLXArray, String) {
guard let url = Bundle.module.url(forResource: "test_en_1_ref_short", withExtension: "wav") else {
throw F5TTSError.unableToLoadReferenceAudio
}

return try (
self.loadAudioArray(url: url),
"Some call me nature, others call me mother nature."
)
}

static func normalizeAudio(audio: MLXArray, targetRMS: Double = 0.1) -> MLXArray {
let rms = Double(audio.square().mean().sqrt().item(Float.self))
if rms < targetRMS {
return audio * targetRMS / rms
}
return audio
}

static func estimatedDuration(refAudio: MLXArray, refText: String, text: String, speed: Double = 1.0) -> TimeInterval {
let refDurationInFrames = refAudio.shape[0] / self.hopLength
let pausePunctuation = "。,、;:?!"
let refTextLength = refText.utf8.count + 3 * pausePunctuation.utf8.count
let genTextLength = text.utf8.count + 3 * pausePunctuation.utf8.count

let refAudioToTextRatio = Double(refDurationInFrames) / Double(refTextLength)
let textLength = Double(genTextLength) / speed
let estimatedDurationInFrames = Int(refAudioToTextRatio * textLength)

let estimatedDuration = TimeInterval(estimatedDurationInFrames) / Self.framesPerSecond
print("Using duration of \(estimatedDuration) seconds (\(estimatedDurationInFrames) frames) for generated speech.")

return estimatedDuration
}
}

// MLX utilities

func lensToMask(t: MLXArray, length: Int? = nil) -> MLXArray {
let maxLength = length ?? t.max(keepDims: false).item(Int.self)
let seq = MLXArray(0..<maxLength)
let expandedSeq = seq.expandedDimensions(axis: 0)
let expandedT = t.expandedDimensions(axis: 1)
return MLX.less(expandedSeq, expandedT)
}

func padToLength(_ t: MLXArray, length: Int, value: Float? = nil) -> MLXArray {
let ndim = t.ndim

guard let seqLen = t.shape.last, length > seqLen else {
return t[0..., .ellipsis]
}

let paddingValue = MLXArray(value ?? 0.0)

let padded: MLXArray
switch ndim {
case 1:
padded = MLX.padded(t, widths: [.init((0, length - seqLen))], value: paddingValue)
case 2:
padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen))], value: paddingValue)
case 3:
padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen)), .init((0, 0))], value: paddingValue)
default:
fatalError("Unsupported padding dims: \(ndim)")
}

return padded[0..., .ellipsis]
}

func padSequence(_ t: [MLXArray], paddingValue: Float = 0) -> MLXArray {
let maxLen = t.map { $0.shape.last ?? 0 }.max() ?? 0
let t = MLX.stacked(t, axis: 0)
return padToLength(t, length: maxLen, value: paddingValue)
}

func listStrToIdx(_ text: [String], vocabCharMap: [String: Int], paddingValue: Int = -1) -> MLXArray {
let listIdxTensors = text.map { str in str.map { char in vocabCharMap[String(char), default: 0] }}
let mlxArrays = listIdxTensors.map { MLXArray($0) }
let paddedText = padSequence(mlxArrays, paddingValue: Float(paddingValue))
return paddedText.asType(.int32)
}
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 3553c91

Please sign in to comment.