pnnbao-ump commited on
Commit
79d5798
·
verified ·
1 Parent(s): 41e7f51

Upload 44 files

Browse files
Files changed (45) hide show
  1. .gitattributes +11 -0
  2. app.py +339 -0
  3. config.yaml +77 -0
  4. packages.txt +3 -0
  5. requirements.txt +9 -0
  6. sample/Bình (nam miền Bắc).pt +3 -0
  7. sample/Bình (nam miền Bắc).txt +1 -0
  8. sample/Bình (nam miền Bắc).wav +3 -0
  9. sample/Dung (nữ miền Nam).pt +3 -0
  10. sample/Dung (nữ miền Nam).txt +1 -0
  11. sample/Dung (nữ miền Nam).wav +3 -0
  12. sample/Hương (nữ miền Bắc).pt +3 -0
  13. sample/Hương (nữ miền Bắc).txt +1 -0
  14. sample/Hương (nữ miền Bắc).wav +3 -0
  15. sample/Ly (nữ miền Bắc).pt +3 -0
  16. sample/Ly (nữ miền Bắc).txt +1 -0
  17. sample/Ly (nữ miền Bắc).wav +3 -0
  18. sample/Nguyên (nam miền Nam).pt +3 -0
  19. sample/Nguyên (nam miền Nam).txt +1 -0
  20. sample/Nguyên (nam miền Nam).wav +3 -0
  21. sample/Ngọc (nữ miền Bắc).pt +3 -0
  22. sample/Ngọc (nữ miền Bắc).txt +1 -0
  23. sample/Ngọc (nữ miền Bắc).wav +3 -0
  24. sample/Sơn (nam miền Nam).pt +3 -0
  25. sample/Sơn (nam miền Nam).txt +1 -0
  26. sample/Sơn (nam miền Nam).wav +3 -0
  27. sample/Tuyên (nam miền Bắc).pt +3 -0
  28. sample/Tuyên (nam miền Bắc).txt +1 -0
  29. sample/Tuyên (nam miền Bắc).wav +3 -0
  30. sample/Vĩnh (nam miền Nam).pt +3 -0
  31. sample/Vĩnh (nam miền Nam).txt +1 -0
  32. sample/Vĩnh (nam miền Nam).wav +3 -0
  33. sample/Đoan (nữ miền Nam).pt +3 -0
  34. sample/Đoan (nữ miền Nam).txt +1 -0
  35. sample/Đoan (nữ miền Nam).wav +3 -0
  36. utils/__init__.py +0 -0
  37. utils/__pycache__/__init__.cpython-312.pyc +0 -0
  38. utils/__pycache__/core_utils.cpython-312.pyc +0 -0
  39. utils/__pycache__/normalize_text.cpython-312.pyc +0 -0
  40. utils/__pycache__/phonemize_text.cpython-312.pyc +0 -0
  41. utils/core_utils.py +53 -0
  42. utils/normalize_text.py +407 -0
  43. utils/phoneme_dict.json +3 -0
  44. utils/phonemize_text.py +346 -0
  45. vieneu_tts.py +859 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sample/Bình[[:space:]](nam[[:space:]]miền[[:space:]]Bắc).wav filter=lfs diff=lfs merge=lfs -text
