Skip to content

Commit

Permalink
Add rk4 sampling and use it by default.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Dec 11, 2024
1 parent a772687 commit 3689910
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 42 deletions.
12 changes: 6 additions & 6 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/maiqingqiang/Jinja",
"state" : {
"revision" : "b435eb62b0d3d5f34167ec70a128355486981712",
"version" : "1.0.5"
"revision" : "6dbe4c449469fb586d0f7339f900f0dd4d78b167",
"version" : "1.0.6"
}
},
{
"identity" : "mlx-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"revision" : "78a7cfe6701d6e9c88e9d4a0d1f7990af84b2146",
"version" : "0.18.0"
"revision" : "70dbb62128a5a1471a5ab80363430adb33470cab",
"version" : "0.21.2"
}
},
{
Expand All @@ -41,8 +41,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers",
"state" : {
"revision" : "4d25d20e49d2269aec1556231f8e278db7b2a4f0",
"version" : "0.1.13"
"revision" : "d42fdae473c49ea216671da8caae58e102d28709",
"version" : "0.1.14"
}
},
{
Expand Down
119 changes: 83 additions & 36 deletions Sources/F5TTS/F5TTS.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,74 @@ import Vocos

// MARK: - F5TTS

func odeint_euler(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray {
var ys = [y0]
var yCurrent = y0

for i in 0..<(t.shape[0] - 1) {
let tCurrent = t[i].item(Float.self)
let dt = t[i + 1].item(Float.self) - tCurrent

let k = fun(tCurrent, yCurrent)
let yNext = yCurrent + dt * k

ys.append(yNext)
yCurrent = yNext
}

return MLX.stacked(ys, axis: 0)
}

func odeint_midpoint(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray {
var ys = [y0]
var yCurrent = y0

for i in 0..<(t.shape[0] - 1) {
let tCurrent = t[i].item(Float.self)
let dt = t[i + 1].item(Float.self) - tCurrent

let k1 = fun(tCurrent, yCurrent)
let mid = yCurrent + 0.5 * dt * k1

let k2 = fun(tCurrent + 0.5 * dt, mid)
let yNext = yCurrent + dt * k2

ys.append(yNext)
yCurrent = yNext
}

return MLX.stacked(ys, axis: 0)
}

func odeint_rk4(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray {
var ys = [y0]
var yCurrent = y0

for i in 0..<(t.shape[0] - 1) {
let tCurrent = t[i].item(Float.self)
let dt = t[i + 1].item(Float.self) - tCurrent

let k1 = fun(tCurrent, yCurrent)
let k2 = fun(tCurrent + 0.5 * dt, yCurrent + 0.5 * dt * k1)
let k3 = fun(tCurrent + 0.5 * dt, yCurrent + 0.5 * dt * k2)
let k4 = fun(tCurrent + dt, yCurrent + dt * k3)

let yNext = yCurrent + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4)

ys.append(yNext)
yCurrent = yNext
}

return MLX.stacked(ys)
}

public class F5TTS: Module {
public enum ODEMethod: String {
case euler
case midpoint
case rk4
}

enum F5TTSError: Error {
case unableToLoadModel
case unableToLoadReferenceAudio
Expand Down Expand Up @@ -38,40 +105,18 @@ public class F5TTS: Module {
super.init()
}

private func odeint(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray {
var ys = [y0]
var yCurrent = y0

for i in 0..<(t.shape[0] - 1) {
let tCurrent = t[i].item(Float.self)
let dt = t[i + 1].item(Float.self) - tCurrent

let k1 = fun(tCurrent, yCurrent)
let mid = yCurrent + 0.5 * dt * k1

let k2 = fun(tCurrent + 0.5 * dt, mid)
let yNext = yCurrent + dt * k2

ys.append(yNext)
yCurrent = yNext
}

return MLX.stacked(ys, axis: 0)
}

private func sample(
cond: MLXArray,
text: [String],
duration: Int? = nil,
lens: MLXArray? = nil,
steps: Int = 32,
steps: Int = 8,
method: ODEMethod = .rk4,
cfgStrength: Double = 2.0,
swayCoef: Double? = -1.0,
seed: Int? = nil,
maxDuration: Int = 4096,
vocoder: ((MLXArray) -> MLXArray)? = nil,
noRefAudio: Bool = false,
editMask: MLXArray? = nil,
progressHandler: ((Double) -> Void)? = nil
) throws -> (MLXArray, MLXArray) {
MLX.eval(self.parameters())
Expand All @@ -96,9 +141,6 @@ public class F5TTS: Module {
lens = MLX.maximum(textLens, lens)

var condMask = lensToMask(t: lens)
if let editMask = editMask {
condMask = condMask & editMask
}

// duration
var resolvedDuration: MLXArray? = (duration != nil) ? MLXArray(duration!) : nil
Expand All @@ -125,10 +167,6 @@ public class F5TTS: Module {

let mask: MLXArray? = (batch > 1) ? lensToMask(t: duration) : nil

if noRefAudio {
cond = MLX.zeros(like: cond)
}

// neural ode

let fn: (Float, MLXArray) -> MLXArray = { t, x in
Expand Down Expand Up @@ -169,7 +207,7 @@ public class F5TTS: Module {

var y0: [MLXArray] = []
for dur in duration {
if let seed = seed {
if let seed {
MLXRandom.seed(UInt64(seed))
}
let noise = MLXRandom.normal([dur.item(Int.self), self.numChannels])
Expand All @@ -183,11 +221,17 @@ public class F5TTS: Module {
t = t + coef * (MLX.cos(MLXArray(.pi) / 2 * t) - 1 + t)
}

let trajectory = self.odeint(fun: fn, y0: y0Padded, t: t)
let odeintFn = switch method {
case .euler: odeint_euler
case .midpoint: odeint_midpoint
case .rk4: odeint_rk4
}

let trajectory = odeintFn(fn, y0Padded, t)
let sampled = trajectory[-1]
var out = MLX.where(condMask, cond, sampled)

if let vocoder = vocoder {
if let vocoder {
out = vocoder(out)
}
out.eval()
Expand All @@ -200,6 +244,8 @@ public class F5TTS: Module {
referenceAudioURL: URL? = nil,
referenceAudioText: String? = nil,
duration: TimeInterval? = nil,
steps: Int = 8,
method: ODEMethod = .rk4,
cfg: Double = 2.0,
sway: Double = -1.0,
speed: Double = 1.0,
Expand Down Expand Up @@ -234,7 +280,8 @@ public class F5TTS: Module {
cond: normalizedAudio.expandedDimensions(axis: 0),
text: [processedText],
duration: nil,
steps: 32,
steps: steps,
method: method,
cfgStrength: cfg,
swayCoef: sway,
seed: seed,
Expand Down Expand Up @@ -339,7 +386,7 @@ public extension F5TTS {
static var framesPerSecond: Double = .init(sampleRate) / Double(hopLength)

static func loadAudioArray(url: URL) throws -> MLXArray {
return try AudioUtilities.loadAudioFile(url: url)
try AudioUtilities.loadAudioFile(url: url)
}

static func referenceAudio() throws -> (MLXArray, String) {
Expand Down
8 changes: 8 additions & 0 deletions Sources/f5-tts-generate/GenerateCommand.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ struct GenerateAudio: AsyncParsableCommand {
@Option(name: .long, help: "Output path for the generated audio")
var outputPath: String = "output.wav"

@Option(name: .long, help: "The number of steps to use for ODE sampling")
var steps: Int = 8

@Option(name: .long, help: "Method to use for ODE sampling. Options are 'euler', 'midpoint', and 'rk4'.")
var method: String = "rk4"

@Option(name: .long, help: "Strength of classifier free guidance")
var cfg: Double = 2.0

Expand All @@ -49,6 +55,8 @@ struct GenerateAudio: AsyncParsableCommand {
referenceAudioURL: refAudioPath != nil ? URL(filePath: refAudioPath!) : nil,
referenceAudioText: refAudioText,
duration: duration,
steps: steps,
method: F5TTS.ODEMethod(rawValue: method)!,
cfg: cfg,
sway: sway,
speed: speed,
Expand Down

0 comments on commit 3689910

Please sign in to comment.