From a772687b7ba73050442cba4d06f8996dd295e4b0 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Fri, 29 Nov 2024 13:14:56 -0800 Subject: [PATCH] Use updated duration predictor. --- Sources/F5TTS/Duration.swift | 38 ++++++++++++++++++++++++++++-------- Sources/F5TTS/F5TTS.swift | 11 ++++++----- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/Sources/F5TTS/Duration.swift b/Sources/F5TTS/Duration.swift index cd9d5c9..c64ce05 100644 --- a/Sources/F5TTS/Duration.swift +++ b/Sources/F5TTS/Duration.swift @@ -22,13 +22,38 @@ class DurationInputEmbedding: Module { } } +public class DurationBlock: Module { + let attn_norm: LayerNorm + let attn: Attention + let ff_norm: LayerNorm + let ff: FeedForward + + init(dim: Int, heads: Int, dimHead: Int, ffMult: Int = 4, dropout: Float = 0.1) { + self.attn_norm = LayerNorm(dimensions: dim) + self.attn = Attention(dim: dim, heads: heads, dimHead: dimHead, dropout: dropout) + self.ff_norm = LayerNorm(dimensions: dim, eps: 1e-6, affine: false) + self.ff = FeedForward(dim: dim, mult: ffMult, dropout: dropout, approximate: "tanh") + + super.init() + } + + func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, rope: (MLXArray, Float)? = nil) -> MLXArray { + let norm = attn_norm(x) + let attnOutput = attn(norm, mask: mask, rope: rope) + var output = x + attnOutput + let normedOutput = ff_norm(output) + let ffOutput = ff(normedOutput) + output = output + ffOutput + return output + } +} + public class DurationTransformer: Module { let dim: Int - let time_embed: TimestepEmbedding let text_embed: TextEmbedding let input_embed: DurationInputEmbedding let rotary_embed: RotaryEmbedding - let transformer_blocks: [DiTBlock] + let transformer_blocks: [DurationBlock] let norm_out: RMSNorm let depth: Int @@ -46,14 +71,13 @@ public class DurationTransformer: Module { ) { self.dim = dim let actualTextDim = textDim ?? melDim - self.time_embed = TimestepEmbedding(dim: dim) self.text_embed = TextEmbedding(textNumEmbeds: textNumEmbeds, textDim: actualTextDim, convLayers: convLayers) self.input_embed = DurationInputEmbedding(melDim: melDim, textDim: actualTextDim, outDim: dim) self.rotary_embed = RotaryEmbedding(dim: dimHead) self.depth = depth self.transformer_blocks = (0 ..< depth).map { _ in - DiTBlock(dim: dim, heads: heads, dimHead: dimHead, ffMult: ffMult, dropout: dropout) + DurationBlock(dim: dim, heads: heads, dimHead: dimHead, ffMult: ffMult, dropout: dropout) } self.norm_out = RMSNorm(dimensions: dim) @@ -66,17 +90,15 @@ public class DurationTransformer: Module { text: MLXArray, mask: MLXArray? = nil ) -> MLXArray { - let batchSize = cond.shape[0] let seqLen = cond.shape[1] - let t = time_embed(MLX.ones([batchSize], type: Float32.self)) let textEmbed = text_embed(text, seqLen: seqLen) var x = input_embed(cond: cond, textEmbed: textEmbed) let rope = rotary_embed.forwardFromSeqLen(seqLen) for block in transformer_blocks { - x = block(x, t: t, mask: mask, rope: rope) + x = block(x, mask: mask, rope: rope) } return norm_out(x) @@ -137,7 +159,7 @@ public class DurationPredictor: Module { lens = MLX.maximum(textLens, lens) var output = transformer(cond: cond, text: inputText) - output = to_pred(output).mean() + output = to_pred(output).mean().reshaped([batch, -1]) output.eval() return output diff --git a/Sources/F5TTS/F5TTS.swift b/Sources/F5TTS/F5TTS.swift index f084a39..a5fa6e3 100644 --- a/Sources/F5TTS/F5TTS.swift +++ b/Sources/F5TTS/F5TTS.swift @@ -106,13 +106,14 @@ public class F5TTS: Module { if resolvedDuration == nil, let durationPredictor = self._durationPredictor { let estimatedDurationInSeconds = durationPredictor(cond, text: text).item(Float32.self) resolvedDuration = MLXArray(Int(Double(estimatedDurationInSeconds) * F5TTS.framesPerSecond)) - print("Generating \(estimatedDurationInSeconds) seconds (\(resolvedDuration) total frames) of audio...") } guard let resolvedDuration else { throw F5TTSError.unableToDetermineDuration } + print("Generating \(Double(resolvedDuration.item(Float32.self)) / F5TTS.framesPerSecond) seconds of audio...") + var duration = resolvedDuration duration = MLX.clip(MLX.maximum(lens + 1, duration), min: 0, max: maxDuration) let maxDuration = duration.max().item(Int.self) @@ -281,18 +282,18 @@ public extension F5TTS { // duration model var durationPredictor: DurationPredictor? - let durationModelURL = modelDirectoryURL.appendingPathComponent("duration_model.safetensors") + let durationModelURL = modelDirectoryURL.appendingPathComponent("duration_v2.safetensors") do { let durationModelWeights = try loadArrays(url: durationModelURL) let durationTransformer = DurationTransformer( - dim: 256, + dim: 512, depth: 8, heads: 8, dimHead: 64, ffMult: 2, textNumEmbeds: vocab.count, - textDim: 256, + textDim: 512, convLayers: 2 ) let predictor = DurationPredictor( @@ -304,7 +305,7 @@ public extension F5TTS { durationPredictor = predictor } catch { - print("Warning: no duration predictor model found.") + print("Warning: no duration predictor model found: \(error)") } // model