37
+ sample/Dung[[:space:]](nữ[[:space:]]miền[[:space:]]Nam).wav filter=lfs diff=lfs merge=lfs -text
38
+ sample/Đoan[[:space:]](nữ[[:space:]]miền[[:space:]]Nam).wav filter=lfs diff=lfs merge=lfs -text
39
+ sample/Hương[[:space:]](nữ[[:space:]]miền[[:space:]]Bắc).wav filter=lfs diff=lfs merge=lfs -text
40
+ sample/Ly[[:space:]](nữ[[:space:]]miền[[:space:]]Bắc).wav filter=lfs diff=lfs merge=lfs -text
41
+ sample/Ngọc[[:space:]](nữ[[:space:]]miền[[:space:]]Bắc).wav filter=lfs diff=lfs merge=lfs -text
42
+ sample/Nguyên[[:space:]](nam[[:space:]]miền[[:space:]]Nam).wav filter=lfs diff=lfs merge=lfs -text
43
+ sample/Sơn[[:space:]](nam[[:space:]]miền[[:space:]]Nam).wav filter=lfs diff=lfs merge=lfs -text
44
+ sample/Tuyên[[:space:]](nam[[:space:]]miền[[:space:]]Bắc).wav filter=lfs diff=lfs merge=lfs -text
45
+ sample/Vĩnh[[:space:]](nam[[:space:]]miền[[:space:]]Nam).wav filter=lfs diff=lfs merge=lfs -text
46
+ utils/phoneme_dict.json filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces # PHẢI import TRƯỚC mọi thứ trên HF Spaces ZeroGPU
2
+ import os
3
+ os.environ['SPACES_ZERO_GPU'] = '1' # Set environment variable explicitly
4
+
5
+ import gradio as gr
6
+ import soundfile as sf
7
+ import tempfile
8
+ import torch
9
+ from vieneu_tts import VieNeuTTS
10
+ import time
11
+
12
+ print("⏳ Đang khởi động VieNeu-TTS...")
13
+
14
+ # --- 1. SETUP MODEL ---
15
+ print("📦 Đang tải model...")
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"🖥️ Sử dụng thiết bị: {device.upper()}")
18
+
19
+ try:
20
+ tts = VieNeuTTS(
21
+ backbone_repo="pnnbao-ump/VieNeu-TTS-0.3B",
22
+ backbone_device=device,
23
+ codec_repo="neuphonic/distill-neucodec",
24
+ codec_device=device
25
+ )
26
+ print("✅ Model đã tải xong!")
27
+ except Exception as e:
28
+ print(f"⚠️ Không thể tải model (Chế độ UI Demo): {e}")
29
+ class MockTTS:
30
+ def encode_reference(self, path): return None
31
+ def infer(self, text, ref, ref_text):
32
+ import numpy as np
33
+ # Giả lập độ trễ để test tính năng đo thời gian
34
+ time.sleep(1.5)
35
+ return np.random.uniform(-0.5, 0.5, 24000*3)
36
+ tts = MockTTS()
37
+
38
+ # --- 2. DATA ---
39
+ VOICE_SAMPLES = {
40
+ "Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"},
41
+ "Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"},
42
+ "Bình (nam miền Bắc)": {"audio": "./sample/Bình (nam miền Bắc).wav", "text": "./sample/Bình (nam miền Bắc).txt"},
43
+ "Nguyên (nam miền Nam)": {"audio": "./sample/Nguyên (nam miền Nam).wav", "text": "./sample/Nguyên (nam miền Nam).txt"},
44
+ "Sơn (nam miền Nam)": {"audio": "./sample/Sơn (nam miền Nam).wav", "text": "./sample/Sơn (nam miền Nam).txt"},
45
+ "Đoan (nữ miền Nam)": {"audio": "./sample/Đoan (nữ miền Nam).wav", "text": "./sample/Đoan (nữ miền Nam).txt"},
46
+ "Ngọc (nữ miền Bắc)": {"audio": "./sample/Ngọc (nữ miền Bắc).wav", "text": "./sample/Ngọc (nữ miền Bắc).txt"},
47
+ "Ly (nữ miền Bắc)": {"audio": "./sample/Ly (nữ miền Bắc).wav", "text": "./sample/Ly (nữ miền Bắc).txt"},
48
+ "Dung (nữ miền Nam)": {"audio": "./sample/Dung (nữ miền Nam).wav", "text": "./sample/Dung (nữ miền Nam).txt"}
49
+ }
50
+
51
+ # --- 3. HELPER FUNCTIONS ---
52
+ def load_reference_info(voice_choice):
53
+ if voice_choice in VOICE_SAMPLES:
54
+ audio_path = VOICE_SAMPLES[voice_choice]["audio"]
55
+ text_path = VOICE_SAMPLES[voice_choice]["text"]
56
+ try:
57
+ if os.path.exists(text_path):
58
+ with open(text_path, "r", encoding="utf-8") as f:
59
+ ref_text = f.read()
60
+ return audio_path, ref_text
61
+ else:
62
+ return audio_path, "⚠️ Không tìm thấy file text mẫu."
63
+ except Exception as e:
64
+ return None, f"❌ Lỗi: {str(e)}"
65
+ return None, ""
66
+
67
+ @spaces.GPU(duration=120)
68
+ def synthesize_speech(text, voice_choice, custom_audio, custom_text, mode_tab):
69
+ try:
70
+ if not text or text.strip() == "":
71
+ return None, "⚠️ Vui lòng nhập văn bản cần tổng hợp!"
72
+
73
+ # --- LOGIC CHECK LIMIT 250 ---
74
+ if len(text) > 250:
75
+ return None, f"❌ Văn bản quá dài ({len(text)}/250 ký tự)! Vui lòng cắt ngắn lại để đảm bảo chất lượng."
76
+
77
+ # Logic chọn Reference
78
+ if mode_tab == "custom_mode":
79
+ if custom_audio is None or not custom_text:
80
+ return None, "⚠️ Vui lòng tải lên Audio và nhập nội dung Audio đó."
81
+ ref_audio_path = custom_audio
82
+ ref_text_raw = custom_text
83
+ print("🎨 Mode: Custom Voice")
84
+ else: # Preset
85
+ if voice_choice not in VOICE_SAMPLES:
86
+ return None, "⚠️ Vui lòng chọn một giọng mẫu."
87
+ ref_audio_path = VOICE_SAMPLES[voice_choice]["audio"]
88
+ ref_text_path = VOICE_SAMPLES[voice_choice]["text"]
89
+
90
+ if not os.path.exists(ref_audio_path):
91
+ return None, f"❌ Không tìm thấy file audio: {ref_audio_path}"
92
+
93
+ with open(ref_text_path, "r", encoding="utf-8") as f:
94
+ ref_text_raw = f.read()
95
+ print(f"🎤 Mode: Preset Voice ({voice_choice})")
96
+
97
+ # Inference & Đo thời gian
98
+ print(f"📝 Text: {text[:50]}...")
99
+
100
+ start_time = time.time()
101
+
102
+ ref_codes = tts.encode_reference(ref_audio_path)
103
+ wav = tts.infer(text, ref_codes, ref_text_raw)
104
+
105
+ end_time = time.time()
106
+ process_time = end_time - start_time
107
+
108
+ # Save
109
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
110
+ sf.write(tmp_file.name, wav, 24000)
111
+ output_path = tmp_file.name
112
+
113
+ return output_path, f"✅ Thành công! (Thời gian: {process_time:.2f}s)"
114
+
115
+ except Exception as e:
116
+ import traceback
117
+ traceback.print_exc()
118
+ return None, f"❌ Lỗi hệ thống: {str(e)}"
119
+
120
+ # --- 4. UI SETUP ---
121
+ theme = gr.themes.Soft(
122
+ primary_hue="indigo",
123
+ secondary_hue="cyan",
124
+ neutral_hue="slate",
125
+ font=[gr.themes.GoogleFont('Inter'), 'ui-sans-serif', 'system-ui'],
126
+ ).set(
127
+ button_primary_background_fill="linear-gradient(90deg, #6366f1 0%, #0ea5e9 100%)",
128
+ button_primary_background_fill_hover="linear-gradient(90deg, #4f46e5 0%, #0284c7 100%)",
129
+ block_shadow="0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)",
130
+ )
131
+
132
+ css = """
133
+ .container { max-width: 1200px; margin: auto; }
134
+ .header-box {
135
+ text-align: center;
136
+ margin-bottom: 25px;
137
+ padding: 25px;
138
+ background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%);
139
+ border-radius: 12px;
140
+ border: 1px solid #334155;
141
+ box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.3);
142
+ }
143
+ .header-title {
144
+ font-size: 2.5rem;
145
+ font-weight: 800;
146
+ color: white;
147
+ background: -webkit-linear-gradient(45deg, #60A5FA, #22D3EE);
148
+ -webkit-background-clip: text;
149
+ -webkit-text-fill-color: transparent;
150
+ margin-bottom: 10px;
151
+ }
152
+ .header-desc {
153
+ font-size: 1.1rem;
154
+ color: #cbd5e1;
155
+ margin-bottom: 15px;
156
+ }
157
+ .link-group a {
158
+ text-decoration: none;
159
+ margin: 0 10px;
160
+ font-weight: 600;
161
+ color: #94a3b8;
162
+ transition: color 0.2s;
163
+ }
164
+ .link-group a:hover { color: #38bdf8; text-shadow: 0 0 5px rgba(56, 189, 248, 0.5); }
165
+ .status-box { font-weight: bold; text-align: center; border: none; background: transparent; }
166
+ .warning-banner {
167
+ background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
168
+ border: 2px solid #f59e0b;
169
+ border-radius: 8px;
170
+ padding: 15px 20px;
171
+ margin: 15px 0;
172
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
173
+ }
174
+ .warning-banner-title {
175
+ font-size: 1.1rem;
176
+ font-weight: 700;
177
+ color: #92400e;
178
+ margin-bottom: 8px;
179
+ display: flex;
180
+ align-items: center;
181
+ gap: 8px;
182
+ }
183
+ .warning-banner-content {
184
+ color: #78350f;
185
+ font-size: 0.95rem;
186
+ line-height: 1.6;
187
+ }
188
+ .warning-banner-content strong {
189
+ color: #92400e;
190
+ font-weight: 600;
191
+ }
192
+ .warning-banner-content code {
193
+ background: #fef3c7;
194
+ padding: 2px 6px;
195
+ border-radius: 3px;
196
+ font-family: monospace;
197
+ color: #92400e;
198
+ font-weight: 500;
199
+ }
200
+ """
201
+
202
+ EXAMPLES_LIST = [
203
+ ["Về miền Tây không chỉ để ngắm nhìn sông nước hữu tình, mà còn để cảm nhận tấm chân tình của người dân nơi đây. Cùng ngồi xuồng ba lá len lỏi qua rặng dừa nước, nghe câu vọng cổ ngọt ngào thì còn gì bằng.", "Vĩnh (nam miền Nam)"],
204
+ ["Hà Nội những ngày vào thu mang một vẻ đẹp trầm mặc và cổ kính đến lạ thường. Đi dạo quanh Hồ Gươm vào sáng sớm, hít hà mùi hoa sữa nồng nàn và thưởng thức chút cốm làng Vòng là trải nghiệm khó quên.", "Bình (nam miền Bắc)"],
205
+ ["Sự bùng nổ của trí tuệ nhân tạo đang định hình lại cách chúng ta làm việc và sinh sống. Từ xe tự lái đến trợ lý ảo thông minh, công nghệ đang dần xóa nhòa ranh giới giữa thực tại và những bộ phim viễn tưởng.", "Tuyên (nam miền Bắc)"],
206
+ ["Sài Gòn hối hả là thế, nhưng chỉ cần tấp vào một quán cà phê ven đường, gọi ly bạc xỉu đá và ngắm nhìn dòng người qua lại, bạn sẽ thấy thành phố này cũng có những khoảng lặng thật bình yên và đáng yêu.", "Nguyên (nam miền Nam)"],
207
+ ["Ngày xửa ngày xưa, ở một ngôi làng nọ có cô Tấm xinh đẹp, nết na nhưng sớm mồ côi mẹ. Dù bị mẹ kế và Cám hãm hại đủ đường, Tấm vẫn giữ được tấm lòng lương thiện và cuối cùng tìm được hạnh phúc xứng đáng.", "Đoan (nữ miền Nam)"],
208
+ ["Dạ em chào anh chị, hiện tại bên em đang có chương trình ưu đãi đặc biệt cho căn hộ hướng sông này. Với thiết kế hiện đại và không gian xanh mát, đây chắc chắn là tổ ấm lý tưởng mà gia đình mình đang tìm kiếm.", "Ly (nữ miền Bắc)"],
209
+ ]
210
+
211
+ with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS Studio") as demo:
212
+
213
+ with gr.Column(elem_classes="container"):
214
+ # Header
215
+ gr.HTML("""
216
+ <div class="header-box">
217
+ <div class="header-title">🦜 VieNeu-TTS Studio</div>
218
+ <div class="header-desc">
219
+ Phiên bản: VieNeu-TTS (model mới nhất, train trên 1000 giờ dữ liệu)
220
+ </div>
221
+ <div class="link-group">
222
+ <a href="https://huggingface.co/pnnbao-ump/VieNeu-TTS" target="_blank">🤗 Model Card</a> •
223
+ <a href="https://huggingface.co/datasets/pnnbao-ump/VieNeu-TTS-1000h" target="_blank">📖 Dataset 1000h</a> •
224
+ <a href="https://github.com/pnnbao97/VieNeu-TTS" target="_blank">🦜 GitHub</a>
225
+ </div>
226
+ </div>
227
+ """)
228
+
229
+ # Performance Warning Banner
230
+ gr.HTML("""
231
+ <div class="warning-banner">
232
+ <div class="warning-banner-title">
233
+ ⚠️ Lưu ý về hiệu năng
234
+ </div>
235
+ <div class="warning-banner-content">
236
+ <strong>Demo này chạy trên HF Spaces với ZeroGPU (shared GPU)</strong> nên tốc độ sẽ <strong>chậm hơn</strong> và <strong>bị giới hạn 250 ký tự</strong> vì không thể triển khai lmdeploy trên HF space.<br><br>
237
+
238
+ 💡 <strong>Muốn tốc độ cực nhanh và không giới hạn ký tự?</strong> Hãy clone mã nguồn từ <a href="https://github.com/pnnbao97/VieNeu-TTS" target="_blank" style="color: #92400e; text-decoration: underline;">GitHub</a> và cài <code>lmdeploy</code> để chạy trên GPU của bạn:<br>
239
+
240
+ 🚀 Với LMDeploy + GPU local, tốc độ sẽ <strong>nhanh hơn 5-10 lần</strong> so với demo này!
241
+ </div>
242
+ </div>
243
+ """)
244
+
245
+ with gr.Row(elem_classes="container", equal_height=False):
246
+
247
+ # --- LEFT: INPUT ---
248
+ with gr.Column(scale=3, variant="panel"):
249
+ gr.Markdown("### 📝 Văn bản đầu vào")
250
+ text_input = gr.Textbox(
251
+ label="Nhập văn bản",
252
+ placeholder="Nhập nội dung tiếng Việt cần chuyển thành giọng nói...",
253
+ lines=4,
254
+ value="Sự bùng nổ của trí tuệ nhân tạo đang định hình lại cách chúng ta làm việc và sinh sống. Từ xe tự lái đến trợ lý ảo thông minh, công nghệ đang dần xóa nhòa ranh giới giữa thực tại và những bộ phim viễn tưởng.",
255
+ show_label=False
256
+ )
257
+
258
+ # Counter
259
+ with gr.Row():
260
+ char_count = gr.HTML("<div style='text-align: right; color: #64748B; font-size: 0.8rem;'>0 / 250 ký tự</div>")
261
+
262
+ gr.Markdown("### 🗣️ Chọn giọng đọc")
263
+ with gr.Tabs() as tabs:
264
+ with gr.TabItem("👤 Giọng có sẵn (Preset)", id="preset_mode"):
265
+ voice_select = gr.Dropdown(
266
+ choices=list(VOICE_SAMPLES.keys()),
267
+ value="Tuyên (nam miền Bắc)",
268
+ label="Danh sách giọng",
269
+ interactive=True
270
+ )
271
+ with gr.Accordion("Thông tin giọng mẫu", open=False):
272
+ ref_audio_preview = gr.Audio(label="Audio mẫu", interactive=False, type="filepath")
273
+ ref_text_preview = gr.Markdown("...")
274
+
275
+ with gr.TabItem("🎙️ Giọng tùy chỉnh (Custom)", id="custom_mode"):
276
+ gr.Markdown("Tải lên giọng của bạn (Zero-shot Cloning)")
277
+ custom_audio = gr.Audio(label="File ghi âm (.wav)", type="filepath")
278
+ custom_text = gr.Textbox(label="Nội dung ghi âm", placeholder="Nhập chính xác lời thoại...")
279
+
280
+ current_mode = gr.State(value="preset_mode")
281
+ btn_generate = gr.Button("Tổng hợp giọng nói", variant="primary", size="lg")
282
+
283
+ # --- RIGHT: OUTPUT ---
284
+ with gr.Column(scale=2):
285
+ gr.Markdown("### 🎧 Kết quả")
286
+ with gr.Group():
287
+ audio_output = gr.Audio(label="Audio đầu ra", type="filepath", autoplay=True)
288
+ status_output = gr.Textbox(label="Trạng thái", show_label=False, elem_classes="status-box", placeholder="Sẵn sàng...")
289
+
290
+ # --- EXAMPLES ---
291
+ with gr.Row(elem_classes="container"):
292
+ with gr.Column():
293
+ gr.Markdown("### 📚 Ví dụ mẫu")
294
+ gr.Examples(examples=EXAMPLES_LIST, inputs=[text_input, voice_select], label="Thử nghiệm nhanh")
295
+
296
+ # --- LOGIC ---
297
+ def update_count(text):
298
+ l = len(text)
299
+ if l > 250:
300
+ color = "#dc2626"
301
+ msg = f"⚠️ <b>{l} / 250</b> - Quá giới hạn!"
302
+ elif l > 200:
303
+ color = "#ea580c"
304
+ msg = f"{l} / 250"
305
+ else:
306
+ color = "#64748B"
307
+ msg = f"{l} / 250 ký tự"
308
+ return f"<div style='text-align: right; color: {color}; font-size: 0.8rem; font-weight: bold'>{msg}</div>"
309
+
310
+ text_input.change(update_count, text_input, char_count)
311
+
312
+ def update_ref_preview(voice):
313
+ audio, text = load_reference_info(voice)
314
+ return audio, f"> *\"{text}\"*"
315
+
316
+ voice_select.change(update_ref_preview, voice_select, [ref_audio_preview, ref_text_preview])
317
+ demo.load(update_ref_preview, voice_select, [ref_audio_preview, ref_text_preview])
318
+
319
+ # Tab handling
320
+ def set_preset_mode():
321
+ return "preset_mode"
322
+
323
+ def set_custom_mode():
324
+ return "custom_mode"
325
+
326
+ tabs.children[0].select(fn=set_preset_mode, outputs=current_mode)
327
+ tabs.children[1].select(fn=set_custom_mode, outputs=current_mode)
328
+
329
+ btn_generate.click(
330
+ fn=synthesize_speech,
331
+ inputs=[text_input, voice_select, custom_audio, custom_text, current_mode],
332
+ outputs=[audio_output, status_output]
333
+ )
334
+
335
+ if __name__ == "__main__":
336
+ demo.queue().launch(
337
+ server_name="0.0.0.0",
338
+ server_port=7860
339
+ )
config.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ text_settings:
2
+ max_chars_per_chunk: 256
3
+ max_total_chars_streaming: 3000
4
+
5
+ backbone_configs:
6
+ "VieNeu-TTS (GPU)":
7
+ repo: pnnbao-ump/VieNeu-TTS
8
+ supports_streaming: false
9
+ description: Chất lượng cao nhất, yêu cầu GPU
10
+ "VieNeu-TTS-0.3B (GPU)":
11
+ repo: pnnbao-ump/VieNeu-TTS-0.3B
12
+ supports_streaming: false
13
+ description: Phiên bản nhẹ cho GPU, tốc độ nhanh x2 so với phiên bản gốc
14
+ "VieNeu-TTS-q8-gguf":
15
+ repo: pnnbao-ump/VieNeu-TTS-q8-gguf
16
+ supports_streaming: true
17
+ description: Phiên bản GGUF có chất lượng cao nhất
18
+ "VieNeu-TTS-q4-gguf":
19
+ repo: pnnbao-ump/VieNeu-TTS-q4-gguf
20
+ supports_streaming: true
21
+ description: Cân bằng giữa chất lượng và tốc độ
22
+ "VieNeu-TTS-0.3B-q4-gguf":
23
+ repo: pnnbao-ump/VieNeu-TTS-0.3B-q4-gguf
24
+ supports_streaming: true
25
+ description: Phiên bản cực nhẹ, chạy mượt trên CPU
26
+
27
+ codec_configs:
28
+ "NeuCodec (Standard)":
29
+ repo: neuphonic/neucodec
30
+ description: Codec chuẩn, tốc độ trung bình
31
+ use_preencoded: false
32
+ "NeuCodec (Distill)":
33
+ repo: neuphonic/distill-neucodec
34
+ description: Codec tối ưu, tốc độ cao
35
+ use_preencoded: false
36
+ "NeuCodec ONNX (Fast CPU)":
37
+ repo: neuphonic/neucodec-onnx-decoder-int8
38
+ description: Tối ưu cho CPU, cần pre-encoded codes
39
+ use_preencoded: true
40
+
41
+ voice_samples:
42
+ "Tuyên (nam miền Bắc)":
43
+ audio: ./sample/Tuyên (nam miền Bắc).wav
44
+ text: ./sample/Tuyên (nam miền Bắc).txt
45
+ codes: ./sample/Tuyên (nam miền Bắc).pt
46
+ "Vĩnh (nam miền Nam)":
47
+ audio: ./sample/Vĩnh (nam miền Nam).wav
48
+ text: ./sample/Vĩnh (nam miền Nam).txt
49
+ codes: ./sample/Vĩnh (nam miền Nam).pt
50
+ "Bình (nam miền Bắc)":
51
+ audio: ./sample/Bình (nam miền Bắc).wav
52
+ text: ./sample/Bình (nam miền Bắc).txt
53
+ codes: ./sample/Bình (nam miền Bắc).pt
54
+ "Nguyên (nam miền Nam)":
55
+ audio: ./sample/Nguyên (nam miền Nam).wav
56
+ text: ./sample/Nguyên (nam miền Nam).txt
57
+ codes: ./sample/Nguyên (nam miền Nam).pt
58
+ "Sơn (nam miền Nam)":
59
+ audio: ./sample/Sơn (nam miền Nam).wav
60
+ text: ./sample/Sơn (nam miền Nam).txt
61
+ codes: ./sample/Sơn (nam miền Nam).pt
62
+ "Đoan (nữ miền Nam)":
63
+ audio: ./sample/Đoan (nữ miền Nam).wav
64
+ text: ./sample/Đoan (nữ miền Nam).txt
65
+ codes: ./sample/Đoan (nữ miền Nam).pt
66
+ "Ngọc (nữ miền Bắc)":
67
+ audio: ./sample/Ngọc (nữ miền Bắc).wav
68
+ text: ./sample/Ngọc (nữ miền Bắc).txt
69
+ codes: ./sample/Ngọc (nữ miền Bắc).pt
70
+ "Ly (nữ miền Bắc)":
71
+ audio: ./sample/Ly (nữ miền Bắc).wav
72
+ text: ./sample/Ly (nữ miền Bắc).txt
73
+ codes: ./sample/Ly (nữ miền Bắc).pt
74
+ "Dung (nữ miền Nam)":
75
+ audio: ./sample/Dung (nữ miền Nam).wav
76
+ text: ./sample/Dung (nữ miền Nam).txt
77
+ codes: ./sample/Dung (nữ miền Nam).pt
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ espeak-ng
2
+ libespeak-ng1
3
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ spaces
3
+ torchaudio
4
+ transformers
5
+ librosa
6
+ soundfile
7
+ numpy
8
+ phonemizer
9
+ neucodec
sample/Bình (nam miền Bắc).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f896d618fc46c3e131eda7b4168e25e9c2fb2d7ea0e864bedff2577fbd0bd30
3
+ size 2089
sample/Bình (nam miền Bắc).txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Anh chỉ muốn được nhìn nhận như là một huấn luyện viên.
sample/Bình (nam miền Bắc).wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:135f087ced48606c4d406b770a11e344d4d9aa6bd7adfb3e5c26f69cd9cc6df1
3
+ size 127054
sample/Dung (nữ miền Nam).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc4d65b6504470cb00e46763915060590595fbe4d47912eeacecd2bf1bade262
3
+ size 2153
sample/Dung (nữ miền Nam).txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Tục ngữ có câu, sai một li, đi một dặm.
sample/Dung (nữ miền Nam).wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56e42039d0c96ad19e9f78ecb7218853202022b2a8460010d34ffb7879b17409
3
+ size 143438
sample/Hương (nữ miền Bắc).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:919035b7c762956a7d568cebc6e69fea22eb9be02bf906c1d32c1db1d8c7b9ff
3
+ size 2217
sample/Hương (nữ miền Bắc).txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Tuy nhiên, lúc này có một vấn đề khó khăn nảy sinh.
sample/Hương (nữ miền Bắc).wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c064b4ec64df44ea1306e87b25b84905e540d0fe29885629d2fe8bc8a5e53bc
3
+ size 155756
sample/Ly (nữ miền Bắc).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69b6bc9bb1062122dc3755be907d87f232fa8be5129b54f6994dead35f4935c6
3
+ size 2153
sample/Ly (nữ miền Bắc).txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Chúng ta có thể áp dụng logic tương tự với người khác.
sample/Ly (nữ miền Bắc).wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d4e47cfa5ed0b753c2bed07c58e26da89ee2977ca5e941244a6bbafd8869d5e
3
+ size 147534
sample/Nguyên (nam miền Nam).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e6ebaa0b2977589afa7e7f811b0553151bd8312c96a70b1b666bd9d0fd50edf
3
+ size 2345
sample/Nguyên (nam miền Nam).txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Hiểu biết về bản thân và người khác bắt đầu từ chính cơ thể mình.
sample/Nguyên (nam miền Nam).wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fc9655e24a7048c3c908494f2cbaf4c42d3d139d68c21d4f50c60e03aa19727
3
+ size 196124
sample/Ngọc (nữ miền Bắc).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78ab670f177092dc8586e45536faea20fdb84471dc8d8a8b1b95dd76a4ed3d0d
3
+ size 2281
sample/Ngọc (nữ miền Bắc).txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Trong phòng rất tù mù, nên có thể dễ dàng che dấu nó.
sample/Ngọc (nữ miền Bắc).wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:475a73298fbe86e5d92e7fb95c6c26e897e5a2ffbdc3fa9e062df4025767af93
3
+ size 174956
sample/Sơn (nam miền Nam).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:114cb04ee2357d06de2f038853bbeb0dc57fc8ed30e085118a9e0bf5a70f7857
3
+ size 2281
sample/Sơn (nam miền Nam).txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Trên thực tế, các nghi ngờ đã bắt đầu xuất hiện.
sample/Sơn (nam miền Nam).wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f6733df04e5f3477a00136c6baeaf7a196c93df0ee13b9bfa3d8ba61034f063
3
+ size 174044
sample/Tuyên (nam miền Bắc).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e79eb6ee9cc7cd35cb4fbbef107249ed3209608b59644c52f55a34941a531873
3
+ size 2473
sample/Tuyên (nam miền Bắc).txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Bạn cầm khúc cây, và ném vào bãi cỏ xanh tươi rậm rạp ở đằng xa.
sample/Tuyên (nam miền Bắc).wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6b7ac2605db0a2cf634ce0f3a55a87a89f4c2e3bc06f83433e6af583c1f3692
3
+ size 217166
sample/Vĩnh (nam miền Nam).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c87342d3a6a8cbaaf2139c21e7554eea19aba6aa03248e4426238a1c2507e447
3
+ size 2217
sample/Vĩnh (nam miền Nam).txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Đến cuối thế kỷ 19, ngành đánh bắt cá được thương mại hóa.
sample/Vĩnh (nam miền Nam).wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:632a5c8fa34fe03001cc3c44427b5e0ee70f767377bc788b59a5dc9afa9fba49
3
+ size 164492
sample/Đoan (nữ miền Nam).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28b48dbae193adc88aa26243086ba3ce862def7035d9793613c2967df29f9afe
3
+ size 2793
sample/Đoan (nữ miền Nam).txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Nuôi con theo phong cách Do Thái, không chỉ tốt cho đứa trẻ, mà còn tốt cho cả các bậc cha mẹ.
sample/Đoan (nữ miền Nam).wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e319ed45dd2a1458a52edfe43a83a36eff813f19399ac2e59ee3f93cace74be
3
+ size 294830
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (133 Bytes). View file
 
