Skip to content

Commit

Permalink
Use updated duration predictor.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Nov 29, 2024
1 parent 6ae8953 commit a772687
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
38 changes: 30 additions & 8 deletions Sources/F5TTS/Duration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,38 @@ class DurationInputEmbedding: Module {
}
}

public class DurationBlock: Module {
let attn_norm: LayerNorm
let attn: Attention
let ff_norm: LayerNorm
let ff: FeedForward

init(dim: Int, heads: Int, dimHead: Int, ffMult: Int = 4, dropout: Float = 0.1) {
self.attn_norm = LayerNorm(dimensions: dim)
self.attn = Attention(dim: dim, heads: heads, dimHead: dimHead, dropout: dropout)
self.ff_norm = LayerNorm(dimensions: dim, eps: 1e-6, affine: false)
self.ff = FeedForward(dim: dim, mult: ffMult, dropout: dropout, approximate: "tanh")

super.init()
}

func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, rope: (MLXArray, Float)? = nil) -> MLXArray {
let norm = attn_norm(x)
let attnOutput = attn(norm, mask: mask, rope: rope)
var output = x + attnOutput
let normedOutput = ff_norm(output)
let ffOutput = ff(normedOutput)
output = output + ffOutput
return output
}
}

public class DurationTransformer: Module {
let dim: Int
let time_embed: TimestepEmbedding
let text_embed: TextEmbedding
let input_embed: DurationInputEmbedding
let rotary_embed: RotaryEmbedding
let transformer_blocks: [DiTBlock]
let transformer_blocks: [DurationBlock]
let norm_out: RMSNorm
let depth: Int

Expand All @@ -46,14 +71,13 @@ public class DurationTransformer: Module {
) {
self.dim = dim
let actualTextDim = textDim ?? melDim
self.time_embed = TimestepEmbedding(dim: dim)
self.text_embed = TextEmbedding(textNumEmbeds: textNumEmbeds, textDim: actualTextDim, convLayers: convLayers)
self.input_embed = DurationInputEmbedding(melDim: melDim, textDim: actualTextDim, outDim: dim)
self.rotary_embed = RotaryEmbedding(dim: dimHead)
self.depth = depth

self.transformer_blocks = (0 ..< depth).map { _ in
DiTBlock(dim: dim, heads: heads, dimHead: dimHead, ffMult: ffMult, dropout: dropout)
DurationBlock(dim: dim, heads: heads, dimHead: dimHead, ffMult: ffMult, dropout: dropout)
}

self.norm_out = RMSNorm(dimensions: dim)
Expand All @@ -66,17 +90,15 @@ public class DurationTransformer: Module {
text: MLXArray,
mask: MLXArray? = nil
) -> MLXArray {
let batchSize = cond.shape[0]
let seqLen = cond.shape[1]

let t = time_embed(MLX.ones([batchSize], type: Float32.self))
let textEmbed = text_embed(text, seqLen: seqLen)
var x = input_embed(cond: cond, textEmbed: textEmbed)

let rope = rotary_embed.forwardFromSeqLen(seqLen)

for block in transformer_blocks {
x = block(x, t: t, mask: mask, rope: rope)
x = block(x, mask: mask, rope: rope)
}

return norm_out(x)
Expand Down Expand Up @@ -137,7 +159,7 @@ public class DurationPredictor: Module {
lens = MLX.maximum(textLens, lens)

var output = transformer(cond: cond, text: inputText)
output = to_pred(output).mean()
output = to_pred(output).mean().reshaped([batch, -1])
output.eval()

return output
Expand Down
11 changes: 6 additions & 5 deletions Sources/F5TTS/F5TTS.swift
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,14 @@ public class F5TTS: Module {
if resolvedDuration == nil, let durationPredictor = self._durationPredictor {
let estimatedDurationInSeconds = durationPredictor(cond, text: text).item(Float32.self)
resolvedDuration = MLXArray(Int(Double(estimatedDurationInSeconds) * F5TTS.framesPerSecond))
print("Generating \(estimatedDurationInSeconds) seconds (\(resolvedDuration) total frames) of audio...")
}

guard let resolvedDuration else {
throw F5TTSError.unableToDetermineDuration
}

print("Generating \(Double(resolvedDuration.item(Float32.self)) / F5TTS.framesPerSecond) seconds of audio...")

var duration = resolvedDuration
duration = MLX.clip(MLX.maximum(lens + 1, duration), min: 0, max: maxDuration)
let maxDuration = duration.max().item(Int.self)
Expand Down Expand Up @@ -281,18 +282,18 @@ public extension F5TTS {
// duration model

var durationPredictor: DurationPredictor?
let durationModelURL = modelDirectoryURL.appendingPathComponent("duration_model.safetensors")
let durationModelURL = modelDirectoryURL.appendingPathComponent("duration_v2.safetensors")
do {
let durationModelWeights = try loadArrays(url: durationModelURL)

let durationTransformer = DurationTransformer(
dim: 256,
dim: 512,
depth: 8,
heads: 8,
dimHead: 64,
ffMult: 2,
textNumEmbeds: vocab.count,
textDim: 256,
textDim: 512,
convLayers: 2
)
let predictor = DurationPredictor(
Expand All @@ -304,7 +305,7 @@ public extension F5TTS {

durationPredictor = predictor
} catch {
print("Warning: no duration predictor model found.")
print("Warning: no duration predictor model found: \(error)")
}

// model
Expand Down

0 comments on commit a772687

Please sign in to comment.