bol-tts-marathi / src /model.ts
shreyask's picture
model.ts: rename ONNX input ref_s → style for kokoro-js / kokoro-onnx interop
cc0aa08 verified
import { AutoModel, Tensor, env } from "@huggingface/transformers";
// Where the ONNX model lives. By default fetches from HF Hub:
// onnx/model.onnx + config.json from `shreyask/bol-tts-marathi-onnx`.
// For local testing of a re-export, set `?local` query param OR override
// `LOCAL_MODEL_PATH` — loads from /public/models/bol-tts/{config.json,onnx/model.onnx}.
const HF_MODEL_REPO = "shreyask/bol-tts-marathi-onnx";
const LOCAL_MODEL_PATH = "models/bol-tts";
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const USE_LOCAL_MODEL = typeof window !== "undefined" && new URLSearchParams(window.location.search).has("local");
export interface KokoroVocab { [phone: string]: number; }
export interface KokoroConfig { vocab: KokoroVocab; n_token: number; }
export async function loadConfig(): Promise<KokoroConfig> {
const res = await fetch("/config.json");
if (!res.ok) throw new Error(`config.json fetch failed: ${res.status}`);
return res.json();
}
export async function loadVoicepack(voiceId: string): Promise<Float32Array> {
const res = await fetch(`/voices/${voiceId}.bin`);
if (!res.ok) throw new Error(`${voiceId}.bin fetch failed: ${res.status}`);
const buf = await res.arrayBuffer();
const f32 = new Float32Array(buf);
if (f32.length !== 510 * 1 * 256) {
throw new Error(`voicepack ${voiceId}: expected ${510 * 1 * 256} floats, got ${f32.length}`);
}
return f32;
}
export interface SynthesizeResult {
audio: Float32Array;
predDur: Int32Array;
sampleRate: number;
// Time (seconds) the BOS token's audio occupies at the start of `audio`,
// before the first content phoneme. Caller must add this to phoneme.start
// / phoneme.end when computing timeline-aligned timings, since `predDur`
// returned here covers ONLY content (BOS+EOS positions stripped).
leadOffsetSec: number;
}
export class KokoroSession {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
private model: any;
private isWebGpu: boolean;
private constructor(model: any, isWebGpu: boolean) { this.model = model; this.isWebGpu = isWebGpu; }
get backend(): string { return this.isWebGpu ? "webgpu" : "wasm"; }
static async create(onProgress?: (msg: string) => void): Promise<KokoroSession> {
if (USE_LOCAL_MODEL) {
env.allowRemoteModels = false;
env.allowLocalModels = true;
// localModelPath is prepended to the repo string; "/" + "models/bol-tts" → "/models/bol-tts"
env.localModelPath = "/";
} else {
env.allowRemoteModels = true;
env.allowLocalModels = false;
}
const hasWebGpu = typeof (navigator as any).gpu !== "undefined";
const device = hasWebGpu ? "webgpu" : "wasm";
// fp32 only — Kokoro's ISTFTNet decoder is sensitive to int8/int4 weight
// quantization (produces NaN or clipped audio). kokoro.js and HeadTTS both
// default to fp32 on WebGPU for this reason. Size is the cost of quality.
const repoOrPath = USE_LOCAL_MODEL ? LOCAL_MODEL_PATH : HF_MODEL_REPO;
onProgress?.(`loading model from ${USE_LOCAL_MODEL ? "local" : "HF"} (${device}, ~325 MB one-time)…`);
const model = await AutoModel.from_pretrained(repoOrPath, {
dtype: "fp32",
device,
progress_callback: (p: any) => {
if (p.status === "progress" && p.file?.endsWith(".onnx")) {
const mb = (p.loaded / 1e6).toFixed(1);
const pct = p.progress ? `${p.progress.toFixed(1)}%` : "";
onProgress?.(`downloading ${p.file}: ${mb} MB ${pct}`);
} else if (p.status === "done") {
onProgress?.(`loaded ${p.file}`);
}
},
} as any);
return new KokoroSession(model, hasWebGpu);
}
async synthesize(inputIds: number[], refS: Float32Array, speed: number = 1.0): Promise<SynthesizeResult> {
// Kokoro's KModel.forward (Python) prepends BOS=0 + appends EOS=0 before
// calling forward_with_tokens. Our ONNX export exposes forward_with_tokens
// directly so the caller has to wrap. Without BOS, the iSTFTNet decoder's
// window-startup transient (~30-50 ms unreliable output) lands on the
// first content phoneme — leading consonants like /m/ in मुंबई get
// perceptually eaten. With BOS, the predictor allocates real duration to
// the boundary token and the decoder produces clean pre-content audio
// before the first phoneme.
const wrappedIds = [0, ...inputIds, 0];
const ids = BigInt64Array.from(wrappedIds.map(v => BigInt(v)));
const idTensor = new Tensor("int64", ids, [1, ids.length]);
// Input is named "style" to match kokoro-js + thewh1teagle/kokoro-onnx
// ecosystem convention. The `refS` argument name is preserved on this
// function for backward compatibility — it's the same [1, 256] voicepack
// slice either way.
const styleTensor = new Tensor("float32", refS, [1, 256]);
const speedTensor = new Tensor("float32", new Float32Array([speed]), [1]);
const out = await this.model({ input_ids: idTensor, style: styleTensor, speed: speedTensor });
const fullAudio = out.audio.data as Float32Array;
const raw = out.pred_dur.data as BigInt64Array;
// Strip BOS (index 0) + EOS (last index) from pred_dur so the returned
// array aligns 1:1 with the caller's original inputIds. ALSO strip the
// BOS audio prefix and EOS audio suffix from the buffer:
// - we need BOS in the *input* so the predictor allocates real duration
// to the boundary token (this is what fixes Marathi /m/ getting eaten
// in 'मुंबई' — the predictor with BOS context gives the first content
// phoneme proper duration);
// - but we DON'T want the BOS *audio* in the output: Rasa-trained voices
// learned BOS audio = soft breathy pre-roll (training data had natural
// pre-speech sounds), which surfaces as an audible "umm" at the start
// of every utterance. SpringLab voices learned BOS = silence so they
// wouldn't show this either way. Stripping the BOS audio gives both
// voice families the predictor benefit without the umm.
//
// 1 predictor frame = 600 audio samples at 24 kHz = 25 ms.
const HOP = 600;
const SR = 24000;
const bosFrames = raw.length > 0 ? Number(raw[0]) : 0;
const eosFrames = raw.length > 1 ? Number(raw[raw.length - 1]) : 0;
const innerLen = Math.max(0, raw.length - 2);
const predDur = new Int32Array(innerLen);
for (let i = 0; i < innerLen; i++) predDur[i] = Number(raw[i + 1]);
const startSample = Math.min(bosFrames * HOP, fullAudio.length);
const endSample = Math.max(startSample, fullAudio.length - eosFrames * HOP);
const audio = fullAudio.subarray(startSample, endSample);
return {
audio,
predDur,
sampleRate: SR,
// BOS audio already stripped from `audio`, so no offset for the caller.
leadOffsetSec: 0,
};
}
}