utils/__pycache__/core_utils.cpython-312.pyc ADDED
Binary file (2.22 kB). View file
 
utils/__pycache__/normalize_text.cpython-312.pyc ADDED
Binary file (24.6 kB). View file
 
utils/__pycache__/phonemize_text.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
utils/core_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ from typing import List
4
+
5
+ def split_text_into_chunks(text: str, max_chars: int = 256) -> List[str]:
6
+ """
7
+ Split raw text into chunks no longer than max_chars.
8
+ """
9
+ sentences = re.split(r"(?<=[\.\!\?\…\n])\s+|(?<=\n)", text.strip())
10
+ chunks: List[str] = []
11
+ buffer = ""
12
+
13
+ def flush_buffer():
14
+ nonlocal buffer
15
+ if buffer:
16
+ chunks.append(buffer.strip())
17
+ buffer = ""
18
+
19
+ for sentence in sentences:
20
+ sentence = sentence.strip()
21
+ if not sentence:
22
+ continue
23
+
24
+ if len(sentence) <= max_chars:
25
+ candidate = f"{buffer} {sentence}".strip() if buffer else sentence
26
+ if len(candidate) <= max_chars:
27
+ buffer = candidate
28
+ else:
29
+ flush_buffer()
30
+ buffer = sentence
31
+ continue
32
+
33
+ flush_buffer()
34
+ words = sentence.split()
35
+ current = ""
36
+ for word in words:
37
+ candidate = f"{current} {word}".strip() if current else word
38
+ if len(candidate) > max_chars and current:
39
+ chunks.append(current.strip())
40
+ current = word
41
+ else:
42
+ current = candidate
43
+ if current:
44
+ chunks.append(current.strip())
45
+
46
+ flush_buffer()
47
+ return [chunk for chunk in chunks if chunk]
48
+
49
+ def env_bool(name: str, default: bool = False) -> bool:
50
+ v = os.getenv(name)
51
+ if v is None:
52
+ return default
53
+ return v.strip().lower() in ("1", "true", "yes", "y", "on")
utils/normalize_text.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ class VietnameseTTSNormalizer:
4
+ """
5
+ A text normalizer for Vietnamese Text-to-Speech systems.
6
+ Converts numbers, dates, units, and special characters into readable Vietnamese text.
7
+ """
8
+
9
+ def __init__(self):
10
+ self.units = {
11
+ 'km': 'ki lô mét', 'dm': 'đê xi mét', 'cm': 'xen ti mét',
12
+ 'mm': 'mi li mét', 'nm': 'na nô mét', 'µm': 'mic rô mét',
13
+ 'μm': 'mic rô mét', 'm': 'mét',
14
+
15
+ 'kg': 'ki lô gam', 'g': 'gam', 'mg': 'mi li gam',
16
+
17
+ 'km²': 'ki lô mét vuông', 'km2': 'ki lô mét vuông',
18
+ 'm²': 'mét vuông', 'm2': 'mét vuông',
19
+ 'cm²': 'xen ti mét vuông', 'cm2': 'xen ti mét vuông',
20
+ 'mm²': 'mi li mét vuông', 'mm2': 'mi li mét vuông',
21
+ 'ha': 'héc ta',
22
+
23
+ 'km³': 'ki lô mét khối', 'km3': 'ki lô mét khối',
24
+ 'm³': 'mét khối', 'm3': 'mét khối',
25
+ 'cm³': 'xen ti mét khối', 'cm3': 'xen ti mét khối',
26
+ 'mm³': 'mi li mét khối', 'mm3': 'mi li mét khối',
27
+ 'l': 'lít', 'dl': 'đê xi lít', 'ml': 'mi li lít', 'hl': 'héc tô lít',
28
+
29
+ 'v': 'vôn', 'kv': 'ki lô vôn', 'mv': 'mi li vôn',
30
+ 'a': 'am pe', 'ma': 'mi li am pe', 'ka': 'ki lô am pe',
31
+ 'w': 'oát', 'kw': 'ki lô oát', 'mw': 'mê ga oát', 'gw': 'gi ga oát',
32
+ 'kwh': 'ki lô oát giờ', 'mwh': 'mê ga oát giờ', 'wh': 'oát giờ',
33
+ 'ω': 'ôm', 'ohm': 'ôm', 'kω': 'ki lô ôm', 'mω': 'mê ga ôm',
34
+
35
+ 'hz': 'héc', 'khz': 'ki lô héc', 'mhz': 'mê ga héc', 'ghz': 'gi ga héc',
36
+
37
+ 'pa': 'pát cal', 'kpa': 'ki lô pát cal', 'mpa': 'mê ga pát cal',
38
+ 'bar': 'ba', 'mbar': 'mi li ba', 'atm': 'át mốt phia', 'psi': 'pi ét xai',
39
+
40
+ 'j': 'giun', 'kj': 'ki lô giun',
41
+ 'cal': 'ca lo', 'kcal': 'ki lô ca lo',
42
+ }
43
+
44
+ self.digits = ['không', 'một', 'hai', 'ba', 'bốn',
45
+ 'năm', 'sáu', 'bảy', 'tám', 'chín']
46
+
47
+ def normalize(self, text):
48
+ """Main normalization pipeline with EN tag protection."""
49
+ # Step 1: Extract and protect EN tags
50
+ en_contents = []
51
+ placeholder_pattern = "___EN_PLACEHOLDER_{}___ "
52
+
53
+ def extract_en(match):
54
+ en_contents.append(match.group(0))
55
+ return placeholder_pattern.format(len(en_contents) - 1)
56
+
57
+ text = re.sub(r'<en>.*?</en>', extract_en, text, flags=re.IGNORECASE)
58
+
59
+ # Step 2: Normal normalization pipeline
60
+ text = text.lower()
61
+ text = self._normalize_temperature(text)
62
+ text = self._normalize_currency(text)
63
+ text = self._normalize_percentage(text)
64
+ text = self._normalize_units(text)
65
+ text = self._normalize_time(text)
66
+ text = self._normalize_date(text)
67
+ text = self._normalize_phone(text)
68
+ text = self._normalize_numbers(text)
69
+ text = self._number_to_words(text)
70
+ text = self._normalize_special_chars(text)
71
+ text = self._normalize_whitespace(text)
72
+
73
+ # Step 3: Restore EN tags
74
+ for idx, en_content in enumerate(en_contents):
75
+ text = text.replace(placeholder_pattern.format(idx).lower(), en_content + ' ')
76
+
77
+ # Final whitespace cleanup
78
+ text = self._normalize_whitespace(text)
79
+
80
+ return text
81
+
82
+ def _normalize_temperature(self, text):
83
+ """Convert temperature notation to words."""
84
+ text = re.sub(r'-(\d+(?:[.,]\d+)?)\s*°\s*c\b', r'âm \1 độ xê', text, flags=re.IGNORECASE)
85
+ text = re.sub(r'-(\d+(?:[.,]\d+)?)\s*°\s*f\b', r'âm \1 độ ép', text, flags=re.IGNORECASE)
86
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*°\s*c\b', r'\1 độ xê', text, flags=re.IGNORECASE)
87
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*°\s*f\b', r'\1 độ ép', text, flags=re.IGNORECASE)
88
+ text = re.sub(r'°', ' độ ', text)
89
+ return text
90
+
91
+ def _normalize_currency(self, text):
92
+ """Convert currency notation to words."""
93
+ def decimal_currency(match):
94
+ whole = match.group(1)
95
+ decimal = match.group(2)
96
+ unit = match.group(3)
97
+ decimal_words = ' '.join([self.digits[int(d)] for d in decimal])
98
+ unit_map = {'k': 'nghìn', 'm': 'triệu', 'b': 'tỷ'}
99
+ unit_word = unit_map.get(unit.lower(), unit)
100
+ return f"{whole} phẩy {decimal_words} {unit_word}"
101
+
102
+ text = re.sub(r'(\d+)[.,](\d+)\s*([kmb])\b', decimal_currency, text, flags=re.IGNORECASE)
103
+ text = re.sub(r'(\d+)\s*k\b', r'\1 nghìn', text, flags=re.IGNORECASE)
104
+ text = re.sub(r'(\d+)\s*m\b', r'\1 triệu', text, flags=re.IGNORECASE)
105
+ text = re.sub(r'(\d+)\s*b\b', r'\1 tỷ', text, flags=re.IGNORECASE)
106
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*đ\b', r'\1 đồng', text)
107
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*vnd\b', r'\1 đồng', text, flags=re.IGNORECASE)
108
+ text = re.sub(r'\$\s*(\d+(?:[.,]\d+)?)', r'\1 đô la', text)
109
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*\$', r'\1 đô la', text)
110
+ return text
111
+
112
+ def _normalize_percentage(self, text):
113
+ """Convert percentage to words."""
114
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*%', r'\1 phần trăm', text)
115
+ return text
116
+
117
+ def _normalize_units(self, text):
118
+ """Convert measurement units to words."""
119
+ def expand_compound_with_number(match):
120
+ number = match.group(1)
121
+ unit1 = match.group(2).lower()
122
+ unit2 = match.group(3).lower()
123
+ full_unit1 = self.units.get(unit1, unit1)
124
+ full_unit2 = self.units.get(unit2, unit2)
125
+ return f"{number} {full_unit1} trên {full_unit2}"
126
+
127
+ def expand_compound_without_number(match):
128
+ unit1 = match.group(1).lower()
129
+ unit2 = match.group(2).lower()
130
+ full_unit1 = self.units.get(unit1, unit1)
131
+ full_unit2 = self.units.get(unit2, unit2)
132
+ return f"{full_unit1} trên {full_unit2}"
133
+
134
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*([a-zA-Zμµ²³°]+)/([a-zA-Zμµ²³°0-9]+)\b',
135
+ expand_compound_with_number, text)
136
+ text = re.sub(r'\b([a-zA-Zμµ²³°]+)/([a-zA-Zμµ²³°0-9]+)\b',
137
+ expand_compound_without_number, text)
138
+
139
+ sorted_units = sorted(self.units.items(), key=lambda x: len(x[0]), reverse=True)
140
+ for unit, full_name in sorted_units:
141
+ pattern = r'(\d+(?:[.,]\d+)?)\s*' + re.escape(unit) + r'\b'
142
+ text = re.sub(pattern, rf'\1 {full_name}', text, flags=re.IGNORECASE)
143
+
144
+ for unit, full_name in sorted_units:
145
+ if any(c in unit for c in '²³°'):
146
+ pattern = r'\b' + re.escape(unit) + r'\b'
147
+ text = re.sub(pattern, full_name, text, flags=re.IGNORECASE)
148
+
149
+ return text
150
+
151
+ def _normalize_time(self, text):
152
+ """Convert time notation to words with validation."""
153
+
154
+ def validate_and_convert_time(match):
155
+ """Validate time components before converting."""
156
+ groups = match.groups()
157
+
158
+ # HH:MM:SS format
159
+ if len(groups) == 3:
160
+ hour, minute, second = groups
161
+ hour_int, minute_int, second_int = int(hour), int(minute), int(second)
162
+
163
+ if not (0 <= hour_int <= 23):
164
+ return match.group(0)
165
+ if not (0 <= minute_int <= 59):
166
+ return match.group(0)
167
+ if not (0 <= second_int <= 59):
168
+ return match.group(0)
169
+
170
+ return f"{hour} giờ {minute} phút {second} giây"
171
+
172
+ # HH:MM or HHhMM format
173
+ elif len(groups) == 2:
174
+ hour, minute = groups
175
+ hour_int, minute_int = int(hour), int(minute)
176
+
177
+ if not (0 <= hour_int <= 23):
178
+ return match.group(0)
179
+ if not (0 <= minute_int <= 59):
180
+ return match.group(0)
181
+
182
+ return f"{hour} giờ {minute} phút"
183
+
184
+ # HHh format
185
+ else:
186
+ hour = groups[0]
187
+ hour_int = int(hour)
188
+
189
+ if not (0 <= hour_int <= 23):
190
+ return match.group(0)
191
+
192
+ return f"{hour} giờ"
193
+
194
+ text = re.sub(r'(\d{1,2}):(\d{2}):(\d{2})', validate_and_convert_time, text)
195
+ text = re.sub(r'(\d{1,2}):(\d{2})', validate_and_convert_time, text)
196
+ text = re.sub(r'(\d{1,2})h(\d{2})', validate_and_convert_time, text)
197
+ text = re.sub(r'(\d{1,2})h\b', validate_and_convert_time, text)
198
+
199
+ return text
200
+
201
+ def _normalize_date(self, text):
202
+ """Convert date notation to words with validation."""
203
+
204
+ def is_valid_date(day, month, year):
205
+ """Check if date components are valid."""
206
+ day, month, year = int(day), int(month), int(year)
207
+
208
+ if not (1 <= day <= 31):
209
+ return False
210
+ if not (1 <= month <= 12):
211
+ return False
212
+
213
+ return True
214
+
215
+ def date_to_text(match):
216
+ day, month, year = match.groups()
217
+ if is_valid_date(day, month, year):
218
+ return f"ngày {day} tháng {month} năm {year}"
219
+ return match.group(0)
220
+
221
+ def date_iso_to_text(match):
222
+ year, month, day = match.groups()
223
+ if is_valid_date(day, month, year):
224
+ return f"ngày {day} tháng {month} năm {year}"
225
+ return match.group(0)
226
+
227
+ def date_short_year(match):
228
+ day, month, year = match.groups()
229
+ full_year = f"20{year}" if int(year) < 50 else f"19{year}"
230
+ if is_valid_date(day, month, full_year):
231
+ return f"ngày {day} tháng {month} năm {full_year}"
232
+ return match.group(0)
233
+
234
+ text = re.sub(r'\bngày\s+(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})\b',
235
+ lambda m: date_to_text(m).replace('ngày ngày', 'ngày'), text)
236
+ text = re.sub(r'\bngày\s+(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})\b',
237
+ lambda m: date_short_year(m).replace('ngày ngày', 'ngày'), text)
238
+ text = re.sub(r'\b(\d{4})-(\d{1,2})-(\d{1,2})\b', date_iso_to_text, text)
239
+ text = re.sub(r'\b(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})\b', date_to_text, text)
240
+ text = re.sub(r'\b(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})\b', date_short_year, text)
241
+
242
+ return text
243
+
244
+ def _normalize_phone(self, text):
245
+ """Convert phone numbers to digit-by-digit reading."""
246
+ def phone_to_text(match):
247
+ phone = match.group(0)
248
+ phone = re.sub(r'[^\d]', '', phone)
249
+
250
+ if phone.startswith('84') and len(phone) >= 10:
251
+ phone = '0' + phone[2:]
252
+
253
+ if 10 <= len(phone) <= 11:
254
+ words = [self.digits[int(d)] for d in phone]
255
+ return ' '.join(words) + ' '
256
+
257
+ return match.group(0)
258
+
259
+ text = re.sub(r'(\+84|84)[\s\-\.]?\d[\d\s\-\.]{7,}', phone_to_text, text)
260
+ text = re.sub(r'\b0\d[\d\s\-\.]{8,}', phone_to_text, text)
261
+ return text
262
+
263
+ def _normalize_numbers(self, text):
264
+ text = re.sub(r'(\d+(?:[,.]\d+)?)%', lambda m: f'{m.group(1)} phần trăm', text)
265
+ text = re.sub(r'(\d{1,3})(?:\.(\d{3}))+', lambda m: m.group(0).replace('.', ''), text)
266
+
267
+ def decimal_to_words(match):
268
+ whole = match.group(1)
269
+ decimal = match.group(2)
270
+ decimal_words = ' '.join([self.digits[int(d)] for d in decimal])
271
+ separator = 'phẩy' if ',' in match.group(0) else 'chấm'
272
+ return f"{whole} {separator} {decimal_words}"
273
+
274
+ text = re.sub(r'(\d+),(\d+)', decimal_to_words, text)
275
+ text = re.sub(r'(\d+)\.(\d{1,2})\b', decimal_to_words, text)
276
+
277
+ return text
278
+
279
+ def _read_two_digits(self, n):
280
+ """Read two-digit numbers in Vietnamese."""
281
+ if n < 10:
282
+ return self.digits[n]
283
+ elif n == 10:
284
+ return "mười"
285
+ elif n < 20:
286
+ if n == 15:
287
+ return "mười lăm"
288
+ return f"mười {self.digits[n % 10]}"
289
+ else:
290
+ tens = n // 10
291
+ ones = n % 10
292
+ if ones == 0:
293
+ return f"{self.digits[tens]} mươi"
294
+ elif ones == 1:
295
+ return f"{self.digits[tens]} mươi mốt"
296
+ elif ones == 5:
297
+ return f"{self.digits[tens]} mươi lăm"
298
+ else:
299
+ return f"{self.digits[tens]} mươi {self.digits[ones]}"
300
+
301
+ def _read_three_digits(self, n):
302
+ """Read three-digit numbers in Vietnamese."""
303
+ if n < 100:
304
+ return self._read_two_digits(n)
305
+
306
+ hundreds = n // 100
307
+ remainder = n % 100
308
+ result = f"{self.digits[hundreds]} trăm"
309
+
310
+ if remainder == 0:
311
+ return result
312
+ elif remainder < 10:
313
+ result += f" lẻ {self.digits[remainder]}"
314
+ else:
315
+ result += f" {self._read_two_digits(remainder)}"
316
+
317
+ return result
318
+
319
+ def _convert_number_to_words(self, num):
320
+ """Convert a number to Vietnamese words."""
321
+ if num == 0:
322
+ return "không"
323
+
324
+ if num < 0:
325
+ return f"âm {self._convert_number_to_words(-num)}"
326
+
327
+ if num >= 1000000000:
328
+ billion = num // 1000000000
329
+ remainder = num % 1000000000
330
+ result = f"{self._read_three_digits(billion)} tỷ"
331
+ if remainder > 0:
332
+ result += f" {self._convert_number_to_words(remainder)}"
333
+ return result
334
+
335
+ elif num >= 1000000:
336
+ million = num // 1000000
337
+ remainder = num % 1000000
338
+ result = f"{self._read_three_digits(million)} triệu"
339
+ if remainder > 0:
340
+ result += f" {self._convert_number_to_words(remainder)}"
341
+ return result
342
+
343
+ elif num >= 1000:
344
+ thousand = num // 1000
345
+ remainder = num % 1000
346
+ result = f"{self._read_three_digits(thousand)} nghìn"
347
+ if remainder > 0:
348
+ if remainder < 10:
349
+ result += f" không trăm lẻ {self.digits[remainder]}"
350
+ elif remainder < 100:
351
+ result += f" không trăm {self._read_two_digits(remainder)}"
352
+ else:
353
+ result += f" {self._read_three_digits(remainder)}"
354
+ return result
355
+
356
+ else:
357
+ return self._read_three_digits(num)
358
+
359
+ def _number_to_words(self, text):
360
+ """Convert all remaining numbers to words."""
361
+ def convert_number(match):
362
+ num = int(match.group(0))
363
+ return self._convert_number_to_words(num)
364
+
365
+ text = re.sub(r'\b\d+\b', convert_number, text)
366
+ return text
367
+
368
+ def _normalize_special_chars(self, text):
369
+ """Handle special characters."""
370
+ text = text.replace('&', ' và ')
371
+ text = text.replace('+', ' cộng ')
372
+ text = text.replace('=', ' bằng ')
373
+ text = text.replace('#', ' thăng ')
374
+ text = re.sub(r'[\[\]\(\)\{\}]', ' ', text)
375
+ text = re.sub(r'\s+[-–—]+\s+', ' ', text)
376
+ text = re.sub(r'\.{2,}', ' ', text)
377
+ text = re.sub(r'\s+\.\s+', ' ', text)
378
+ text = re.sub(r'[^\w\sàáảãạăắằẳẵặâấầẩẫậèéẻẽẹêếềểễệìíỉĩịòóỏõọôốồổỗộơớờởỡợùúủũụưứừửữựỳýỷỹỵđ.,!?;:@%_]', ' ', text)
379
+ return text
380
+
381
+ def _normalize_whitespace(self, text):
382
+ """Normalize whitespace."""
383
+ text = re.sub(r'\s+', ' ', text)
384
+ text = text.strip()
385
+ return text
386
+
387
+
388
+ if __name__ == "__main__":
389
+ normalizer = VietnameseTTSNormalizer()
390
+
391
+ test_texts = [
392
+ "Chào mừng <en>hello world</en> đến với AI",
393
+ "Công nghệ <en>machine learning</en> và <en>deep learning</en>",
394
+ "Giá 2.500.000đ với <en>discount</en> 50%",
395
+ "Nhiệt độ 25°C, <en>temperature</en> cao",
396
+ "Hệ thống <en>text-to-speech</en> tiếng Việt",
397
+ ]
398
+
399
+ print("=" * 80)
400
+ print("VIETNAMESE TTS NORMALIZATION TEST (WITH EN TAG)")
401
+ print("=" * 80)
402
+
403
+ for text in test_texts:
404
+ print(f"\n📝 Input: {text}")
405
+ normalized = normalizer.normalize(text)
406
+ print(f"🎵 Output: {normalized}")
407
+ print("-" * 80)
utils/phoneme_dict.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:331f9583a0ac0c795000b569e141e0c9d50c3005d02c49b631fb30edb4b407d4
3
+ size 18078190
utils/phonemize_text.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import platform
4
+ import glob
5
+ import re
6
+ from phonemizer import phonemize
7
+ from phonemizer.backend.espeak.espeak import EspeakWrapper
8
+ from utils.normalize_text import VietnameseTTSNormalizer
9
+
10
+ # Configuration
11
+ PHONEME_DICT_PATH = os.getenv(
12
+ 'PHONEME_DICT_PATH',
13
+ os.path.join(os.path.dirname(__file__), "phoneme_dict.json")
14
+ )
15
+
16
+ def load_phoneme_dict(path=PHONEME_DICT_PATH):
17
+ """Load phoneme dictionary from JSON file."""
18
+ try:
19
+ with open(path, "r", encoding="utf-8") as f:
20
+ return json.load(f)
21
+ except FileNotFoundError:
22
+ raise FileNotFoundError(
23
+ f"Phoneme dictionary not found at {path}. "
24
+ "Please create it or set PHONEME_DICT_PATH environment variable."
25
+ )
26
+
27
+ def setup_espeak_library():
28
+ """Configure eSpeak library path based on operating system."""
29
+ system = platform.system()
30
+
31
+ if system == "Windows":
32
+ _setup_windows_espeak()
33
+ elif system == "Linux":
34
+ _setup_linux_espeak()
35
+ elif system == "Darwin":
36
+ _setup_macos_espeak()
37
+ else:
38
+ raise OSError(
39
+ f"Unsupported OS: {system}. "
40
+ "Only Windows, Linux, and macOS are supported."
41
+ )
42
+
43
+ def _setup_windows_espeak():
44
+ """Setup eSpeak for Windows."""
45
+ default_path = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"
46
+ if os.path.exists(default_path):
47
+ EspeakWrapper.set_library(default_path)
48
+ else:
49
+ raise FileNotFoundError(
50
+ f"eSpeak library not found at {default_path}. "
51
+ "Please install eSpeak NG from: https://github.com/espeak-ng/espeak-ng/releases"
52
+ )
53
+
54
+ def _setup_linux_espeak():
55
+ """Setup eSpeak for Linux."""
56
+ search_patterns = [
57
+ "/usr/lib/x86_64-linux-gnu/libespeak-ng.so*",
58
+ "/usr/lib/x86_64-linux-gnu/libespeak.so*",
59
+ "/usr/lib/libespeak-ng.so*",
60
+ "/usr/lib64/libespeak-ng.so*",
61
+ "/usr/local/lib/libespeak-ng.so*",
62
+ ]
63
+
64
+ for pattern in search_patterns:
65
+ matches = glob.glob(pattern)
66
+ if matches:
67
+ EspeakWrapper.set_library(sorted(matches, key=len)[0])
68
+ return
69
+
70
+ raise RuntimeError(
71
+ "eSpeak NG library not found. Install with:\n"
72
+ " Ubuntu/Debian: sudo apt-get install espeak-ng\n"
73
+ " Fedora: sudo dnf install espeak-ng\n"
74
+ " Arch: sudo pacman -S espeak-ng\n"
75
+ "See: https://github.com/pnnbao97/VieNeu-TTS/issues/5"
76
+ )
77
+
78
+ def _setup_macos_espeak():
79
+ """Setup eSpeak for macOS."""
80
+ espeak_lib = os.environ.get('PHONEMIZER_ESPEAK_LIBRARY')
81
+
82
+ paths_to_check = [
83
+ espeak_lib,
84
+ "/opt/homebrew/lib/libespeak-ng.dylib", # Apple Silicon
85
+ "/usr/local/lib/libespeak-ng.dylib", # Intel
86
+ "/opt/local/lib/libespeak-ng.dylib", # MacPorts
87
+ ]
88
+
89
+ for path in paths_to_check:
90
+ if path and os.path.exists(path):
91
+ EspeakWrapper.set_library(path)
92
+ return
93
+
94
+ raise FileNotFoundError(
95
+ "eSpeak library not found. Install with:\n"
96
+ " brew install espeak-ng\n"
97
+ "Or set: export PHONEMIZER_ESPEAK_LIBRARY=/path/to/libespeak-ng.dylib"
98
+ )
99
+
100
+ # Initialize
101
+ try:
102
+ setup_espeak_library()
103
+ phoneme_dict = load_phoneme_dict()
104
+ normalizer = VietnameseTTSNormalizer()
105
+ except Exception as e:
106
+ print(f"Initialization error: {e}")
107
+ raise
108
+
109
+ def phonemize_text(text: str) -> str:
110
+ """
111
+ Convert text to phonemes (simple version without dict, without EN tag).
112
+ Kept for backward compatibility.
113
+ """
114
+ text = normalizer.normalize(text)
115
+ return phonemize(
116
+ text,
117
+ language="vi",
118
+ backend="espeak",
119
+ preserve_punctuation=True,
120
+ with_stress=True,
121
+ language_switch="remove-flags"
122
+ )
123
+
124
+
125
+ def phonemize_with_dict(text: str, phoneme_dict=phoneme_dict) -> str:
126
+ """
127
+ Phonemize single text with dictionary lookup and EN tag support.
128
+ """
129
+ text = normalizer.normalize(text)
130
+
131
+ # Split by EN tags
132
+ parts = re.split(r'(<en>.*?</en>)', text, flags=re.IGNORECASE)
133
+
134
+ en_texts = []
135
+ en_indices = []
136
+ vi_texts = []
137
+ vi_indices = []
138
+ vi_word_maps = []
139
+
140
+ processed_parts = []
141
+
142
+ for part_idx, part in enumerate(parts):
143
+ if re.match(r'<en>.*</en>', part, re.IGNORECASE):
144
+ # English part
145
+ en_content = re.sub(r'</?en>', '', part, flags=re.IGNORECASE).strip()
146
+ en_texts.append(en_content)
147
+ en_indices.append(part_idx)
148
+ processed_parts.append(None)
149
+ else:
150
+ # Vietnamese part
151
+ words = part.split()
152
+ processed_words = []
153
+
154
+ for word_idx, word in enumerate(words):
155
+ match = re.match(r'^(\W*)(.*?)(\W*)$', word)
156
+ pre, core, suf = match.groups() if match else ("", word, "")
157
+
158
+ if not core:
159
+ processed_words.append(word)
160
+ elif core in phoneme_dict:
161
+ processed_words.append(f"{pre}{phoneme_dict[core]}{suf}")
162
+ else:
163
+ vi_texts.append(word)
164
+ vi_indices.append(part_idx)
165
+ vi_word_maps.append((part_idx, len(processed_words)))
166
+ processed_words.append(None)
167
+
168
+ processed_parts.append(processed_words)
169
+
170
+ if en_texts:
171
+ try:
172
+ en_phonemes = phonemize(
173
+ en_texts,
174
+ language='en-us',
175
+ backend='espeak',
176
+ preserve_punctuation=True,
177
+ with_stress=True,
178
+ language_switch="remove-flags"
179
+ )
180
+
181
+ if isinstance(en_phonemes, str):
182
+ en_phonemes = [en_phonemes]
183
+
184
+ for idx, (part_idx, phoneme) in enumerate(zip(en_indices, en_phonemes)):
185
+ processed_parts[part_idx] = phoneme.strip()
186
+ except Exception as e:
187
+ print(f"Warning: Could not phonemize EN texts: {e}")
188
+ for part_idx in en_indices:
189
+ processed_parts[part_idx] = en_texts[en_indices.index(part_idx)]
190
+
191
+ if vi_texts:
192
+ try:
193
+ vi_phonemes = phonemize(
194
+ vi_texts,
195
+ language='vi',
196
+ backend='espeak',
197
+ preserve_punctuation=True,
198
+ with_stress=True,
199
+ language_switch='remove-flags'
200
+ )
201
+
202
+ if isinstance(vi_phonemes, str):
203
+ vi_phonemes = [vi_phonemes]
204
+
205
+ for idx, (part_idx, word_idx) in enumerate(vi_word_maps):
206
+ phoneme = vi_phonemes[idx].strip()
207
+
208
+ original_word = vi_texts[idx]
209
+ if original_word.lower().startswith('r'):
210
+ phoneme = 'ɹ' + phoneme[1:] if len(phoneme) > 0 else phoneme
211
+
212
+ phoneme_dict[original_word] = phoneme
213
+
214
+ if processed_parts[part_idx] is not None:
215
+ processed_parts[part_idx][word_idx] = phoneme
216
+ except Exception as e:
217
+ print(f"Warning: Could not phonemize VI texts: {e}")
218
+ for idx, (part_idx, word_idx) in enumerate(vi_word_maps):
219
+ if processed_parts[part_idx] is not None:
220
+ processed_parts[part_idx][word_idx] = vi_texts[idx]
221
+
222
+ final_parts = []
223
+ for part in processed_parts:
224
+ if isinstance(part, list):
225
+ final_parts.append(' '.join(str(w) for w in part if w is not None))
226
+ elif part is not None:
227
+ final_parts.append(part)
228
+
229
+ result = ' '.join(final_parts)
230
+
231
+ result = re.sub(r'\s+([.,!?;:])', r'\1', result)
232
+
233
+ return result
234
+
235
+
236
+ def phonemize_batch(texts: list, phoneme_dict=phoneme_dict) -> list:
237
+ """
238
+ Phonemize multiple texts with optimal batching.
239
+
240
+ Args:
241
+ texts: List of text strings to phonemize
242
+ phoneme_dict: Phoneme dictionary for lookup
243
+
244
+ Returns:
245
+ List of phonemized texts
246
+ """
247
+ normalized_texts = [normalizer.normalize(text) for text in texts]
248
+
249
+ all_en_texts = []
250
+ all_en_maps = []
251
+
252
+ all_vi_texts = []
253
+ all_vi_maps = []
254
+
255
+ results = []
256
+
257
+ for text_idx, text in enumerate(normalized_texts):
258
+ parts = re.split(r'(<en>.*?</en>)', text, flags=re.IGNORECASE)
259
+ processed_parts = []
260
+
261
+ for part_idx, part in enumerate(parts):
262
+ if re.match(r'<en>.*</en>', part, re.IGNORECASE):
263
+ en_content = re.sub(r'</?en>', '', part, flags=re.IGNORECASE).strip()
264
+ all_en_texts.append(en_content)
265
+ all_en_maps.append((text_idx, part_idx))
266
+ processed_parts.append(None)
267
+ else:
268
+ words = part.split()
269
+ processed_words = []
270
+
271
+ for word in words:
272
+ match = re.match(r'^(\W*)(.*?)(\W*)$', word)
273
+ pre, core, suf = match.groups() if match else ("", word, "")
274
+
275
+ if not core:
276
+ processed_words.append(word)
277
+ elif core in phoneme_dict:
278
+ processed_words.append(f"{pre}{phoneme_dict[core]}{suf}")
279
+ else:
280
+ all_vi_texts.append(word)
281
+ all_vi_maps.append((text_idx, part_idx, len(processed_words)))
282
+ processed_words.append(None)
283
+
284
+ processed_parts.append(processed_words)
285
+
286
+ results.append(processed_parts)
287
+
288
+ if all_en_texts:
289
+ try:
290
+ en_phonemes = phonemize(
291
+ all_en_texts,
292
+ language='en-us',
293
+ backend='espeak',
294
+ preserve_punctuation=True,
295
+ with_stress=True,
296
+ language_switch="remove-flags"
297
+ )
298
+
299
+ if isinstance(en_phonemes, str):
300
+ en_phonemes = [en_phonemes]
301
+
302
+ for (text_idx, part_idx), phoneme in zip(all_en_maps, en_phonemes):
303
+ results[text_idx][part_idx] = phoneme.strip()
304
+ except Exception as e:
305
+ print(f"Warning: Batch EN phonemization failed: {e}")
306
+
307
+ if all_vi_texts:
308
+ try:
309
+ vi_phonemes = phonemize(
310
+ all_vi_texts,
311
+ language='vi',
312
+ backend='espeak',
313
+ preserve_punctuation=True,
314
+ with_stress=True,
315
+ language_switch='remove-flags'
316
+ )
317
+
318
+ if isinstance(vi_phonemes, str):
319
+ vi_phonemes = [vi_phonemes]
320
+
321
+ for idx, (text_idx, part_idx, word_idx) in enumerate(all_vi_maps):
322
+ phoneme = vi_phonemes[idx].strip()
323
+
324
+ original_word = all_vi_texts[idx]
325
+ if original_word.lower().startswith('r'):
326
+ phoneme = 'ɹ' + phoneme[1:] if len(phoneme) > 0 else phoneme
327
+
328
+ phoneme_dict[original_word] = phoneme
329
+ results[text_idx][part_idx][word_idx] = phoneme
330
+ except Exception as e:
331
+ print(f"Warning: Batch VI phonemization failed: {e}")
332
+
333
+ final_results = []
334
+ for processed_parts in results:
335
+ final_parts = []
336
+ for part in processed_parts:
337
+ if isinstance(part, list):
338
+ final_parts.append(' '.join(str(w) for w in part if w is not None))
339
+ elif part is not None:
340
+ final_parts.append(part)
341
+
342
+ result = ' '.join(final_parts)
343
+ result = re.sub(r'\s+([.,!?;:])', r'\1', result)
344
+ final_results.append(result)
345
+
346
+ return final_results
vieneu_tts.py ADDED
@@ -0,0 +1,859 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Generator
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ from neucodec import NeuCodec, DistillNeuCodec
7
+ from utils.phonemize_text import phonemize_with_dict
8
+ from collections import defaultdict
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ import re
11
+ import gc
12
+
13
+ # ============================================================================
14
+ # Shared Utilities
15
+ # ============================================================================
16
+
17
+ def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
18
+ """Linear overlap-add for smooth audio concatenation"""
19
+ assert len(frames)
20
+ dtype = frames[0].dtype
21
+ shape = frames[0].shape[:-1]
22
+
23
+ total_size = 0
24
+ for i, frame in enumerate(frames):
25
+ frame_end = stride * i + frame.shape[-1]
26
+ total_size = max(total_size, frame_end)
27
+
28
+ sum_weight = np.zeros(total_size, dtype=dtype)
29
+ out = np.zeros(*shape, total_size, dtype=dtype)
30
+
31
+ offset: int = 0
32
+ for frame in frames:
33
+ frame_length = frame.shape[-1]
34
+ t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
35
+ weight = np.abs(0.5 - (t - 0.5))
36
+
37
+ out[..., offset : offset + frame_length] += weight * frame
38
+ sum_weight[offset : offset + frame_length] += weight
39
+ offset += stride
40
+ assert sum_weight.min() > 0
41
+ return out / sum_weight
42
+
43
+
44
+ def _compile_codec_with_triton(codec):
45
+ """Compile codec with Triton for faster decoding (Windows/Linux compatible)"""
46
+ try:
47
+ import triton
48
+
49
+ if hasattr(codec, 'dec') and hasattr(codec.dec, 'resblocks'):
50
+ if len(codec.dec.resblocks) > 2:
51
+ codec.dec.resblocks[2].forward = torch.compile(
52
+ codec.dec.resblocks[2].forward,
53
+ mode="reduce-overhead",
54
+ dynamic=True
55
+ )
56
+ print(" ✅ Triton compilation enabled for codec")
57
+ return True
58
+
59
+ except ImportError:
60
+ print(" ⚠️ Triton not found. Install for faster speed:")
61
+ print(" • Linux: pip install triton")
62
+ print(" • Windows: pip install triton-windows")
63
+ print(" (Optional but recommended)")
64
+ return False
65
+
66
+
67
+ # ============================================================================
68
+ # VieNeuTTS - Standard implementation (CPU/GPU compatible)
69
+ # Supports: PyTorch Transformers, GGUF/GGML quantized models
70
+ # ============================================================================
71
+
72
+ class VieNeuTTS:
73
+ """
74
+ Standard VieNeu-TTS implementation.
75
+
76
+ Supports:
77
+ - PyTorch + Transformers backend (CPU/GPU)
78
+ - GGUF quantized models via llama-cpp-python (CPU optimized)
79
+
80
+ Use this for:
81
+ - CPU-only environments
82
+ - Standard PyTorch workflows
83
+ - GGUF quantized models
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
89
+ backbone_device="cpu",
90
+ codec_repo="neuphonic/neucodec",
91
+ codec_device="cpu",
92
+ ):
93
+ """
94
+ Initialize VieNeu-TTS.
95
+
96
+ Args:
97
+ backbone_repo: Model repository or path to GGUF file
98
+ backbone_device: Device for backbone ('cpu', 'cuda', 'gpu')
99
+ codec_repo: Codec repository
100
+ codec_device: Device for codec
101
+ """
102
+
103
+ # Constants
104
+ self.sample_rate = 24_000
105
+ self.max_context = 2048
106
+ self.hop_length = 480
107
+ self.streaming_overlap_frames = 1
108
+ self.streaming_frames_per_chunk = 25
109
+ self.streaming_lookforward = 5
110
+ self.streaming_lookback = 50
111
+ self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
112
+
113
+ # Flags
114
+ self._is_quantized_model = False
115
+ self._is_onnx_codec = False
116
+
117
+ # HF tokenizer
118
+ self.tokenizer = None
119
+
120
+ # Load models
121
+ self._load_backbone(backbone_repo, backbone_device)
122
+ self._load_codec(codec_repo, codec_device)
123
+
124
+ def _load_backbone(self, backbone_repo, backbone_device):
125
+ # MPS device validation
126
+ if backbone_device == "mps":
127
+ if not torch.backends.mps.is_available():
128
+ print("Warning: MPS not available, falling back to CPU")
129
+ backbone_device = "cpu"
130
+
131
+ print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
132
+
133
+ if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
134
+ try:
135
+ from llama_cpp import Llama
136
+ except ImportError as e:
137
+ raise ImportError(
138
+ "Failed to import `llama_cpp`. "
139
+ "Xem hướng dẫn cài đặt llama_cpp_python phiên bản tối thiểu 0.3.16 tại: https://llama-cpp-python.readthedocs.io/en/latest/"
140
+ ) from e
141
+ self.backbone = Llama.from_pretrained(
142
+ repo_id=backbone_repo,
143
+ filename="*.gguf",
144
+ verbose=False,
145
+ n_gpu_layers=-1 if backbone_device == "gpu" else 0,
146
+ n_ctx=self.max_context,
147
+ mlock=True,
148
+ flash_attn=True if backbone_device == "gpu" else False,
149
+ )
150
+ self._is_quantized_model = True
151
+
152
+ else:
153
+ from transformers import AutoTokenizer, AutoModelForCausalLM
154
+ self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
155
+ self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to(
156
+ torch.device(backbone_device)
157
+ )
158
+
159
+ def _load_codec(self, codec_repo, codec_device):
160
+ # MPS device validation
161
+ if codec_device == "mps":
162
+ if not torch.backends.mps.is_available():
163
+ print("Warning: MPS not available for codec, falling back to CPU")
164
+ codec_device = "cpu"
165
+
166
+ print(f"Loading codec from: {codec_repo} on {codec_device} ...")
167
+ match codec_repo:
168
+ case "neuphonic/neucodec":
169
+ self.codec = NeuCodec.from_pretrained(codec_repo)
170
+ self.codec.eval().to(codec_device)
171
+ case "neuphonic/distill-neucodec":
172
+ self.codec = DistillNeuCodec.from_pretrained(codec_repo)
173
+ self.codec.eval().to(codec_device)
174
+ case "neuphonic/neucodec-onnx-decoder-int8":
175
+ if codec_device != "cpu":
176
+ raise ValueError("Onnx decoder only currently runs on CPU.")
177
+ try:
178
+ from neucodec import NeuCodecOnnxDecoder
179
+ except ImportError as e:
180
+ raise ImportError(
181
+ "Failed to import the onnx decoder."
182
+ "Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
183
+ ) from e
184
+ self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
185
+ self._is_onnx_codec = True
186
+ case _:
187
+ raise ValueError(f"Unsupported codec repository: {codec_repo}")
188
+
189
+ def encode_reference(self, ref_audio_path: str | Path):
190
+ """Encode reference audio to codes"""
191
+ wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
192
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
193
+ with torch.no_grad():
194
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
195
+ return ref_codes
196
+
197
+ def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
198
+ """
199
+ Perform inference to generate speech from text using the TTS model and reference audio.
200
+
201
+ Args:
202
+ text (str): Input text to be converted to speech.
203
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
204
+ ref_text (str): Reference text for reference audio.
205
+ Returns:
206
+ np.ndarray: Generated speech waveform.
207
+ """
208
+
209
+ # Generate tokens
210
+ if self._is_quantized_model:
211
+ output_str = self._infer_ggml(ref_codes, ref_text, text)
212
+ else:
213
+ prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
214
+ output_str = self._infer_torch(prompt_ids)
215
+
216
+ # Decode
217
+ wav = self._decode(output_str)
218
+
219
+ return wav
220
+
221
+ def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
222
+ """
223
+ Perform streaming inference to generate speech from text using the TTS model and reference audio.
224
+
225
+ Args:
226
+ text (str): Input text to be converted to speech.
227
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
228
+ ref_text (str): Reference text for reference audio.
229
+ Yields:
230
+ np.ndarray: Generated speech waveform.
231
+ """
232
+
233
+ if self._is_quantized_model:
234
+ return self._infer_stream_ggml(ref_codes, ref_text, text)
235
+ else:
236
+ raise NotImplementedError("Streaming is not implemented for the torch backend!")
237
+
238
+ def _decode(self, codes: str):
239
+ """Decode speech tokens to audio waveform."""
240
+ # Extract speech token IDs using regex
241
+ speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
242
+
243
+ if len(speech_ids) == 0:
244
+ raise ValueError(
245
+ "No valid speech tokens found in the output. "
246
+ "The model may not have generated proper speech tokens."
247
+ )
248
+
249
+ # Onnx decode
250
+ if self._is_onnx_codec:
251
+ codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
252
+ recon = self.codec.decode_code(codes)
253
+ # Torch decode
254
+ else:
255
+ with torch.no_grad():
256
+ codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
257
+ self.codec.device
258
+ )
259
+ recon = self.codec.decode_code(codes).cpu().numpy()
260
+
261
+ return recon[0, 0, :]
262
+
263
+ def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
264
+ input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
265
+
266
+ speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
267
+ speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
268
+ text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
269
+ text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
270
+ text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
271
+
272
+ input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
273
+ chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
274
+ ids = self.tokenizer.encode(chat)
275
+
276
+ text_replace_idx = ids.index(text_replace)
277
+ ids = (
278
+ ids[:text_replace_idx]
279
+ + [text_prompt_start]
280
+ + input_ids
281
+ + [text_prompt_end]
282
+ + ids[text_replace_idx + 1 :] # noqa
283
+ )
284
+
285
+ speech_replace_idx = ids.index(speech_replace)
286
+ codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
287
+ codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
288
+ ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
289
+
290
+ return ids
291
+
292
+ def _infer_torch(self, prompt_ids: list[int]) -> str:
293
+ prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
294
+ speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
295
+ with torch.no_grad():
296
+ output_tokens = self.backbone.generate(
297
+ prompt_tensor,
298
+ max_length=self.max_context,
299
+ eos_token_id=speech_end_id,
300
+ do_sample=True,
301
+ temperature=0.7,
302
+ top_k=50,
303
+ use_cache=True,
304
+ min_new_tokens=50,
305
+ )
306
+ input_length = prompt_tensor.shape[-1]
307
+ output_str = self.tokenizer.decode(
308
+ output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
309
+ )
310
+ return output_str
311
+
312
+ def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
313
+ ref_text = phonemize_with_dict(ref_text)
314
+ input_text = phonemize_with_dict(input_text)
315
+
316
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
317
+ prompt = (
318
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
319
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
320
+ )
321
+ output = self.backbone(
322
+ prompt,
323
+ max_tokens=self.max_context,
324
+ temperature=0.7,
325
+ top_k=50,
326
+ stop=["<|SPEECH_GENERATION_END|>"],
327
+ )
328
+ output_str = output["choices"][0]["text"]
329
+ return output_str
330
+
331
+ def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
332
+ ref_text = phonemize_with_dict(ref_text)
333
+ input_text = phonemize_with_dict(input_text)
334
+
335
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
336
+ prompt = (
337
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
338
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
339
+ )
340
+
341
+ audio_cache: list[np.ndarray] = []
342
+ token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
343
+ n_decoded_samples: int = 0
344
+ n_decoded_tokens: int = len(ref_codes)
345
+
346
+ for item in self.backbone(
347
+ prompt,
348
+ max_tokens=self.max_context,
349
+ temperature=0.7,
350
+ top_k=50,
351
+ stop=["<|SPEECH_GENERATION_END|>"],
352
+ stream=True
353
+ ):
354
+ output_str = item["choices"][0]["text"]
355
+ token_cache.append(output_str)
356
+
357
+ if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
358
+
359
+ # decode chunk
360
+ tokens_start = max(
361
+ n_decoded_tokens
362
+ - self.streaming_lookback
363
+ - self.streaming_overlap_frames,
364
+ 0
365
+ )
366
+ tokens_end = (
367
+ n_decoded_tokens
368
+ + self.streaming_frames_per_chunk
369
+ + self.streaming_lookforward
370
+ + self.streaming_overlap_frames
371
+ )
372
+ sample_start = (
373
+ n_decoded_tokens - tokens_start
374
+ ) * self.hop_length
375
+ sample_end = (
376
+ sample_start
377
+ + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
378
+ )
379
+ curr_codes = token_cache[tokens_start:tokens_end]
380
+ recon = self._decode("".join(curr_codes))
381
+ recon = recon[sample_start:sample_end]
382
+ audio_cache.append(recon)
383
+
384
+ # postprocess
385
+ processed_recon = _linear_overlap_add(
386
+ audio_cache, stride=self.streaming_stride_samples
387
+ )
388
+ new_samples_end = len(audio_cache) * self.streaming_stride_samples
389
+ processed_recon = processed_recon[
390
+ n_decoded_samples:new_samples_end
391
+ ]
392
+ n_decoded_samples = new_samples_end
393
+ n_decoded_tokens += self.streaming_frames_per_chunk
394
+ yield processed_recon
395
+
396
+ # final decoding handled separately as non-constant chunk size
397
+ remaining_tokens = len(token_cache) - n_decoded_tokens
398
+ if len(token_cache) > n_decoded_tokens:
399
+ tokens_start = max(
400
+ len(token_cache)
401
+ - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
402
+ 0
403
+ )
404
+ sample_start = (
405
+ len(token_cache)
406
+ - tokens_start
407
+ - remaining_tokens
408
+ - self.streaming_overlap_frames
409
+ ) * self.hop_length
410
+ curr_codes = token_cache[tokens_start:]
411
+ recon = self._decode("".join(curr_codes))
412
+ recon = recon[sample_start:]
413
+ audio_cache.append(recon)
414
+
415
+ processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
416
+ processed_recon = processed_recon[n_decoded_samples:]
417
+ yield processed_recon
418
+
419
+
420
+ # ============================================================================
421
+ # FastVieNeuTTS - GPU-optimized implementation
422
+ # Requires: LMDeploy with CUDA
423
+ # ============================================================================
424
+
425
+ class FastVieNeuTTS:
426
+ """
427
+ GPU-optimized VieNeu-TTS using LMDeploy TurbomindEngine.
428
+ """
429
+
430
+ def __init__(
431
+ self,
432
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
433
+ backbone_device="cuda",
434
+ codec_repo="neuphonic/neucodec",
435
+ codec_device="cuda",
436
+ memory_util=0.3,
437
+ tp=1,
438
+ enable_prefix_caching=True,
439
+ quant_policy=0,
440
+ enable_triton=True,
441
+ max_batch_size=8,
442
+ ):
443
+ """
444
+ Initialize FastVieNeuTTS with LMDeploy backend and optimizations.
445
+
446
+ Args:
447
+ backbone_repo: Model repository
448
+ backbone_device: Device for backbone (must be CUDA)
449
+ codec_repo: Codec repository
450
+ codec_device: Device for codec
451
+ memory_util: GPU memory utilization (0.0-1.0)
452
+ tp: Tensor parallel size for multi-GPU
453
+ enable_prefix_caching: Enable prefix caching for faster batch processing
454
+ quant_policy: KV cache quantization (0=off, 8=int8, 4=int4)
455
+ enable_triton: Enable Triton compilation for codec
456
+ max_batch_size: Maximum batch size for inference (prevent GPU overload)
457
+ """
458
+
459
+ if backbone_device != "cuda" and not backbone_device.startswith("cuda:"):
460
+ raise ValueError("LMDeploy backend requires CUDA device")
461
+
462
+ # Constants
463
+ self.sample_rate = 24_000
464
+ self.max_context = 2048
465
+ self.hop_length = 480
466
+ self.streaming_overlap_frames = 1
467
+ self.streaming_frames_per_chunk = 50
468
+ self.streaming_lookforward = 5
469
+ self.streaming_lookback = 50
470
+ self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
471
+
472
+ self.max_batch_size = max_batch_size
473
+
474
+ self._ref_cache = {}
475
+
476
+ self.stored_dict = defaultdict(dict)
477
+
478
+ # Flags
479
+ self._is_onnx_codec = False
480
+ self._triton_enabled = False
481
+
482
+ # Load models
483
+ self._load_backbone_lmdeploy(backbone_repo, memory_util, tp, enable_prefix_caching, quant_policy)
484
+ self._load_codec(codec_repo, codec_device, enable_triton)
485
+
486
+ self._warmup_model()
487
+
488
+ print("✅ FastVieNeuTTS with optimizations loaded successfully!")
489
+ print(f" Max batch size: {self.max_batch_size} (adjustable to prevent GPU overload)")
490
+
491
+ def _load_backbone_lmdeploy(self, repo, memory_util, tp, enable_prefix_caching, quant_policy):
492
+ """Load backbone using LMDeploy's TurbomindEngine"""
493
+ print(f"Loading backbone with LMDeploy from: {repo}")
494
+
495
+ try:
496
+ from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
497
+ except ImportError as e:
498
+ raise ImportError(
499
+ "Failed to import `lmdeploy`. "
500
+ "Xem hướng dẫn cài đặt lmdeploy để tối ưu hiệu suất GPU tại: https://github.com/pnnbao97/VieNeu-TTS"
501
+ ) from e
502
+
503
+ backend_config = TurbomindEngineConfig(
504
+ cache_max_entry_count=memory_util,
505
+ tp=tp,
506
+ enable_prefix_caching=enable_prefix_caching,
507
+ dtype='bfloat16',
508
+ quant_policy=quant_policy
509
+ )
510
+
511
+ self.backbone = pipeline(repo, backend_config=backend_config)
512
+
513
+ self.gen_config = GenerationConfig(
514
+ top_p=0.95,
515
+ top_k=50,
516
+ temperature=0.7,
517
+ max_new_tokens=2048,
518
+ do_sample=True,
519
+ min_new_tokens=40,
520
+ )
521
+
522
+ print(f" LMDeploy TurbomindEngine initialized")
523
+ print(f" - Memory util: {memory_util}")
524
+ print(f" - Tensor Parallel: {tp}")
525
+ print(f" - Prefix caching: {enable_prefix_caching}")
526
+ print(f" - KV quant: {quant_policy} ({'Enabled' if quant_policy > 0 else 'Disabled'})")
527
+
528
+ def _load_codec(self, codec_repo, codec_device, enable_triton):
529
+ """Load codec with optional Triton compilation"""
530
+ print(f"Loading codec from: {codec_repo} on {codec_device}")
531
+
532
+ match codec_repo:
533
+ case "neuphonic/neucodec":
534
+ self.codec = NeuCodec.from_pretrained(codec_repo)
535
+ self.codec.eval().to(codec_device)
536
+ case "neuphonic/distill-neucodec":
537
+ self.codec = DistillNeuCodec.from_pretrained(codec_repo)
538
+ self.codec.eval().to(codec_device)
539
+ case "neuphonic/neucodec-onnx-decoder-int8":
540
+ if codec_device != "cpu":
541
+ raise ValueError("ONNX decoder only runs on CPU")
542
+ try:
543
+ from neucodec import NeuCodecOnnxDecoder
544
+ except ImportError as e:
545
+ raise ImportError(
546
+ "Failed to import ONNX decoder. "
547
+ "Ensure onnxruntime and neucodec >= 0.0.4 are installed."
548
+ ) from e
549
+ self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
550
+ self._is_onnx_codec = True
551
+ case _:
552
+ raise ValueError(f"Unsupported codec repository: {codec_repo}")
553
+
554
+ if enable_triton and not self._is_onnx_codec and codec_device != "cpu":
555
+ self._triton_enabled = _compile_codec_with_triton(self.codec)
556
+
557
+ def _warmup_model(self):
558
+ """Warmup inference pipeline to reduce first-token latency"""
559
+ print("🔥 Warming up model...")
560
+ try:
561
+ dummy_codes = list(range(10))
562
+ dummy_prompt = self._format_prompt(dummy_codes, "warmup", "test")
563
+ _ = self.backbone([dummy_prompt], gen_config=self.gen_config, do_preprocess=False)
564
+ print(" ✅ Warmup complete")
565
+ except Exception as e:
566
+ print(f" ⚠️ Warmup failed (non-critical): {e}")
567
+
568
+ def encode_reference(self, ref_audio_path: str | Path):
569
+ """Encode reference audio to codes"""
570
+ wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
571
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
572
+ with torch.no_grad():
573
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
574
+ return ref_codes
575
+
576
+ def get_cached_reference(self, voice_name: str, audio_path: str, ref_text: str = None):
577
+ """
578
+ Get or create cached reference codes.
579
+
580
+ Args:
581
+ voice_name: Unique identifier for this voice
582
+ audio_path: Path to reference audio
583
+ ref_text: Optional reference text (stored with codes)
584
+
585
+ Returns:
586
+ ref_codes: Encoded reference codes
587
+ """
588
+ cache_key = f"{voice_name}_{audio_path}"
589
+
590
+ if cache_key not in self._ref_cache:
591
+ ref_codes = self.encode_reference(audio_path)
592
+ self._ref_cache[cache_key] = {
593
+ 'codes': ref_codes,
594
+ 'ref_text': ref_text
595
+ }
596
+
597
+ return self._ref_cache[cache_key]['codes']
598
+
599
+ def add_speaker(self, user_id: int, audio_file: str, ref_text: str):
600
+ """
601
+ Add a speaker to the stored dictionary for easy access.
602
+
603
+ Args:
604
+ user_id: Unique user ID
605
+ audio_file: Reference audio file path
606
+ ref_text: Reference text
607
+
608
+ Returns:
609
+ user_id: The user ID for use in streaming
610
+ """
611
+ codes = self.encode_reference(audio_file)
612
+
613
+ if isinstance(codes, torch.Tensor):
614
+ codes = codes.cpu().numpy()
615
+ if isinstance(codes, np.ndarray):
616
+ codes = codes.flatten().tolist()
617
+
618
+ self.stored_dict[f"{user_id}"]['codes'] = codes
619
+ self.stored_dict[f"{user_id}"]['ref_text'] = ref_text
620
+
621
+ return user_id
622
+
623
+ def _decode(self, codes: str):
624
+ """Decode speech tokens to audio waveform"""
625
+ speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
626
+
627
+ if len(speech_ids) == 0:
628
+ raise ValueError("No valid speech tokens found in output")
629
+
630
+ if self._is_onnx_codec:
631
+ codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
632
+ recon = self.codec.decode_code(codes)
633
+ else:
634
+ with torch.no_grad():
635
+ codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
636
+ self.codec.device
637
+ )
638
+ recon = self.codec.decode_code(codes).cpu().numpy()
639
+
640
+ return recon[0, 0, :]
641
+
642
+ def _decode_batch(self, codes_list: list[str], max_workers: int = None):
643
+ """
644
+ Decode multiple code strings in parallel.
645
+
646
+ Args:
647
+ codes_list: List of code strings to decode
648
+ max_workers: Number of parallel workers (auto-tuned if None)
649
+
650
+ Returns:
651
+ List of decoded audio arrays
652
+ """
653
+ # Auto-tune workers based on GPU memory and batch size
654
+ if max_workers is None:
655
+ if torch.cuda.is_available():
656
+ gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
657
+ # 1 worker per 4GB VRAM, max 4 workers
658
+ max_workers = min(max(1, int(gpu_mem_gb / 4)), 4)
659
+ else:
660
+ max_workers = 2
661
+
662
+ # For small batches, use sequential to avoid overhead
663
+ if len(codes_list) <= 2:
664
+ return [self._decode(codes) for codes in codes_list]
665
+
666
+ # Parallel decoding with controlled workers
667
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
668
+ futures = [executor.submit(self._decode, codes) for codes in codes_list]
669
+ results = [f.result() for f in futures]
670
+ return results
671
+
672
+ def _format_prompt(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
673
+ """Format prompt for LMDeploy"""
674
+ ref_text_phones = phonemize_with_dict(ref_text)
675
+ input_text_phones = phonemize_with_dict(input_text)
676
+
677
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
678
+
679
+ prompt = (
680
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text_phones} {input_text_phones}"
681
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
682
+ )
683
+
684
+ return prompt
685
+
686
+ def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
687
+ """
688
+ Single inference.
689
+
690
+ Args:
691
+ text: Input text to synthesize
692
+ ref_codes: Encoded reference audio codes
693
+ ref_text: Reference text for reference audio
694
+
695
+ Returns:
696
+ Generated speech waveform as numpy array
697
+ """
698
+ if isinstance(ref_codes, torch.Tensor):
699
+ ref_codes = ref_codes.cpu().numpy()
700
+ if isinstance(ref_codes, np.ndarray):
701
+ ref_codes = ref_codes.flatten().tolist()
702
+
703
+ prompt = self._format_prompt(ref_codes, ref_text, text)
704
+
705
+ # Use LMDeploy pipeline for generation
706
+ responses = self.backbone([prompt], gen_config=self.gen_config, do_preprocess=False)
707
+ output_str = responses[0].text
708
+
709
+ # Decode to audio
710
+ wav = self._decode(output_str)
711
+
712
+ return wav
713
+
714
+ def infer_batch(self, texts: list[str], ref_codes: np.ndarray | torch.Tensor, ref_text: str, max_batch_size: int = None) -> list[np.ndarray]:
715
+ """
716
+ Batch inference for multiple texts.
717
+ """
718
+ if max_batch_size is None:
719
+ max_batch_size = self.max_batch_size
720
+
721
+ if not isinstance(texts, list):
722
+ texts = [texts]
723
+
724
+ if isinstance(ref_codes, torch.Tensor):
725
+ ref_codes = ref_codes.cpu().numpy()
726
+ if isinstance(ref_codes, np.ndarray):
727
+ ref_codes = ref_codes.flatten().tolist()
728
+
729
+ all_wavs = []
730
+
731
+ for i in range(0, len(texts), max_batch_size):
732
+ batch_texts = texts[i:i+max_batch_size]
733
+ prompts = [self._format_prompt(ref_codes, ref_text, text) for text in batch_texts]
734
+ responses = self.backbone(prompts, gen_config=self.gen_config, do_preprocess=False)
735
+ batch_codes = [response.text for response in responses]
736
+
737
+ if len(batch_codes) > 3:
738
+ batch_wavs = self._decode_batch(batch_codes)
739
+ else:
740
+ batch_wavs = [self._decode(codes) for codes in batch_codes]
741
+ all_wavs.extend(batch_wavs)
742
+
743
+ if i + max_batch_size < len(texts):
744
+ if torch.cuda.is_available():
745
+ torch.cuda.empty_cache()
746
+
747
+ return all_wavs
748
+
749
+ def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
750
+ """
751
+ Streaming inference with low latency.
752
+
753
+ Args:
754
+ text: Input text to synthesize
755
+ ref_codes: Encoded reference audio codes
756
+ ref_text: Reference text for reference audio
757
+
758
+ Yields:
759
+ Audio chunks as numpy arrays
760
+ """
761
+ if isinstance(ref_codes, torch.Tensor):
762
+ ref_codes = ref_codes.cpu().numpy()
763
+ if isinstance(ref_codes, np.ndarray):
764
+ ref_codes = ref_codes.flatten().tolist()
765
+
766
+ prompt = self._format_prompt(ref_codes, ref_text, text)
767
+
768
+ audio_cache = []
769
+ token_cache = [f"<|speech_{idx}|>" for idx in ref_codes]
770
+ n_decoded_samples = 0
771
+ n_decoded_tokens = len(ref_codes)
772
+
773
+ for response in self.backbone.stream_infer([prompt], gen_config=self.gen_config, do_preprocess=False):
774
+ output_str = response.text
775
+
776
+ # Extract new tokens
777
+ new_tokens = output_str[len("".join(token_cache[len(ref_codes):])):] if len(token_cache) > len(ref_codes) else output_str
778
+
779
+ if new_tokens:
780
+ token_cache.append(new_tokens)
781
+
782
+ # Check if we have enough tokens to decode a chunk
783
+ if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
784
+
785
+ # Decode chunk with context
786
+ tokens_start = max(
787
+ n_decoded_tokens - self.streaming_lookback - self.streaming_overlap_frames,
788
+ 0
789
+ )
790
+ tokens_end = (
791
+ n_decoded_tokens
792
+ + self.streaming_frames_per_chunk
793
+ + self.streaming_lookforward
794
+ + self.streaming_overlap_frames
795
+ )
796
+ sample_start = (n_decoded_tokens - tokens_start) * self.hop_length
797
+ sample_end = (
798
+ sample_start
799
+ + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
800
+ )
801
+
802
+ curr_codes = token_cache[tokens_start:tokens_end]
803
+ recon = self._decode("".join(curr_codes))
804
+ recon = recon[sample_start:sample_end]
805
+ audio_cache.append(recon)
806
+
807
+ # Overlap-add processing
808
+ processed_recon = _linear_overlap_add(
809
+ audio_cache, stride=self.streaming_stride_samples
810
+ )
811
+ new_samples_end = len(audio_cache) * self.streaming_stride_samples
812
+ processed_recon = processed_recon[n_decoded_samples:new_samples_end]
813
+ n_decoded_samples = new_samples_end
814
+ n_decoded_tokens += self.streaming_frames_per_chunk
815
+
816
+ yield processed_recon
817
+
818
+ # Final chunk
819
+ remaining_tokens = len(token_cache) - n_decoded_tokens
820
+ if remaining_tokens > 0:
821
+ tokens_start = max(
822
+ len(token_cache) - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
823
+ 0
824
+ )
825
+ sample_start = (
826
+ len(token_cache) - tokens_start - remaining_tokens - self.streaming_overlap_frames
827
+ ) * self.hop_length
828
+
829
+ curr_codes = token_cache[tokens_start:]
830
+ recon = self._decode("".join(curr_codes))
831
+ recon = recon[sample_start:]
832
+ audio_cache.append(recon)
833
+
834
+ processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
835
+ processed_recon = processed_recon[n_decoded_samples:]
836
+ yield processed_recon
837
+
838
+ def cleanup_memory(self):
839
+ """Clean up GPU memory"""
840
+ if torch.cuda.is_available():
841
+ torch.cuda.empty_cache()
842
+ gc.collect()
843
+ print("🧹 Memory cleaned up")
844
+
845
+ def get_optimization_stats(self) -> dict:
846
+ """
847
+ Get current optimization statistics.
848
+
849
+ Returns:
850
+ Dictionary with optimization info
851
+ """
852
+ return {
853
+ 'triton_enabled': self._triton_enabled,
854
+ 'max_batch_size': self.max_batch_size,
855
+ 'cached_references': len(self._ref_cache),
856
+ 'active_sessions': len(self.stored_dict),
857
+ 'kv_quant': self.gen_config.__dict__.get('quant_policy', 0),
858
+ 'prefix_caching': True, # Always enabled in our config
859
+ }