File size: 6,931 Bytes
2bb8806
 
761f477
 
 
 
 
 
 
 
2bb8806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73b1551
 
 
 
 
2bb8806
 
 
 
 
 
 
 
 
 
 
761f477
 
 
 
 
 
 
 
 
2bb8806
 
 
 
 
 
 
761f477
 
 
2bb8806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc4d90f
73b1551
 
 
 
 
 
 
 
 
 
cc4d90f
cc0aa08
 
 
 
 
cc4d90f
2bb8806
cc0aa08
2bb8806
b367e53
2bb8806
 
73b1551
b367e53
 
 
 
 
 
 
 
 
 
 
 
 
 
73b1551
 
 
b367e53
73b1551
 
 
 
b367e53
 
 
 
73b1551
 
 
 
b367e53
 
73b1551
2bb8806
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import { AutoModel, Tensor, env } from "@huggingface/transformers";

// Where the ONNX model lives. By default fetches from HF Hub:
//   onnx/model.onnx + config.json from `shreyask/bol-tts-marathi-onnx`.
// For local testing of a re-export, set `?local` query param OR override
// `LOCAL_MODEL_PATH` — loads from /public/models/bol-tts/{config.json,onnx/model.onnx}.
const HF_MODEL_REPO = "shreyask/bol-tts-marathi-onnx";
const LOCAL_MODEL_PATH = "models/bol-tts";
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const USE_LOCAL_MODEL = typeof window !== "undefined" && new URLSearchParams(window.location.search).has("local");

export interface KokoroVocab { [phone: string]: number; }
export interface KokoroConfig { vocab: KokoroVocab; n_token: number; }

export async function loadConfig(): Promise<KokoroConfig> {
  const res = await fetch("/config.json");
  if (!res.ok) throw new Error(`config.json fetch failed: ${res.status}`);
  return res.json();
}

export async function loadVoicepack(voiceId: string): Promise<Float32Array> {
  const res = await fetch(`/voices/${voiceId}.bin`);
  if (!res.ok) throw new Error(`${voiceId}.bin fetch failed: ${res.status}`);
  const buf = await res.arrayBuffer();
  const f32 = new Float32Array(buf);
  if (f32.length !== 510 * 1 * 256) {
    throw new Error(`voicepack ${voiceId}: expected ${510 * 1 * 256} floats, got ${f32.length}`);
  }
  return f32;
}

export interface SynthesizeResult {
  audio: Float32Array;
  predDur: Int32Array;
  sampleRate: number;
  // Time (seconds) the BOS token's audio occupies at the start of `audio`,
  // before the first content phoneme. Caller must add this to phoneme.start
  // / phoneme.end when computing timeline-aligned timings, since `predDur`
  // returned here covers ONLY content (BOS+EOS positions stripped).
  leadOffsetSec: number;
}

export class KokoroSession {
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
  private model: any;
  private isWebGpu: boolean;
  private constructor(model: any, isWebGpu: boolean) { this.model = model; this.isWebGpu = isWebGpu; }

  get backend(): string { return this.isWebGpu ? "webgpu" : "wasm"; }

  static async create(onProgress?: (msg: string) => void): Promise<KokoroSession> {
    if (USE_LOCAL_MODEL) {
      env.allowRemoteModels = false;
      env.allowLocalModels = true;
      // localModelPath is prepended to the repo string; "/" + "models/bol-tts" → "/models/bol-tts"
      env.localModelPath = "/";
    } else {
      env.allowRemoteModels = true;
      env.allowLocalModels = false;
    }

    const hasWebGpu = typeof (navigator as any).gpu !== "undefined";
    const device = hasWebGpu ? "webgpu" : "wasm";

    // fp32 only — Kokoro's ISTFTNet decoder is sensitive to int8/int4 weight
    // quantization (produces NaN or clipped audio). kokoro.js and HeadTTS both
    // default to fp32 on WebGPU for this reason. Size is the cost of quality.
    const repoOrPath = USE_LOCAL_MODEL ? LOCAL_MODEL_PATH : HF_MODEL_REPO;
    onProgress?.(`loading model from ${USE_LOCAL_MODEL ? "local" : "HF"} (${device}, ~325 MB one-time)…`);
    const model = await AutoModel.from_pretrained(repoOrPath, {
      dtype: "fp32",
      device,
      progress_callback: (p: any) => {
        if (p.status === "progress" && p.file?.endsWith(".onnx")) {
          const mb = (p.loaded / 1e6).toFixed(1);
          const pct = p.progress ? `${p.progress.toFixed(1)}%` : "";
          onProgress?.(`downloading ${p.file}: ${mb} MB ${pct}`);
        } else if (p.status === "done") {
          onProgress?.(`loaded ${p.file}`);
        }
      },
    } as any);
    return new KokoroSession(model, hasWebGpu);
  }

  async synthesize(inputIds: number[], refS: Float32Array, speed: number = 1.0): Promise<SynthesizeResult> {
    // Kokoro's KModel.forward (Python) prepends BOS=0 + appends EOS=0 before
    // calling forward_with_tokens. Our ONNX export exposes forward_with_tokens
    // directly so the caller has to wrap. Without BOS, the iSTFTNet decoder's
    // window-startup transient (~30-50 ms unreliable output) lands on the
    // first content phoneme — leading consonants like /m/ in मुंबई get
    // perceptually eaten. With BOS, the predictor allocates real duration to
    // the boundary token and the decoder produces clean pre-content audio
    // before the first phoneme.
    const wrappedIds = [0, ...inputIds, 0];
    const ids = BigInt64Array.from(wrappedIds.map(v => BigInt(v)));
    const idTensor    = new Tensor("int64",   ids,  [1, ids.length]);
    // Input is named "style" to match kokoro-js + thewh1teagle/kokoro-onnx
    // ecosystem convention. The `refS` argument name is preserved on this
    // function for backward compatibility — it's the same [1, 256] voicepack
    // slice either way.
    const styleTensor = new Tensor("float32", refS, [1, 256]);
    const speedTensor = new Tensor("float32", new Float32Array([speed]), [1]);

    const out = await this.model({ input_ids: idTensor, style: styleTensor, speed: speedTensor });

    const fullAudio = out.audio.data as Float32Array;
    const raw = out.pred_dur.data as BigInt64Array;

    // Strip BOS (index 0) + EOS (last index) from pred_dur so the returned
    // array aligns 1:1 with the caller's original inputIds. ALSO strip the
    // BOS audio prefix and EOS audio suffix from the buffer:
    //   - we need BOS in the *input* so the predictor allocates real duration
    //     to the boundary token (this is what fixes Marathi /m/ getting eaten
    //     in 'मुंबई' — the predictor with BOS context gives the first content
    //     phoneme proper duration);
    //   - but we DON'T want the BOS *audio* in the output: Rasa-trained voices
    //     learned BOS audio = soft breathy pre-roll (training data had natural
    //     pre-speech sounds), which surfaces as an audible "umm" at the start
    //     of every utterance. SpringLab voices learned BOS = silence so they
    //     wouldn't show this either way. Stripping the BOS audio gives both
    //     voice families the predictor benefit without the umm.
    //
    // 1 predictor frame = 600 audio samples at 24 kHz = 25 ms.
    const HOP = 600;
    const SR = 24000;
    const bosFrames = raw.length > 0 ? Number(raw[0]) : 0;
    const eosFrames = raw.length > 1 ? Number(raw[raw.length - 1]) : 0;
    const innerLen = Math.max(0, raw.length - 2);
    const predDur = new Int32Array(innerLen);
    for (let i = 0; i < innerLen; i++) predDur[i] = Number(raw[i + 1]);

    const startSample = Math.min(bosFrames * HOP, fullAudio.length);
    const endSample = Math.max(startSample, fullAudio.length - eosFrames * HOP);
    const audio = fullAudio.subarray(startSample, endSample);

    return {
      audio,
      predDur,
      sampleRate: SR,
      // BOS audio already stripped from `audio`, so no offset for the caller.
      leadOffsetSec: 0,
    };
  }
}