voicebox项目中qwen-tts模型离线加载问题解决方案
《Voicebox语音克隆工作室Qwen-TTS模型离线加载问题解决方案》摘要: 本文详细记录了解决HuggingFace模型离线加载问题的完整过程。项目Voicebox在模型已缓存的情况下仍访问网络下载文件,通过分析发现Qwen-TTS库存在强制网络验证问题。解决方案采用双阶段方法:先尝试补丁式修复(设置环境变量、禁用SSL验证),后深入分析HuggingFace缓存机制,最终通过直接加载本地快
Voicebox是一个本地优先的语音克隆工作室,具有类似DAW的功能,用于专业的语音合成。将其视为本地、免费和开源的替代品,类似于ElevenLabs——在您的机器上下载模型、克隆语音和生成语音。
我下载了项目源代码,并根据项目要求安装并配置好bun、rust和python环境。当使用bun run dev:server和bun run tauri dev指令在两个分开的控制台运行项目时,在模型已经完全下载到本地缓存的情况下,项目服务端仍在访问HuggingFace下载模型文件。在Trae IDE中,花了近一个晚上的时间,经数个阶段,在Trae后端AI的帮助下,解决了qwen-tts模型离线加载的问题。下述内容,是由Trae整理的我在Trae中解决问题的全过程,现分享给朋友们。
一、Qwen-tts模型离线加载问题解决方案
问题背景
初始需求
项目在运行时,模型已经下载到本地缓存,但仍然需要访问 HuggingFace 网站才能运行。用户要求:
- 下载模型时禁用证书验证
- 下载模型时启用 HuggingFace 国内镜像
- 在模型已下载到本地的情况下,不再重复访问网络下载模型相关文件
环境信息
- 操作系统:Windows
- 模型缓存路径:
C:\Users\{username}\.cache\huggingface\hub\ - 涉及模型:Qwen3-TTS-12Hz-1.7B-Base、Qwen3-TTS-12Hz-0.6B-Base、Whisper
问题分析过程
第一阶段:补丁式方法(自下而上)
1.1 初步尝试
创建 hf_config.py 模块,尝试通过环境变量控制 HuggingFace Hub 行为:
# 设置离线模式
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
# 设置国内镜像
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# 禁用 SSL 验证
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
结果:部分有效,但 qwen_tts 库内部仍尝试网络访问
1.2 遇到的问题
问题一:FutureWarning 警告
FutureWarning: You are using `torch.load` with `weights_only=False`
解决:添加 weights_only=False 参数并抑制警告
问题二:SSL 警告
InsecureRequestWarning: Unverified HTTPS request
解决:添加 SSL 警告抑制
问题三:离线模式错误
Error: Cannot reach `https://hf-mirror.com/api/models/...`
offline mode is enabled.
原因:qwen_tts 库需要网络验证,即使设置了离线模式
1.3 发现 qwen_tts 库的问题
查看 qwen_tts 库源码发现:
# qwen_tts/inference/qwen3_tts_model.py
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True,)
# 问题:processor 加载没有传递 **kwargs
尝试解决:创建 qwen_tts_patch.py 补丁,确保 local_files_only=True 传递给 processor
1.4 发现 .no_exist 目录问题
.no_exist 目录记录了哪些文件在远程仓库中不存在:
models--Qwen--Qwen3-TTS-12Hz-1.7B-Base/
└── .no_exist/
└── fd4b254389122332181a7c3db7f27e918eec64e3/
├── processor_config.json (0 bytes)
├── tokenizer.json (0 bytes)
└── ...
尝试解决:添加 clean_no_exist_cache() 函数清理该目录
结果:仍然无法完全阻止网络访问
第二阶段:从原理到问题(自上而下)
2.1 分析 HuggingFace 缓存机制
通过命令检查缓存目录结构:
# 检查 blobs 目录
Get-ChildItem "...\models--Qwen--Qwen3-TTS-12Hz-1.7B-Base\blobs"
# 发现大文件:38fc7fc51c5e... (3678.72 MB) - 实际模型权重
# 检查 snapshots 目录
Get-ChildItem "...\snapshots\fd4b254389122332181a7c3db7f27e918eec64e3"
# 发现符号链接:
# config.json → ../../blobs/xxx
# model.safetensors → ../../blobs/xxx
2.2 理解缓存结构
HuggingFace 缓存结构:
models--Qwen--Qwen3-TTS-12Hz-1.7B-Base/
├── blobs/ # 实际文件内容(按 SHA256 哈希命名)
│ ├── 38fc7fc51c5e... # 3.6GB 模型权重
│ ├── 836b7b357f5e... # 650MB 其他文件
│ └── ...
├── snapshots/ # 符号链接目录(按 commit hash 命名)
│ └── fd4b254389122332... # 快照目录
│ ├── config.json → ../../blobs/xxx # 符号链接
│ ├── model.safetensors → ../../blobs/xxx
│ ├── preprocessor_config.json → ../../blobs/xxx
│ └── ...
├── refs/ # 引用(main 等)
└── .no_exist/ # 记录不存在的文件
2.3 关键洞察
发现:
-
当使用 HuggingFace Hub ID(如
Qwen/Qwen3-TTS-12Hz-1.7B-Base)加载时,from_pretrained()会:- 解析 Hub ID
- 检查本地缓存
- 尝试验证远程仓库状态(即使设置了
local_files_only=True) - 可能触发网络请求
-
当使用本地路径(如
C:\Users\xxx\.cache\...\snapshots\xxx)加载时:- 直接读取本地文件
- 完全绕过 HuggingFace Hub 的网络验证机制
2.4 最终解决方案
修改模型加载逻辑:
def _get_local_snapshot_path(self, model_size: str) -> Optional[Path]:
"""
获取本地快照路径(如果模型已完全缓存)。
"""
try:
from huggingface_hub import constants as hf_constants
model_id = self._get_model_path(model_size)
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--"))
if not repo_cache.exists():
return None
# 检查是否有 .incomplete 文件
blobs_dir = repo_cache / "blobs"
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
return None
# 获取最新快照
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.exists():
return None
snapshot_dirs = sorted(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)
if not snapshot_dirs:
return None
latest_snapshot = snapshot_dirs[0]
# 检查模型权重是否存在
has_weights = (
any(latest_snapshot.rglob("*.safetensors")) or
any(latest_snapshot.rglob("*.bin"))
)
if not has_weights:
return None
# 检查 config.json 是否存在
if not (latest_snapshot / "config.json").exists():
return None
return latest_snapshot
except Exception as e:
print(f"[_get_local_snapshot_path] Error: {e}")
return None
def _load_model_sync(self, model_size: str):
"""同步模型加载。"""
# 获取本地快照路径
local_snapshot_path = self._get_local_snapshot_path(model_size)
is_cached = local_snapshot_path is not None
# 获取 HuggingFace Hub ID(用于下载)
model_id = self._get_model_path(model_size)
# 确定加载路径
# 如果已缓存,使用本地快照路径,避免任何网络访问
# 如果未缓存,使用 HuggingFace Hub ID 进行下载
load_path = str(local_snapshot_path) if is_cached else model_id
if is_cached:
print(f"Loading model {model_size} from local cache: {load_path}")
else:
print(f"Model {model_size} not cached, will download from HuggingFace Hub")
setup_huggingface_for_online()
# 加载模型
self.model = Qwen3TTSModel.from_pretrained(load_path, ...)
解决方案对比
| 方法 | 优点 | 缺点 | 结果 |
|---|---|---|---|
设置 HF_HUB_OFFLINE=1 |
简单 | 某些库仍尝试网络验证 | ❌ 失败 |
传递 local_files_only=True |
标准做法 | qwen_tts 库未传递给 processor |
❌ 失败 |
对 qwen_tts 打补丁 |
针对性强 | 维护成本高,可能影响其他库 | ⚠️ 部分有效 |
清理 .no_exist 目录 |
减少网络请求 | 无法完全阻止网络访问 | ⚠️ 部分有效 |
| 直接使用本地快照路径 | 完全离线,零网络访问 | 需要检测缓存完整性 | ✅ 成功 |
最终实现
修改的文件
-
backend/backends/pytorch_backend.py- 添加
_get_local_snapshot_path()方法 - 修改
_load_model_sync()方法(TTS 和 STT)
- 添加
-
backend/backends/mlx_backend.py- 同样的修改应用于 MLX 后端
-
backend/utils/hf_config.py- 保留国内镜像和 SSL 配置
- 移除了不再需要的离线模式设置
删除的文件
backend/utils/qwen_tts_patch.py- 不再需要补丁
方法论总结
自下而上 vs 自上而下
| 方法 | 适用场景 | 本次对话中的体现 |
|---|---|---|
| 自下而上(就问题解决问题) | 快速修复、临时方案、问题边界清晰 | 设置环境变量、打补丁、清理缓存 |
| 自上而下(从原理到问题) | 复杂系统、根本性问题、需要深入理解 | 分析 HuggingFace 缓存机制、理解加载路径本质 |
本次对话的启示
-
理解系统原理是解决复杂问题的关键
- HuggingFace 的缓存机制(blobs + snapshots + 符号链接)
from_pretrained()方法的行为差异(Hub ID vs 本地路径)
-
补丁式方法的局限性
- 只能解决表面问题
- 可能引入新的复杂性
- 难以维护
-
从原理出发的优势
- 找到根本解决方案
- 代码更简洁、更可靠
- 更容易维护
最佳实践
模型离线加载检查清单
-
检查缓存完整性
- blobs 目录存在且包含大文件
- snapshots 目录存在且包含符号链接
- 没有
.incomplete文件 -
config.json存在
-
使用本地路径加载
# 推荐:使用本地快照路径 load_path = str(local_snapshot_path) # 如 C:\Users\xxx\.cache\...\snapshots\xxx # 不推荐:使用 Hub ID(可能触发网络访问) load_path = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" -
环境配置(可选)
# 国内镜像(用于下载) os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" # 禁用 SSL 验证(如有需要) import ssl ssl._create_default_https_context = ssl._create_unverified_context
参考资料
文档创建时间:2026-03-10
问题解决状态:已解决 ✅
二、根据上述过程最终修改生成的代码文件如下
1. backend/main.py
"""
FastAPI application for voicebox backend.
Handles voice cloning, generation history, and server mode.
"""
# 在最开始抑制警告
from .utils.warning_suppressor import suppress_common_warnings, suppress_ssl_warnings
suppress_common_warnings()
suppress_ssl_warnings()
from fastapi import FastAPI, Depends, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy.orm import Session
from typing import List, Optional
from datetime import datetime
import asyncio
import uvicorn
import argparse
import torch
import tempfile
import io
from pathlib import Path
import uuid
import asyncio
import signal
import os
from urllib.parse import quote
def _safe_content_disposition(disposition_type: str, filename: str) -> str:
"""Build a Content-Disposition header that is safe for non-ASCII filenames.
Uses RFC 5987 ``filename*`` parameter so that browsers can decode
UTF-8 filenames while the ``filename`` fallback stays ASCII-only.
"""
ascii_name = "".join(
c for c in filename if c.isascii() and (c.isalnum() or c in " -_.")
).strip() or "download"
utf8_name = quote(filename, safe="")
return (
f'{disposition_type}; filename="{ascii_name}"; '
f"filename*=UTF-8''{utf8_name}"
)
from . import database, models, profiles, history, tts, transcribe, config, export_import, channels, stories, __version__
from .database import get_db, Generation as DBGeneration, VoiceProfile as DBVoiceProfile
from .utils.progress import get_progress_manager
from .utils.tasks import get_task_manager
from .utils.cache import clear_voice_prompt_cache
from .platform_detect import get_backend_type
app = FastAPI(
title="voicebox API",
description="Production-quality Qwen3-TTS voice cloning API",
version=__version__,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================
# ROOT & HEALTH ENDPOINTS
# ============================================
@app.get("/")
async def root():
"""Root endpoint."""
return {"message": "voicebox API", "version": __version__}
@app.post("/shutdown")
async def shutdown():
"""Gracefully shutdown the server."""
async def shutdown_async():
await asyncio.sleep(0.1) # Give response time to send
os.kill(os.getpid(), signal.SIGTERM)
asyncio.create_task(shutdown_async())
return {"message": "Shutting down..."}
@app.get("/health", response_model=models.HealthResponse)
async def health():
"""Health check endpoint."""
from huggingface_hub import hf_hub_download, constants as hf_constants
from pathlib import Path
import os
tts_model = tts.get_tts_model()
backend_type = get_backend_type()
# Check for GPU availability (CUDA, MPS, Intel Arc XPU, or DirectML)
has_cuda = torch.cuda.is_available()
has_mps = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
# Intel Arc / Intel Xe via intel-extension-for-pytorch (IPEX)
has_xpu = False
xpu_name = None
try:
import intel_extension_for_pytorch as ipex # noqa: F401
if hasattr(torch, 'xpu') and torch.xpu.is_available():
has_xpu = True
try:
xpu_name = torch.xpu.get_device_name(0)
except Exception:
xpu_name = "Intel GPU"
except ImportError:
pass
# DirectML backend (torch-directml) for any Windows GPU
has_directml = False
directml_name = None
try:
import torch_directml
if torch_directml.device_count() > 0:
has_directml = True
try:
directml_name = torch_directml.device_name(0)
except Exception:
directml_name = "DirectML GPU"
except ImportError:
pass
gpu_available = has_cuda or has_mps or has_xpu or has_directml or backend_type == "mlx"
gpu_type = None
if has_cuda:
gpu_type = f"CUDA ({torch.cuda.get_device_name(0)})"
elif has_mps:
gpu_type = "MPS (Apple Silicon)"
elif backend_type == "mlx":
gpu_type = "Metal (Apple Silicon via MLX)"
elif has_xpu:
gpu_type = f"XPU ({xpu_name})"
elif has_directml:
gpu_type = f"DirectML ({directml_name})"
vram_used = None
if has_cuda:
vram_used = torch.cuda.memory_allocated() / 1024 / 1024 # MB
# Check if model is loaded - use the same logic as model status endpoint
model_loaded = False
model_size = None
try:
# Use the same check as model status endpoint
if tts_model.is_loaded():
model_loaded = True
# Get the actual loaded model size
# Check _current_model_size first (more reliable for actually loaded models)
model_size = getattr(tts_model, '_current_model_size', None)
if not model_size:
# Fallback to model_size attribute (which should be set when model loads)
model_size = getattr(tts_model, 'model_size', None)
except Exception:
# If there's an error checking, assume not loaded
model_loaded = False
model_size = None
# Check if default model is downloaded (cached)
model_downloaded = None
try:
# Check if the default model (1.7B) is cached
# Use different model IDs based on backend
if backend_type == "mlx":
default_model_id = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
else:
default_model_id = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
# Method 1: Try scan_cache_dir if available
try:
from huggingface_hub import scan_cache_dir
cache_info = scan_cache_dir()
for repo in cache_info.repos:
if repo.repo_id == default_model_id:
model_downloaded = True
break
except (ImportError, Exception):
# Method 2: Check cache directory (using HuggingFace's OS-specific cache location)
cache_dir = hf_constants.HF_HUB_CACHE
repo_cache = Path(cache_dir) / ("models--" + default_model_id.replace("/", "--"))
if repo_cache.exists():
has_model_files = (
any(repo_cache.rglob("*.bin")) or
any(repo_cache.rglob("*.safetensors")) or
any(repo_cache.rglob("*.pt")) or
any(repo_cache.rglob("*.pth")) or
any(repo_cache.rglob("*.npz")) # MLX models may use npz
)
model_downloaded = has_model_files
except Exception:
pass
return models.HealthResponse(
status="healthy",
model_loaded=model_loaded,
model_downloaded=model_downloaded,
model_size=model_size,
gpu_available=gpu_available,
gpu_type=gpu_type,
vram_used_mb=vram_used,
backend_type=backend_type,
)
# ============================================
# VOICE PROFILE ENDPOINTS
# ============================================
@app.post("/profiles", response_model=models.VoiceProfileResponse)
async def create_profile(
data: models.VoiceProfileCreate,
db: Session = Depends(get_db),
):
"""Create a new voice profile."""
try:
return await profiles.create_profile(data, db)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/profiles", response_model=List[models.VoiceProfileResponse])
async def list_profiles(db: Session = Depends(get_db)):
"""List all voice profiles."""
return await profiles.list_profiles(db)
@app.post("/profiles/import", response_model=models.VoiceProfileResponse)
async def import_profile(
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Import a voice profile from a ZIP archive."""
# Validate file size (max 100MB)
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
# Read file content
content = await file.read()
if len(content) > MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
)
try:
profile = await export_import.import_profile_from_zip(content, db)
return profile
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
async def get_profile(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get a voice profile by ID."""
profile = await profiles.get_profile(profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
return profile
@app.put("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
async def update_profile(
profile_id: str,
data: models.VoiceProfileCreate,
db: Session = Depends(get_db),
):
"""Update a voice profile."""
profile = await profiles.update_profile(profile_id, data, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
return profile
@app.delete("/profiles/{profile_id}")
async def delete_profile(
profile_id: str,
db: Session = Depends(get_db),
):
"""Delete a voice profile."""
success = await profiles.delete_profile(profile_id, db)
if not success:
raise HTTPException(status_code=404, detail="Profile not found")
return {"message": "Profile deleted successfully"}
@app.post("/profiles/{profile_id}/samples", response_model=models.ProfileSampleResponse)
async def add_profile_sample(
profile_id: str,
file: UploadFile = File(...),
reference_text: str = Form(...),
db: Session = Depends(get_db),
):
"""Add a sample to a voice profile."""
# Preserve the uploaded file's extension so librosa can detect format correctly.
# Defaulting to .wav was causing soundfile to reject MP3/WebM content as invalid WAV.
_allowed_audio_exts = {'.wav', '.mp3', '.m4a', '.ogg', '.flac', '.aac', '.webm', '.opus'}
_uploaded_ext = Path(file.filename or '').suffix.lower()
file_suffix = _uploaded_ext if _uploaded_ext in _allowed_audio_exts else '.wav'
with tempfile.NamedTemporaryFile(suffix=file_suffix, delete=False) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
sample = await profiles.add_profile_sample(
profile_id,
tmp_path,
reference_text,
db,
)
return sample
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to process audio file: {str(e)}")
finally:
# Clean up temp file
Path(tmp_path).unlink(missing_ok=True)
@app.get("/profiles/{profile_id}/samples", response_model=List[models.ProfileSampleResponse])
async def get_profile_samples(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get all samples for a profile."""
return await profiles.get_profile_samples(profile_id, db)
@app.delete("/profiles/samples/{sample_id}")
async def delete_profile_sample(
sample_id: str,
db: Session = Depends(get_db),
):
"""Delete a profile sample."""
success = await profiles.delete_profile_sample(sample_id, db)
if not success:
raise HTTPException(status_code=404, detail="Sample not found")
return {"message": "Sample deleted successfully"}
@app.put("/profiles/samples/{sample_id}", response_model=models.ProfileSampleResponse)
async def update_profile_sample(
sample_id: str,
data: models.ProfileSampleUpdate,
db: Session = Depends(get_db),
):
"""Update a profile sample's reference text."""
sample = await profiles.update_profile_sample(sample_id, data.reference_text, db)
if not sample:
raise HTTPException(status_code=404, detail="Sample not found")
return sample
@app.post("/profiles/{profile_id}/avatar", response_model=models.VoiceProfileResponse)
async def upload_profile_avatar(
profile_id: str,
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Upload or update avatar image for a profile."""
# Save uploaded file to temp location
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
profile = await profiles.upload_avatar(profile_id, tmp_path, db)
return profile
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
finally:
# Clean up temp file
Path(tmp_path).unlink(missing_ok=True)
@app.get("/profiles/{profile_id}/avatar")
async def get_profile_avatar(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get avatar image for a profile."""
profile = await profiles.get_profile(profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
if not profile.avatar_path:
raise HTTPException(status_code=404, detail="No avatar found for this profile")
avatar_path = Path(profile.avatar_path)
if not avatar_path.exists():
raise HTTPException(status_code=404, detail="Avatar file not found")
return FileResponse(avatar_path)
@app.delete("/profiles/{profile_id}/avatar")
async def delete_profile_avatar(
profile_id: str,
db: Session = Depends(get_db),
):
"""Delete avatar image for a profile."""
success = await profiles.delete_avatar(profile_id, db)
if not success:
raise HTTPException(status_code=404, detail="Profile not found or no avatar to delete")
return {"message": "Avatar deleted successfully"}
@app.get("/profiles/{profile_id}/export")
async def export_profile(
profile_id: str,
db: Session = Depends(get_db),
):
"""Export a voice profile as a ZIP archive."""
try:
# Get profile to get name for filename
profile = await profiles.get_profile(profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
# Export to ZIP
zip_bytes = export_import.export_profile_to_zip(profile_id, db)
# Create safe filename
safe_name = "".join(c for c in profile.name if c.isalnum() or c in (' ', '-', '_')).strip()
if not safe_name:
safe_name = "profile"
filename = f"profile-{safe_name}.voicebox.zip"
# Return as streaming response
return StreamingResponse(
io.BytesIO(zip_bytes),
media_type="application/zip",
headers={
"Content-Disposition": _safe_content_disposition("attachment", filename)
}
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ============================================
# AUDIO CHANNEL ENDPOINTS
# ============================================
@app.get("/channels", response_model=List[models.AudioChannelResponse])
async def list_channels(db: Session = Depends(get_db)):
"""List all audio channels."""
return await channels.list_channels(db)
@app.post("/channels", response_model=models.AudioChannelResponse)
async def create_channel(
data: models.AudioChannelCreate,
db: Session = Depends(get_db),
):
"""Create a new audio channel."""
try:
return await channels.create_channel(data, db)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/channels/{channel_id}", response_model=models.AudioChannelResponse)
async def get_channel(
channel_id: str,
db: Session = Depends(get_db),
):
"""Get an audio channel by ID."""
channel = await channels.get_channel(channel_id, db)
if not channel:
raise HTTPException(status_code=404, detail="Channel not found")
return channel
@app.put("/channels/{channel_id}", response_model=models.AudioChannelResponse)
async def update_channel(
channel_id: str,
data: models.AudioChannelUpdate,
db: Session = Depends(get_db),
):
"""Update an audio channel."""
try:
channel = await channels.update_channel(channel_id, data, db)
if not channel:
raise HTTPException(status_code=404, detail="Channel not found")
return channel
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.delete("/channels/{channel_id}")
async def delete_channel(
channel_id: str,
db: Session = Depends(get_db),
):
"""Delete an audio channel."""
try:
success = await channels.delete_channel(channel_id, db)
if not success:
raise HTTPException(status_code=404, detail="Channel not found")
return {"message": "Channel deleted successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/channels/{channel_id}/voices")
async def get_channel_voices(
channel_id: str,
db: Session = Depends(get_db),
):
"""Get list of profile IDs assigned to a channel."""
try:
profile_ids = await channels.get_channel_voices(channel_id, db)
return {"profile_ids": profile_ids}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.put("/channels/{channel_id}/voices")
async def set_channel_voices(
channel_id: str,
data: models.ChannelVoiceAssignment,
db: Session = Depends(get_db),
):
"""Set which voices are assigned to a channel."""
try:
await channels.set_channel_voices(channel_id, data, db)
return {"message": "Channel voices updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/profiles/{profile_id}/channels")
async def get_profile_channels(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get list of channel IDs assigned to a profile."""
try:
channel_ids = await channels.get_profile_channels(profile_id, db)
return {"channel_ids": channel_ids}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.put("/profiles/{profile_id}/channels")
async def set_profile_channels(
profile_id: str,
data: models.ProfileChannelAssignment,
db: Session = Depends(get_db),
):
"""Set which channels a profile is assigned to."""
try:
await channels.set_profile_channels(profile_id, data, db)
return {"message": "Profile channels updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# ============================================
# GENERATION ENDPOINTS
# ============================================
@app.post("/generate", response_model=models.GenerationResponse)
async def generate_speech(
data: models.GenerationRequest,
db: Session = Depends(get_db),
):
"""Generate speech from text using a voice profile."""
task_manager = get_task_manager()
generation_id = str(uuid.uuid4())
try:
# Start tracking generation
task_manager.start_generation(
task_id=generation_id,
profile_id=data.profile_id,
text=data.text,
)
# Get profile
profile = await profiles.get_profile(data.profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
# Generate audio
# Resolve model size and load the correct model FIRST.
# This must happen before create_voice_prompt_for_profile because that
# function calls load_model_async(None), which falls back to self.model_size.
# If the model is already loaded with the right size at that point, it
# returns immediately and the voice prompt is created by the correct model.
tts_model = tts.get_tts_model()
model_size = data.model_size or "1.7B"
# Check if model needs to be downloaded first
model_path = tts_model._get_model_path(model_size)
if not tts_model._is_model_cached(model_size):
# Model is not fully cached — kick off a background download and tell
# the client to retry once it's ready.
model_name = f"qwen-tts-{model_size}"
async def download_model_background():
try:
await tts_model.load_model_async(model_size)
except Exception as e:
task_manager.error_download(model_name, str(e))
task_manager.start_download(model_name)
asyncio.create_task(download_model_background())
raise HTTPException(
status_code=202,
detail={
"message": f"Model {model_size} is being downloaded. Please wait and try again.",
"model_name": model_name,
"downloading": True,
},
)
# Load (or switch to) the requested model before building the voice prompt
await tts_model.load_model_async(model_size)
# Create voice prompt from profile (model is already loaded with correct size)
voice_prompt = await profiles.create_voice_prompt_for_profile(
data.profile_id,
db,
)
audio, sample_rate = await tts_model.generate(
data.text,
voice_prompt,
data.language,
data.seed,
data.instruct,
)
# Calculate duration
duration = len(audio) / sample_rate
# Save audio
audio_path = config.get_generations_dir() / f"{generation_id}.wav"
from .utils.audio import save_audio
save_audio(audio, str(audio_path), sample_rate)
# Create history entry
generation = await history.create_generation(
profile_id=data.profile_id,
text=data.text,
language=data.language,
audio_path=str(audio_path),
duration=duration,
seed=data.seed,
db=db,
instruct=data.instruct,
)
# Mark generation as complete
task_manager.complete_generation(generation_id)
return generation
except ValueError as e:
task_manager.complete_generation(generation_id)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
task_manager.complete_generation(generation_id)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate/stream")
async def stream_speech(
data: models.GenerationRequest,
db: Session = Depends(get_db),
):
"""
Generate speech and stream the WAV audio directly without saving to disk.
Returns raw WAV bytes via a StreamingResponse so the client can start
playing audio before the entire file has been received. This endpoint
does NOT create a history entry — use /generate for that.
"""
profile = await profiles.get_profile(data.profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
tts_model = tts.get_tts_model()
model_size = data.model_size or "1.7B"
if not tts_model._is_model_cached(model_size):
raise HTTPException(
status_code=400,
detail=f"Model {model_size} is not downloaded yet. Use /generate to trigger a download.",
)
# Load the correct model before building the voice prompt (fixes issue #96)
await tts_model.load_model_async(model_size)
voice_prompt = await profiles.create_voice_prompt_for_profile(data.profile_id, db)
audio, sample_rate = await tts_model.generate(
data.text,
voice_prompt,
data.language,
data.seed,
data.instruct,
)
wav_bytes = tts.audio_to_wav_bytes(audio, sample_rate)
async def _wav_stream():
# Yield in chunks so large responses don't block the event loop
chunk_size = 64 * 1024 # 64 KB
for i in range(0, len(wav_bytes), chunk_size):
yield wav_bytes[i : i + chunk_size]
return StreamingResponse(
_wav_stream(),
media_type="audio/wav",
headers={"Content-Disposition": 'attachment; filename="speech.wav"'},
)
# ============================================
# HISTORY ENDPOINTS
# ============================================
@app.get("/history", response_model=models.HistoryListResponse)
async def list_history(
profile_id: Optional[str] = None,
search: Optional[str] = None,
limit: int = 50,
offset: int = 0,
db: Session = Depends(get_db),
):
"""List generation history with optional filters."""
query = models.HistoryQuery(
profile_id=profile_id,
search=search,
limit=limit,
offset=offset,
)
return await history.list_generations(query, db)
@app.get("/history/stats")
async def get_stats(db: Session = Depends(get_db)):
"""Get generation statistics."""
return await history.get_generation_stats(db)
@app.post("/history/import")
async def import_generation(
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Import a generation from a ZIP archive."""
# Validate file size (max 50MB)
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
# Read file content
content = await file.read()
if len(content) > MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
)
try:
result = await export_import.import_generation_from_zip(content, db)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/history/{generation_id}", response_model=models.HistoryResponse)
async def get_generation(
generation_id: str,
db: Session = Depends(get_db),
):
"""Get a generation by ID."""
# Get generation with profile name
result = db.query(
DBGeneration,
DBVoiceProfile.name.label('profile_name')
).join(
DBVoiceProfile,
DBGeneration.profile_id == DBVoiceProfile.id
).filter(
DBGeneration.id == generation_id
).first()
if not result:
raise HTTPException(status_code=404, detail="Generation not found")
gen, profile_name = result
return models.HistoryResponse(
id=gen.id,
profile_id=gen.profile_id,
profile_name=profile_name,
text=gen.text,
language=gen.language,
audio_path=gen.audio_path,
duration=gen.duration,
seed=gen.seed,
instruct=gen.instruct,
created_at=gen.created_at,
)
@app.delete("/history/{generation_id}")
async def delete_generation(
generation_id: str,
db: Session = Depends(get_db),
):
"""Delete a generation."""
success = await history.delete_generation(generation_id, db)
if not success:
raise HTTPException(status_code=404, detail="Generation not found")
return {"message": "Generation deleted successfully"}
@app.get("/history/{generation_id}/export")
async def export_generation(
generation_id: str,
db: Session = Depends(get_db),
):
"""Export a generation as a ZIP archive."""
try:
# Get generation to create filename
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
raise HTTPException(status_code=404, detail="Generation not found")
# Export to ZIP
zip_bytes = export_import.export_generation_to_zip(generation_id, db)
# Create safe filename from text
safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
if not safe_text:
safe_text = "generation"
filename = f"generation-{safe_text}.voicebox.zip"
# Return as streaming response
return StreamingResponse(
io.BytesIO(zip_bytes),
media_type="application/zip",
headers={
"Content-Disposition": _safe_content_disposition("attachment", filename)
}
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/history/{generation_id}/export-audio")
async def export_generation_audio(
generation_id: str,
db: Session = Depends(get_db),
):
"""Export only the audio file from a generation."""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
raise HTTPException(status_code=404, detail="Generation not found")
audio_path = Path(generation.audio_path)
if not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
# Create safe filename from text
safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
if not safe_text:
safe_text = "generation"
filename = f"{safe_text}.wav"
return FileResponse(
audio_path,
media_type="audio/wav",
headers={
"Content-Disposition": _safe_content_disposition("attachment", filename)
}
)
# ============================================
# TRANSCRIPTION ENDPOINTS
# ============================================
@app.post("/transcribe", response_model=models.TranscriptionResponse)
async def transcribe_audio(
file: UploadFile = File(...),
language: Optional[str] = Form(None),
):
"""Transcribe audio file to text."""
# Save uploaded file to temporary location
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
# Get audio duration
from .utils.audio import load_audio
audio, sr = load_audio(tmp_path)
duration = len(audio) / sr
# Transcribe
whisper_model = transcribe.get_whisper_model()
# Check if Whisper model is downloaded (uses default size "base")
model_size = whisper_model.model_size
model_name = f"openai/whisper-{model_size}"
# Check if model is cached
from huggingface_hub import constants as hf_constants
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_name.replace("/", "--"))
if not repo_cache.exists():
# Start download in background
progress_model_name = f"whisper-{model_size}"
async def download_whisper_background():
try:
await whisper_model.load_model_async(model_size)
except Exception as e:
get_task_manager().error_download(progress_model_name, str(e))
get_task_manager().start_download(progress_model_name)
asyncio.create_task(download_whisper_background())
# Return 202 Accepted
raise HTTPException(
status_code=202,
detail={
"message": f"Whisper model {model_size} is being downloaded. Please wait and try again.",
"model_name": progress_model_name,
"downloading": True
}
)
text = await whisper_model.transcribe(tmp_path, language)
return models.TranscriptionResponse(
text=text,
duration=duration,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
# Clean up temp file
Path(tmp_path).unlink(missing_ok=True)
# ============================================
# STORY ENDPOINTS
# ============================================
@app.get("/stories", response_model=List[models.StoryResponse])
async def list_stories(db: Session = Depends(get_db)):
"""List all stories."""
return await stories.list_stories(db)
@app.post("/stories", response_model=models.StoryResponse)
async def create_story(
data: models.StoryCreate,
db: Session = Depends(get_db),
):
"""Create a new story."""
try:
return await stories.create_story(data, db)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/stories/{story_id}", response_model=models.StoryDetailResponse)
async def get_story(
story_id: str,
db: Session = Depends(get_db),
):
"""Get a story with all its items."""
story = await stories.get_story(story_id, db)
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
@app.put("/stories/{story_id}", response_model=models.StoryResponse)
async def update_story(
story_id: str,
data: models.StoryCreate,
db: Session = Depends(get_db),
):
"""Update a story."""
story = await stories.update_story(story_id, data, db)
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
@app.delete("/stories/{story_id}")
async def delete_story(
story_id: str,
db: Session = Depends(get_db),
):
"""Delete a story."""
success = await stories.delete_story(story_id, db)
if not success:
raise HTTPException(status_code=404, detail="Story not found")
return {"message": "Story deleted successfully"}
@app.post("/stories/{story_id}/items", response_model=models.StoryItemDetail)
async def add_story_item(
story_id: str,
data: models.StoryItemCreate,
db: Session = Depends(get_db),
):
"""Add a generation to a story."""
item = await stories.add_item_to_story(story_id, data, db)
if not item:
raise HTTPException(status_code=404, detail="Story or generation not found")
return item
@app.delete("/stories/{story_id}/items/{item_id}")
async def remove_story_item(
story_id: str,
item_id: str,
db: Session = Depends(get_db),
):
"""Remove a story item from a story."""
success = await stories.remove_item_from_story(story_id, item_id, db)
if not success:
raise HTTPException(status_code=404, detail="Story item not found")
return {"message": "Item removed successfully"}
@app.put("/stories/{story_id}/items/times")
async def update_story_item_times(
story_id: str,
data: models.StoryItemBatchUpdate,
db: Session = Depends(get_db),
):
"""Update story item timecodes."""
success = await stories.update_story_item_times(story_id, data, db)
if not success:
raise HTTPException(status_code=400, detail="Invalid timecode update request")
return {"message": "Item timecodes updated successfully"}
@app.put("/stories/{story_id}/items/reorder", response_model=List[models.StoryItemDetail])
async def reorder_story_items(
story_id: str,
data: models.StoryItemReorder,
db: Session = Depends(get_db),
):
"""Reorder story items and recalculate timecodes."""
items = await stories.reorder_story_items(story_id, data.generation_ids, db)
if items is None:
raise HTTPException(status_code=400, detail="Invalid reorder request - ensure all generation IDs belong to this story")
return items
@app.put("/stories/{story_id}/items/{item_id}/move", response_model=models.StoryItemDetail)
async def move_story_item(
story_id: str,
item_id: str,
data: models.StoryItemMove,
db: Session = Depends(get_db),
):
"""Move a story item (update position and/or track)."""
item = await stories.move_story_item(story_id, item_id, data, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item not found")
return item
@app.put("/stories/{story_id}/items/{item_id}/trim", response_model=models.StoryItemDetail)
async def trim_story_item(
story_id: str,
item_id: str,
data: models.StoryItemTrim,
db: Session = Depends(get_db),
):
"""Trim a story item (update trim_start_ms and trim_end_ms)."""
item = await stories.trim_story_item(story_id, item_id, data, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item not found or invalid trim values")
return item
@app.post("/stories/{story_id}/items/{item_id}/split", response_model=List[models.StoryItemDetail])
async def split_story_item(
story_id: str,
item_id: str,
data: models.StoryItemSplit,
db: Session = Depends(get_db),
):
"""Split a story item at a given time, creating two clips."""
items = await stories.split_story_item(story_id, item_id, data, db)
if items is None:
raise HTTPException(status_code=404, detail="Story item not found or invalid split point")
return items
@app.post("/stories/{story_id}/items/{item_id}/duplicate", response_model=models.StoryItemDetail)
async def duplicate_story_item(
story_id: str,
item_id: str,
db: Session = Depends(get_db),
):
"""Duplicate a story item, creating a copy with all properties."""
item = await stories.duplicate_story_item(story_id, item_id, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item not found")
return item
@app.get("/stories/{story_id}/export-audio")
async def export_story_audio(
story_id: str,
db: Session = Depends(get_db),
):
"""Export story as single mixed audio file with timecode-based mixing."""
try:
# Get story to create filename
story = db.query(database.Story).filter_by(id=story_id).first()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
# Export audio
audio_bytes = await stories.export_story_audio(story_id, db)
if not audio_bytes:
raise HTTPException(status_code=400, detail="Story has no audio items")
# Create safe filename
safe_name = "".join(c for c in story.name if c.isalnum() or c in (' ', '-', '_')).strip()
if not safe_name:
safe_name = "story"
filename = f"{safe_name}.wav"
# Return as streaming response
return StreamingResponse(
io.BytesIO(audio_bytes),
media_type="audio/wav",
headers={
"Content-Disposition": _safe_content_disposition("attachment", filename)
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ============================================
# FILE SERVING
# ============================================
@app.get("/audio/{generation_id}")
async def get_audio(generation_id: str, db: Session = Depends(get_db)):
"""Serve generated audio file."""
generation = await history.get_generation(generation_id, db)
if not generation:
raise HTTPException(status_code=404, detail="Generation not found")
audio_path = Path(generation.audio_path)
if not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(
audio_path,
media_type="audio/wav",
filename=f"generation_{generation_id}.wav",
)
@app.get("/samples/{sample_id}")
async def get_sample_audio(sample_id: str, db: Session = Depends(get_db)):
"""Serve profile sample audio file."""
from .database import ProfileSample as DBProfileSample
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
if not sample:
raise HTTPException(status_code=404, detail="Sample not found")
audio_path = Path(sample.audio_path)
if not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(
audio_path,
media_type="audio/wav",
filename=f"sample_{sample_id}.wav",
)
# ============================================
# MODEL MANAGEMENT
# ============================================
@app.post("/models/load")
async def load_model(model_size: str = "1.7B"):
"""Manually load TTS model."""
try:
tts_model = tts.get_tts_model()
await tts_model.load_model_async(model_size)
return {"message": f"Model {model_size} loaded successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/models/unload")
async def unload_model():
"""Unload TTS model to free memory."""
try:
tts.unload_tts_model()
return {"message": "Model unloaded successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models/progress/{model_name}")
async def get_model_progress(model_name: str):
"""Get model download progress via Server-Sent Events."""
from fastapi.responses import StreamingResponse
progress_manager = get_progress_manager()
async def event_generator():
"""Generate SSE events for progress updates."""
async for event in progress_manager.subscribe(model_name):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@app.get("/models/status", response_model=models.ModelStatusListResponse)
async def get_model_status():
"""Get status of all available models."""
from huggingface_hub import constants as hf_constants
from pathlib import Path
backend_type = get_backend_type()
task_manager = get_task_manager()
# Get set of currently downloading model names
active_download_names = {task.model_name for task in task_manager.get_active_downloads()}
# Try to import scan_cache_dir (might not be available in older versions)
try:
from huggingface_hub import scan_cache_dir
use_scan_cache = True
except ImportError:
use_scan_cache = False
def check_tts_loaded(model_size: str):
"""Check if TTS model is loaded with specific size."""
try:
tts_model = tts.get_tts_model()
return tts_model.is_loaded() and getattr(tts_model, 'model_size', None) == model_size
except Exception:
return False
def check_whisper_loaded(model_size: str):
"""Check if Whisper model is loaded with specific size."""
try:
whisper_model = transcribe.get_whisper_model()
return whisper_model.is_loaded() and getattr(whisper_model, 'model_size', None) == model_size
except Exception:
return False
# Use backend-specific model IDs
if backend_type == "mlx":
tts_1_7b_id = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
tts_0_6b_id = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16" # Fallback to 1.7B
# MLX backend uses openai/whisper-* models, not mlx-community
whisper_base_id = "openai/whisper-base"
whisper_small_id = "openai/whisper-small"
whisper_medium_id = "openai/whisper-medium"
whisper_large_id = "openai/whisper-large"
else:
tts_1_7b_id = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
tts_0_6b_id = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
whisper_base_id = "openai/whisper-base"
whisper_small_id = "openai/whisper-small"
whisper_medium_id = "openai/whisper-medium"
whisper_large_id = "openai/whisper-large"
model_configs = [
{
"model_name": "qwen-tts-1.7B",
"display_name": "Qwen TTS 1.7B",
"hf_repo_id": tts_1_7b_id,
"model_size": "1.7B",
"check_loaded": lambda: check_tts_loaded("1.7B"),
},
{
"model_name": "qwen-tts-0.6B",
"display_name": "Qwen TTS 0.6B",
"hf_repo_id": tts_0_6b_id,
"model_size": "0.6B",
"check_loaded": lambda: check_tts_loaded("0.6B"),
},
{
"model_name": "whisper-base",
"display_name": "Whisper Base",
"hf_repo_id": whisper_base_id,
"model_size": "base",
"check_loaded": lambda: check_whisper_loaded("base"),
},
{
"model_name": "whisper-small",
"display_name": "Whisper Small",
"hf_repo_id": whisper_small_id,
"model_size": "small",
"check_loaded": lambda: check_whisper_loaded("small"),
},
{
"model_name": "whisper-medium",
"display_name": "Whisper Medium",
"hf_repo_id": whisper_medium_id,
"model_size": "medium",
"check_loaded": lambda: check_whisper_loaded("medium"),
},
{
"model_name": "whisper-large",
"display_name": "Whisper Large",
"hf_repo_id": whisper_large_id,
"model_size": "large",
"check_loaded": lambda: check_whisper_loaded("large"),
},
]
# Build a mapping of model_name -> hf_repo_id so we can check if shared repos are downloading
model_to_repo = {cfg["model_name"]: cfg["hf_repo_id"] for cfg in model_configs}
# Get the set of hf_repo_ids that are currently being downloaded
# This handles the case where multiple models share the same repo (e.g., 0.6B and 1.7B on MLX)
active_download_repos = {model_to_repo.get(name) for name in active_download_names if name in model_to_repo}
# Get HuggingFace cache info (if available)
cache_info = None
if use_scan_cache:
try:
cache_info = scan_cache_dir()
except Exception:
# Function failed, continue without it
pass
statuses = []
for config in model_configs:
try:
downloaded = False
size_mb = None
loaded = False
# Method 1: Try using scan_cache_dir if available
if cache_info:
repo_id = config["hf_repo_id"]
for repo in cache_info.repos:
if repo.repo_id == repo_id:
# Check if actual model weight files exist (not just config files)
# scan_cache_dir only shows completed files, so check if any are model weights
has_model_weights = False
for rev in repo.revisions:
for f in rev.files:
fname = f.file_name.lower()
if fname.endswith(('.safetensors', '.bin', '.pt', '.pth', '.npz')):
has_model_weights = True
break
if has_model_weights:
break
# Also check for .incomplete files in blobs directory (downloads in progress)
has_incomplete = False
try:
cache_dir = hf_constants.HF_HUB_CACHE
blobs_dir = Path(cache_dir) / ("models--" + repo_id.replace("/", "--")) / "blobs"
if blobs_dir.exists():
has_incomplete = any(blobs_dir.glob("*.incomplete"))
except Exception:
pass
# Only mark as downloaded if we have model weights AND no incomplete files
if has_model_weights and not has_incomplete:
downloaded = True
# Calculate size from cache info
try:
total_size = sum(revision.size_on_disk for revision in repo.revisions)
size_mb = total_size / (1024 * 1024)
except Exception:
pass
break
# Method 2: Fallback to checking cache directory directly (using HuggingFace's OS-specific cache location)
if not downloaded:
try:
cache_dir = hf_constants.HF_HUB_CACHE
repo_cache = Path(cache_dir) / ("models--" + config["hf_repo_id"].replace("/", "--"))
if repo_cache.exists():
# Check for .incomplete files - if any exist, download is still in progress
blobs_dir = repo_cache / "blobs"
has_incomplete = blobs_dir.exists() and any(blobs_dir.glob("*.incomplete"))
if not has_incomplete:
# Check for actual model weight files (not just index files)
# in the snapshots directory (symlinks to completed blobs)
snapshots_dir = repo_cache / "snapshots"
has_model_files = False
if snapshots_dir.exists():
has_model_files = (
any(snapshots_dir.rglob("*.bin")) or
any(snapshots_dir.rglob("*.safetensors")) or
any(snapshots_dir.rglob("*.pt")) or
any(snapshots_dir.rglob("*.pth")) or
any(snapshots_dir.rglob("*.npz"))
)
if has_model_files:
downloaded = True
# Calculate size (exclude .incomplete files)
try:
total_size = sum(
f.stat().st_size for f in repo_cache.rglob("*")
if f.is_file() and not f.name.endswith('.incomplete')
)
size_mb = total_size / (1024 * 1024)
except Exception:
pass
except Exception:
pass
# Method 3 removed - checking for config.json is too lenient
# Methods 1 and 2 properly verify that model weight files exist
# Check if loaded in memory
try:
loaded = config["check_loaded"]()
except Exception:
loaded = False
# Check if this model (or its shared repo) is currently being downloaded
is_downloading = config["hf_repo_id"] in active_download_repos
# If downloading, don't report as downloaded (partial files exist)
if is_downloading:
downloaded = False
size_mb = None # Don't show partial size during download
statuses.append(models.ModelStatus(
model_name=config["model_name"],
display_name=config["display_name"],
downloaded=downloaded,
downloading=is_downloading,
size_mb=size_mb,
loaded=loaded,
))
except Exception as e:
# If check fails, try to at least check if loaded
try:
loaded = config["check_loaded"]()
except Exception:
loaded = False
# Check if this model (or its shared repo) is currently being downloaded
is_downloading = config["hf_repo_id"] in active_download_repos
statuses.append(models.ModelStatus(
model_name=config["model_name"],
display_name=config["display_name"],
downloaded=False, # Assume not downloaded if check failed
downloading=is_downloading,
size_mb=None,
loaded=loaded,
))
return models.ModelStatusListResponse(models=statuses)
@app.post("/models/download")
async def trigger_model_download(request: models.ModelDownloadRequest):
"""Trigger download of a specific model."""
import asyncio
task_manager = get_task_manager()
progress_manager = get_progress_manager()
model_configs = {
"qwen-tts-1.7B": {
"model_size": "1.7B",
"load_func": lambda: tts.get_tts_model().load_model("1.7B"),
},
"qwen-tts-0.6B": {
"model_size": "0.6B",
"load_func": lambda: tts.get_tts_model().load_model("0.6B"),
},
"whisper-base": {
"model_size": "base",
"load_func": lambda: transcribe.get_whisper_model().load_model("base"),
},
"whisper-small": {
"model_size": "small",
"load_func": lambda: transcribe.get_whisper_model().load_model("small"),
},
"whisper-medium": {
"model_size": "medium",
"load_func": lambda: transcribe.get_whisper_model().load_model("medium"),
},
"whisper-large": {
"model_size": "large",
"load_func": lambda: transcribe.get_whisper_model().load_model("large"),
},
}
if request.model_name not in model_configs:
raise HTTPException(status_code=400, detail=f"Unknown model: {request.model_name}")
config = model_configs[request.model_name]
async def download_in_background():
"""Download model in background without blocking the HTTP request."""
try:
# Call the load function (which may be async)
result = config["load_func"]()
# If it's a coroutine, await it
if asyncio.iscoroutine(result):
await result
task_manager.complete_download(request.model_name)
except Exception as e:
task_manager.error_download(request.model_name, str(e))
# Start tracking download
task_manager.start_download(request.model_name)
# Initialize progress state so SSE endpoint has initial data to send.
# This fixes a race condition where the frontend connects to SSE before
# any progress callbacks have fired (especially for large models like Qwen
# where huggingface_hub takes time to fetch metadata for all files).
progress_manager.update_progress(
model_name=request.model_name,
current=0,
total=0, # Will be updated once actual total is known
filename="Connecting to HuggingFace...",
status="downloading",
)
# Start download in background task (don't await)
asyncio.create_task(download_in_background())
# Return immediately - frontend should poll progress endpoint
return {"message": f"Model {request.model_name} download started"}
@app.delete("/models/{model_name}")
async def delete_model(model_name: str):
"""Delete a downloaded model from the HuggingFace cache."""
import shutil
import os
from huggingface_hub import constants as hf_constants
# Map model names to HuggingFace repo IDs
model_configs = {
"qwen-tts-1.7B": {
"hf_repo_id": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
"model_size": "1.7B",
"model_type": "tts",
},
"qwen-tts-0.6B": {
"hf_repo_id": "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
"model_size": "0.6B",
"model_type": "tts",
},
"whisper-base": {
"hf_repo_id": "openai/whisper-base",
"model_size": "base",
"model_type": "whisper",
},
"whisper-small": {
"hf_repo_id": "openai/whisper-small",
"model_size": "small",
"model_type": "whisper",
},
"whisper-medium": {
"hf_repo_id": "openai/whisper-medium",
"model_size": "medium",
"model_type": "whisper",
},
"whisper-large": {
"hf_repo_id": "openai/whisper-large",
"model_size": "large",
"model_type": "whisper",
},
}
if model_name not in model_configs:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
config = model_configs[model_name]
hf_repo_id = config["hf_repo_id"]
try:
# Check if model is loaded and unload it first
if config["model_type"] == "tts":
tts_model = tts.get_tts_model()
if tts_model.is_loaded() and tts_model.model_size == config["model_size"]:
tts.unload_tts_model()
elif config["model_type"] == "whisper":
whisper_model = transcribe.get_whisper_model()
if whisper_model.is_loaded() and whisper_model.model_size == config["model_size"]:
transcribe.unload_whisper_model()
# Find and delete the cache directory (using HuggingFace's OS-specific cache location)
cache_dir = hf_constants.HF_HUB_CACHE
repo_cache_dir = Path(cache_dir) / ("models--" + hf_repo_id.replace("/", "--"))
# Check if the cache directory exists
if not repo_cache_dir.exists():
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in cache")
# Delete the entire cache directory for this model
try:
shutil.rmtree(repo_cache_dir)
except OSError as e:
raise HTTPException(
status_code=500,
detail=f"Failed to delete model cache directory: {str(e)}"
)
return {"message": f"Model {model_name} deleted successfully"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
@app.post("/cache/clear")
async def clear_cache():
"""Clear all voice prompt caches (memory and disk)."""
try:
deleted_count = clear_voice_prompt_cache()
return {
"message": f"Voice prompt cache cleared successfully",
"files_deleted": deleted_count,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to clear cache: {str(e)}")
# ============================================
# TASK MANAGEMENT
# ============================================
@app.get("/tasks/active", response_model=models.ActiveTasksResponse)
async def get_active_tasks():
"""Return all currently active downloads and generations."""
task_manager = get_task_manager()
progress_manager = get_progress_manager()
# Get active downloads from both task manager and progress manager
# Task manager tracks which downloads are active
# Progress manager has the actual progress data
active_downloads = []
task_manager_downloads = task_manager.get_active_downloads()
progress_active = progress_manager.get_all_active()
# Combine data from both sources
download_map = {task.model_name: task for task in task_manager_downloads}
progress_map = {p["model_name"]: p for p in progress_active}
# Create unified list
all_model_names = set(download_map.keys()) | set(progress_map.keys())
for model_name in all_model_names:
task = download_map.get(model_name)
progress = progress_map.get(model_name)
if task:
active_downloads.append(models.ActiveDownloadTask(
model_name=model_name,
status=task.status,
started_at=task.started_at,
))
elif progress:
# Progress exists but no task - create from progress data
timestamp_str = progress.get("timestamp")
if timestamp_str:
try:
started_at = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
except (ValueError, AttributeError):
started_at = datetime.utcnow()
else:
started_at = datetime.utcnow()
active_downloads.append(models.ActiveDownloadTask(
model_name=model_name,
status=progress.get("status", "downloading"),
started_at=started_at,
))
# Get active generations
active_generations = []
for gen_task in task_manager.get_active_generations():
active_generations.append(models.ActiveGenerationTask(
task_id=gen_task.task_id,
profile_id=gen_task.profile_id,
text_preview=gen_task.text_preview,
started_at=gen_task.started_at,
))
return models.ActiveTasksResponse(
downloads=active_downloads,
generations=active_generations,
)
# ============================================
# STARTUP & SHUTDOWN
# ============================================
def _get_gpu_status() -> str:
"""Get GPU availability status."""
backend_type = get_backend_type()
if torch.cuda.is_available():
return f"CUDA ({torch.cuda.get_device_name(0)})"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "MPS (Apple Silicon)"
elif backend_type == "mlx":
return "Metal (Apple Silicon via MLX)"
return "None (CPU only)"
@app.on_event("startup")
async def startup_event():
"""Run on application startup."""
print("voicebox API starting up...")
database.init_db()
print(f"Database initialized at {database._db_path}")
backend_type = get_backend_type()
print(f"Backend: {backend_type.upper()}")
print(f"GPU available: {_get_gpu_status()}")
# Initialize progress manager with main event loop for thread-safe operations
try:
progress_manager = get_progress_manager()
progress_manager._set_main_loop(asyncio.get_running_loop())
print("Progress manager initialized with event loop")
except Exception as e:
print(f"Warning: Could not initialize progress manager event loop: {e}")
# Ensure HuggingFace cache directory exists
try:
from huggingface_hub import constants as hf_constants
cache_dir = Path(hf_constants.HF_HUB_CACHE)
cache_dir.mkdir(parents=True, exist_ok=True)
print(f"HuggingFace cache directory: {cache_dir}")
except Exception as e:
print(f"Warning: Could not create HuggingFace cache directory: {e}")
print("Model downloads may fail. Please ensure the directory exists and has write permissions.")
@app.on_event("shutdown")
async def shutdown_event():
"""Run on application shutdown."""
print("voicebox API shutting down...")
# Unload models to free memory
tts.unload_tts_model()
transcribe.unload_whisper_model()
# ============================================
# MAIN
# ============================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="voicebox backend server")
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Host to bind to (use 0.0.0.0 for remote access)",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to bind to",
)
parser.add_argument(
"--data-dir",
type=str,
default=None,
help="Data directory for database, profiles, and generated audio",
)
args = parser.parse_args()
# Set data directory if provided
if args.data_dir:
config.set_data_dir(args.data_dir)
# Initialize database after data directory is set
database.init_db()
uvicorn.run(
"backend.main:app",
host=args.host,
port=args.port,
reload=False, # Disable reload in production
)
2. backend/utils/cache.py
"""
Voice prompt caching utilities.
"""
import hashlib
import torch
import warnings
from pathlib import Path
from typing import Optional, Union, Dict, Any
from .. import config
def _get_cache_dir() -> Path:
"""Get cache directory from config."""
return config.get_cache_dir()
# In-memory cache - can store dict (voice prompt) or tensor (legacy)
_memory_cache: dict[str, Union[torch.Tensor, Dict[str, Any]]] = {}
def get_cache_key(audio_path: str, reference_text: str) -> str:
"""
Generate cache key from audio file and reference text.
Args:
audio_path: Path to audio file
reference_text: Reference text
Returns:
Cache key (MD5 hash)
"""
# Read audio file
with open(audio_path, "rb") as f:
audio_bytes = f.read()
# Combine audio bytes and text
combined = audio_bytes + reference_text.encode("utf-8")
# Generate hash
return hashlib.md5(combined).hexdigest()
def get_cached_voice_prompt(
cache_key: str,
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
"""
Get cached voice prompt if available.
Args:
cache_key: Cache key
Returns:
Cached voice prompt (dict or tensor) or None
"""
# Check in-memory cache
if cache_key in _memory_cache:
return _memory_cache[cache_key]
# Check disk cache
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
if cache_file.exists():
try:
# 抑制 FutureWarning 并使用 weights_only=False
# 因为我们缓存的是自己生成的 voice prompt,是可信的
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
prompt = torch.load(cache_file, weights_only=False)
_memory_cache[cache_key] = prompt
return prompt
except Exception:
# Cache file corrupted, delete it
cache_file.unlink()
return None
def cache_voice_prompt(
cache_key: str,
voice_prompt: Union[torch.Tensor, Dict[str, Any]],
) -> None:
"""
Cache voice prompt to memory and disk.
Args:
cache_key: Cache key
voice_prompt: Voice prompt (dict or tensor)
"""
# Store in memory
_memory_cache[cache_key] = voice_prompt
# Store on disk (torch.save can handle both dicts and tensors)
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
torch.save(voice_prompt, cache_file)
def clear_voice_prompt_cache() -> int:
"""
Clear all voice prompt caches (memory and disk).
Returns:
Number of cache files deleted
"""
# Clear memory cache
_memory_cache.clear()
# Clear disk cache
cache_dir = _get_cache_dir()
deleted_count = 0
if cache_dir.exists():
# Delete prompt cache files
for cache_file in cache_dir.glob("*.prompt"):
try:
cache_file.unlink()
deleted_count += 1
except Exception as e:
print(f"Failed to delete cache file {cache_file}: {e}")
# Delete combined audio files
for audio_file in cache_dir.glob("combined_*.wav"):
try:
audio_file.unlink()
deleted_count += 1
except Exception as e:
print(f"Failed to delete combined audio file {audio_file}: {e}")
return deleted_count
def clear_profile_cache(profile_id: str) -> int:
"""
Clear cache files for a specific profile.
Args:
profile_id: Profile ID
Returns:
Number of cache files deleted
"""
cache_dir = _get_cache_dir()
deleted_count = 0
if cache_dir.exists():
# Delete combined audio files for this profile
pattern = f"combined_{profile_id}_*.wav"
for audio_file in cache_dir.glob(pattern):
try:
audio_file.unlink()
deleted_count += 1
except Exception as e:
print(f"Failed to delete combined audio file {audio_file}: {e}")
return deleted_count
3. backend/utils/hf_config.py
"""
HuggingFace Hub 配置工具模块。
提供以下功能:
1. 禁用SSL证书验证
2. 使用国内镜像加速下载
3. 优化本地模型加载,避免重复访问网络
"""
import os
import ssl
import warnings
from pathlib import Path
from typing import Optional
# 国内镜像地址
HF_MIRRORS = [
"https://hf-mirror.com",
"https://modelscope.cn/api/v1/models",
]
def configure_huggingface_hub(
disable_ssl_verify: bool = True,
mirror_url: Optional[str] = None,
local_files_only: Optional[bool] = None,
):
"""
配置HuggingFace Hub参数。
Args:
disable_ssl_verify: 是否禁用SSL证书验证
mirror_url: 镜像URL,如果为None则使用默认镜像
local_files_only: 是否仅使用本地文件,如果为None则自动判断
"""
# 禁用SSL证书验证
if disable_ssl_verify:
_disable_ssl_verification()
# 设置镜像
if mirror_url:
os.environ["HF_ENDPOINT"] = mirror_url
elif "HF_ENDPOINT" not in os.environ:
# 使用默认镜像
os.environ["HF_ENDPOINT"] = HF_MIRRORS[0]
# 设置本地文件优先模式
if local_files_only is not None:
os.environ["HF_HUB_OFFLINE"] = "1" if local_files_only else "0"
def _disable_ssl_verification():
"""禁用SSL证书验证。"""
try:
# 禁用SSL验证警告
ssl._create_default_https_context = ssl._create_unverified_context
# 抑制SSL相关警告
warnings.filterwarnings("ignore", message="Unverified HTTPS request")
warnings.filterwarnings("ignore", category=UserWarning, message=".*SSL.*")
# 尝试禁用 urllib3 的警告
try:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
except Exception:
pass
except Exception as e:
print(f"[HF Config] Warning: Could not disable SSL verification: {e}")
def get_model_download_kwargs(
model_name: str,
is_cached: bool,
force_download: bool = False,
) -> dict:
"""
获取模型下载参数。
Args:
model_name: 模型名称
is_cached: 模型是否已缓存
force_download: 是否强制重新下载
Returns:
包含下载参数的字典
"""
kwargs = {}
# 如果模型已缓存且不强制下载,则优先使用本地文件
if is_cached and not force_download:
kwargs["local_files_only"] = True
print(f"[HF Config] Using local files for {model_name}")
else:
# 需要下载时,禁用SSL验证
kwargs["local_files_only"] = False
kwargs["trust_remote_code"] = True
print(f"[HF Config] Will download {model_name} from HuggingFace Hub")
return kwargs
def is_model_fully_cached(
model_id: str,
cache_dir: Optional[str] = None,
) -> bool:
"""
检查模型是否已完全缓存。
Args:
model_id: HuggingFace模型ID
cache_dir: 缓存目录,如果为None则使用默认目录
Returns:
模型是否已完全缓存
"""
try:
from huggingface_hub import constants as hf_constants
# 获取缓存目录
if cache_dir is None:
cache_dir = hf_constants.HF_HUB_CACHE
# 构建模型缓存路径
repo_cache = Path(cache_dir) / ("models--" + model_id.replace("/", "--"))
if not repo_cache.exists():
return False
# 检查是否有未完成的下载
blobs_dir = repo_cache / "blobs"
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
print(f"[HF Config] Found incomplete downloads for {model_id}")
return False
# 检查是否有模型权重文件
snapshots_dir = repo_cache / "snapshots"
if snapshots_dir.exists():
has_weights = (
any(snapshots_dir.rglob("*.safetensors")) or
any(snapshots_dir.rglob("*.bin")) or
any(snapshots_dir.rglob("*.pt")) or
any(snapshots_dir.rglob("*.pth")) or
any(snapshots_dir.rglob("*.npz"))
)
if not has_weights:
print(f"[HF Config] No model weights found for {model_id}")
return False
return True
except Exception as e:
print(f"[HF Config] Error checking cache for {model_id}: {e}")
return False
def setup_huggingface_for_offline():
"""
设置HuggingFace为离线模式,仅使用本地文件。
当模型已下载到本地时,调用此函数可以避免访问网络。
此函数必须在导入 transformers 或 huggingface_hub 之前调用。
"""
# 设置环境变量
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
# 禁用 huggingface_hub 的网络请求
os.environ["HF_UPDATE_DOWNLOAD_COUNTS"] = "0"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
# 设置 transformers 的离线模式
try:
import transformers
transformers.utils.hub._is_offline_mode = True
except Exception:
pass
# 尝试设置 huggingface_hub 的离线模式
try:
import huggingface_hub
huggingface_hub.constants.HF_HUB_OFFLINE = True
except Exception:
pass
print("[HF Config] Set to offline mode - will only use local files")
def setup_huggingface_for_online():
"""
设置HuggingFace为在线模式,允许从网络下载。
"""
os.environ["HF_HUB_OFFLINE"] = "0"
os.environ["TRANSFORMERS_OFFLINE"] = "0"
os.environ["HF_UPDATE_DOWNLOAD_COUNTS"] = "1"
# 设置 transformers 的在线模式
try:
import transformers
transformers.utils.hub._is_offline_mode = False
except Exception:
pass
# 尝试设置 huggingface_hub 的在线模式
try:
import huggingface_hub
huggingface_hub.constants.HF_HUB_OFFLINE = False
except Exception:
pass
print("[HF Config] Set to online mode - will download from network if needed")
def get_huggingface_config() -> dict:
"""
获取当前HuggingFace配置信息。
Returns:
包含配置信息的字典
"""
return {
"HF_ENDPOINT": os.environ.get("HF_ENDPOINT"),
"HF_HUB_OFFLINE": os.environ.get("HF_HUB_OFFLINE"),
"TRANSFORMERS_OFFLINE": os.environ.get("TRANSFORMERS_OFFLINE"),
"HF_HUB_CACHE": os.environ.get("HF_HUB_CACHE"),
}
def clean_no_exist_cache(model_id: str, cache_dir: Optional[str] = None) -> int:
"""
清理模型的 .no_exist 目录。
.no_exist 目录记录了哪些文件在远程仓库中不存在。
在离线模式下,这些记录可能导致不必要的网络访问尝试。
清理这个目录可以让 HuggingFace Hub 在离线模式下不再尝试验证这些文件。
Args:
model_id: HuggingFace模型ID
cache_dir: 缓存目录,如果为None则使用默认目录
Returns:
删除的文件数量
"""
try:
from huggingface_hub import constants as hf_constants
import shutil
# 获取缓存目录
if cache_dir is None:
cache_dir = hf_constants.HF_HUB_CACHE
# 构建 .no_exist 目录路径
repo_cache = Path(cache_dir) / ("models--" + model_id.replace("/", "--"))
no_exist_dir = repo_cache / ".no_exist"
if not no_exist_dir.exists():
print(f"[HF Config] No .no_exist directory found for {model_id}")
return 0
# 统计文件数量
file_count = sum(1 for _ in no_exist_dir.rglob("*") if _.is_file())
# 删除 .no_exist 目录
shutil.rmtree(no_exist_dir)
print(f"[HF Config] Cleaned .no_exist directory for {model_id} ({file_count} files removed)")
return file_count
except Exception as e:
print(f"[HF Config] Error cleaning .no_exist directory for {model_id}: {e}")
return 0
def clean_all_no_exist_cache(cache_dir: Optional[str] = None) -> int:
"""
清理所有模型的 .no_exist 目录。
Args:
cache_dir: 缓存目录,如果为None则使用默认目录
Returns:
删除的总文件数量
"""
try:
from huggingface_hub import constants as hf_constants
import shutil
# 获取缓存目录
if cache_dir is None:
cache_dir = hf_constants.HF_HUB_CACHE
cache_path = Path(cache_dir)
total_files = 0
# 查找所有模型的 .no_exist 目录
for model_dir in cache_path.glob("models--*"):
no_exist_dir = model_dir / ".no_exist"
if no_exist_dir.exists():
file_count = sum(1 for _ in no_exist_dir.rglob("*") if _.is_file())
shutil.rmtree(no_exist_dir)
print(f"[HF Config] Cleaned .no_exist for {model_dir.name} ({file_count} files)")
total_files += file_count
print(f"[HF Config] Total .no_exist files cleaned: {total_files}")
return total_files
except Exception as e:
print(f"[HF Config] Error cleaning all .no_exist directories: {e}")
return 0
# 初始化时自动配置
configure_huggingface_hub(disable_ssl_verify=True)
4. backend/utils/warning_suppressor.py
"""
警告抑制工具模块。
在应用启动时抑制常见的警告信息,保持日志清洁。
"""
import warnings
import os
def suppress_common_warnings():
"""
抑制常见的警告信息。
包括:
- torch.load 的 FutureWarning
- transformers 的 UserWarning
- 其他已知的无害警告
"""
# 抑制 torch.load 的 FutureWarning
warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*")
# 抑制 transformers 的 Flash Attention 警告
warnings.filterwarnings("ignore", category=UserWarning, message=".*flash attention.*")
# 抑制 transformers 的 pad_token_id 警告
warnings.filterwarnings("ignore", category=UserWarning, message=".*pad_token_id.*")
# 抑制 transformers 的 UserWarning
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
# 设置环境变量来抑制 transformers 的警告
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
# 抑制 Python 的 DeprecationWarning
warnings.filterwarnings("ignore", category=DeprecationWarning)
print("[Warning Suppressor] Common warnings suppressed")
def suppress_ssl_warnings():
"""抑制 SSL 相关的警告。"""
import urllib3
# 禁用 InsecureRequestWarning
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
print("[Warning Suppressor] SSL warnings suppressed")
5. backend/backends/pytorch_backend.py
"""
PyTorch backend implementation for TTS and STT.
"""
from typing import Optional, List, Tuple
import asyncio
import torch
import numpy as np
from pathlib import Path
from . import TTSBackend, STTBackend
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
from ..utils.audio import normalize_audio, load_audio
from ..utils.progress import get_progress_manager
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
from ..utils.tasks import get_task_manager
from ..utils.hf_config import (
get_model_download_kwargs,
setup_huggingface_for_online,
)
class PyTorchTTSBackend:
"""PyTorch-based TTS backend using Qwen3-TTS."""
def __init__(self, model_size: str = "1.7B"):
self.model = None
self.model_size = model_size
self.device = self._get_device()
self._current_model_size = None
def _get_device(self) -> str:
"""Get the best available device."""
if torch.cuda.is_available():
return "cuda"
# Intel Arc / Intel Xe GPU via intel-extension-for-pytorch (IPEX)
try:
import intel_extension_for_pytorch # noqa: F401
if hasattr(torch, 'xpu') and torch.xpu.is_available():
return "xpu"
except ImportError:
pass
# Any GPU on Windows via DirectML (torch-directml)
try:
import torch_directml
if torch_directml.device_count() > 0:
return torch_directml.device(0)
except ImportError:
pass
# MPS (Apple Silicon) — kept for completeness but MLX backend is preferred
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "cpu" # MPS disabled for stability; MLX backend handles Apple Silicon
return "cpu"
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _get_model_path(self, model_size: str) -> str:
"""
Get the HuggingFace Hub model ID.
Args:
model_size: Model size (1.7B or 0.6B)
Returns:
HuggingFace Hub model ID
"""
hf_model_map = {
"1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
"0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
}
if model_size not in hf_model_map:
raise ValueError(f"Unknown model size: {model_size}")
return hf_model_map[model_size]
def _get_local_snapshot_path(self, model_size: str) -> Optional[Path]:
"""
Get the local snapshot path if the model is fully cached.
Args:
model_size: Model size to check
Returns:
Path to local snapshot if fully cached, None otherwise
"""
try:
from huggingface_hub import constants as hf_constants
model_id = self._get_model_path(model_size)
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--"))
if not repo_cache.exists():
return None
# Check for .incomplete files - if any exist, download is still in progress
blobs_dir = repo_cache / "blobs"
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
return None
# Check that actual model weight files exist in snapshots
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.exists():
return None
# Get the latest snapshot (by modification time)
snapshot_dirs = sorted(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)
if not snapshot_dirs:
return None
latest_snapshot = snapshot_dirs[0]
# Check for model weights (actual files, not just symlinks)
has_weights = (
any(latest_snapshot.rglob("*.safetensors")) or
any(latest_snapshot.rglob("*.bin"))
)
if not has_weights:
return None
# Check for config.json
if not (latest_snapshot / "config.json").exists():
return None
return latest_snapshot
except Exception as e:
print(f"[_get_local_snapshot_path] Error: {e}")
return None
def _is_model_cached(self, model_size: str) -> bool:
"""
Check if the model is already cached locally AND fully downloaded.
Args:
model_size: Model size to check
Returns:
True if model is fully cached, False if missing or incomplete
"""
local_path = self._get_local_snapshot_path(model_size)
if local_path:
print(f"[_is_model_cached] Model {model_size} is fully cached at {local_path}")
else:
print(f"[_is_model_cached] Model {model_size} is not cached")
return local_path is not None
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the TTS model with automatic downloading from HuggingFace Hub.
Args:
model_size: Model size to load (1.7B or 0.6B)
"""
if model_size is None:
model_size = self.model_size
# If already loaded with correct size, return
if self.model is not None and self._current_model_size == model_size:
return
# Unload existing model if different size requested
if self.model is not None and self._current_model_size != model_size:
self.unload_model()
# Run blocking load in thread pool
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
try:
progress_manager = get_progress_manager()
task_manager = get_task_manager()
model_name = f"qwen-tts-{model_size}"
# Get local snapshot path if model is cached
local_snapshot_path = self._get_local_snapshot_path(model_size)
is_cached = local_snapshot_path is not None
# Get model ID for HuggingFace Hub (used for downloading)
model_id = self._get_model_path(model_size)
# Determine the path to use for loading
# If cached, use local snapshot path directly to avoid any network access
# If not cached, use HuggingFace Hub ID to download
load_path = str(local_snapshot_path) if is_cached else model_id
if is_cached:
print(f"[TTS] Loading model {model_size} from local cache: {load_path}")
else:
print(f"[TTS] Model {model_size} not cached, will download from HuggingFace Hub")
setup_huggingface_for_online()
# Set up progress callback and tracker
# If cached: filter out non-download progress (like "Segment 1/1" during generation)
# If not cached: report all progress (we're actually downloading)
progress_callback = create_hf_progress_callback(model_name, progress_manager)
tracker = HFProgressTracker(progress_callback, filter_non_downloads=is_cached)
# Patch tqdm BEFORE importing qwen_tts
tracker_context = tracker.patch_download()
tracker_context.__enter__()
# Import qwen_tts
from qwen_tts import Qwen3TTSModel
print(f"Loading TTS model {model_size} on {self.device}...")
# Only track download progress if model is NOT cached
if not is_cached:
# Start tracking download task
task_manager.start_download(model_name)
# Initialize progress state so SSE endpoint has initial data to send
progress_manager.update_progress(
model_name=model_name,
current=0,
total=0, # Will be updated once actual total is known
filename="Connecting to HuggingFace...",
status="downloading",
)
# Load the model
try:
# When loading from local path, no need for download kwargs
# When loading from HuggingFace Hub, use download kwargs
if is_cached:
# Load directly from local path - no network access
if self.device == "cpu":
self.model = Qwen3TTSModel.from_pretrained(
load_path,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
)
else:
self.model = Qwen3TTSModel.from_pretrained(
load_path,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
else:
# Load from HuggingFace Hub - will download
download_kwargs = get_model_download_kwargs(model_name, is_cached)
if self.device == "cpu":
self.model = Qwen3TTSModel.from_pretrained(
load_path,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
**download_kwargs
)
else:
self.model = Qwen3TTSModel.from_pretrained(
load_path,
device_map=self.device,
torch_dtype=torch.bfloat16,
**download_kwargs
)
finally:
# Exit the patch context
tracker_context.__exit__(None, None, None)
# Only mark download as complete if we were tracking it
if not is_cached:
progress_manager.mark_complete(model_name)
task_manager.complete_download(model_name)
self._current_model_size = model_size
self.model_size = model_size
print(f"TTS model {model_size} loaded successfully")
except ImportError as e:
print(f"Error: qwen_tts package not found. Install with: pip install git+https://github.com/QwenLM/Qwen3-TTS.git")
progress_manager = get_progress_manager()
task_manager = get_task_manager()
model_name = f"qwen-tts-{model_size}"
progress_manager.mark_error(model_name, str(e))
task_manager.error_download(model_name, str(e))
raise
except Exception as e:
error_msg = str(e)
print(f"Error loading TTS model: {error_msg}")
# 检测离线模式错误并提供更友好的提示
if "offline mode" in error_msg.lower() or "cannot reach" in error_msg.lower():
print(f"\n[提示] 模型文件已缓存,但 qwen_tts 库需要网络连接来验证模型。")
print(f"[提示] 请尝试以下解决方案:")
print(f" 1. 连接网络后重试(推荐)")
print(f" 2. 设置环境变量 HF_HUB_OFFLINE=0 后重试")
print(f" 3. 检查模型缓存是否完整")
else:
print(f"Tip: The model will be automatically downloaded from HuggingFace Hub on first use.")
progress_manager = get_progress_manager()
task_manager = get_task_manager()
model_name = f"qwen-tts-{model_size}"
progress_manager.mark_error(model_name, error_msg)
task_manager.error_download(model_name, error_msg)
raise
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
self._current_model_size = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("TTS model unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
Args:
audio_path: Path to reference audio file
reference_text: Transcript of reference audio
use_cache: Whether to use cached prompt if available
Returns:
Tuple of (voice_prompt_dict, was_cached)
"""
await self.load_model_async(None)
# Check cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cached_prompt = get_cached_voice_prompt(cache_key)
if cached_prompt is not None:
# Cache stores as torch.Tensor but actual prompt is dict
# Convert if needed
if isinstance(cached_prompt, dict):
# For PyTorch backend, the dict should contain tensors, not file paths
# So we can safely return it
return cached_prompt, True
elif isinstance(cached_prompt, torch.Tensor):
# Legacy cache format - convert to dict
# This shouldn't happen in practice, but handle it
return {"prompt": cached_prompt}, True
def _create_prompt_sync():
"""Run synchronous voice prompt creation in thread pool."""
return self.model.create_voice_clone_prompt(
ref_audio=str(audio_path),
ref_text=reference_text,
x_vector_only_mode=False,
)
# Run blocking operation in thread pool
voice_prompt_items = await asyncio.to_thread(_create_prompt_sync)
# Cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cache_voice_prompt(cache_key, voice_prompt_items)
return voice_prompt_items, False
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
"""
Combine multiple reference samples for better quality.
Args:
audio_paths: List of audio file paths
reference_texts: List of reference texts
Returns:
Tuple of (combined_audio, combined_text)
"""
combined_audio = []
for audio_path in audio_paths:
audio, sr = load_audio(audio_path)
audio = normalize_audio(audio)
combined_audio.append(audio)
# Concatenate audio
mixed = np.concatenate(combined_audio)
mixed = normalize_audio(mixed)
# Combine texts
combined_text = " ".join(reference_texts)
return mixed, combined_text
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text using voice prompt.
Args:
text: Text to synthesize
voice_prompt: Voice prompt dictionary from create_voice_prompt
language: Language code (en or zh)
seed: Random seed for reproducibility
instruct: Natural language instruction for speech delivery control
Returns:
Tuple of (audio_array, sample_rate)
"""
# Load model
await self.load_model_async(None)
def _generate_sync():
"""Run synchronous generation in thread pool."""
# Set seed if provided
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Generate audio - this is the blocking operation
wavs, sample_rate = self.model.generate_voice_clone(
text=text,
voice_clone_prompt=voice_prompt,
instruct=instruct,
)
return wavs[0], sample_rate
# Run blocking inference in thread pool to avoid blocking event loop
audio, sample_rate = await asyncio.to_thread(_generate_sync)
return audio, sample_rate
class PyTorchSTTBackend:
"""PyTorch-based STT backend using Whisper."""
def __init__(self, model_size: str = "base"):
self.model = None
self.processor = None
self.model_size = model_size
self.device = self._get_device()
def _get_device(self) -> str:
"""Get the best available device."""
if torch.cuda.is_available():
return "cuda"
# Intel Arc / Intel Xe GPU via intel-extension-for-pytorch (IPEX)
try:
import intel_extension_for_pytorch # noqa: F401
if hasattr(torch, 'xpu') and torch.xpu.is_available():
return "xpu"
except ImportError:
pass
# Any GPU on Windows via DirectML (torch-directml)
try:
import torch_directml
if torch_directml.device_count() > 0:
return torch_directml.device(0)
except ImportError:
pass
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "cpu" # MPS disabled for stability
return "cpu"
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _get_local_snapshot_path(self, model_size: str) -> Optional[Path]:
"""
Get the local snapshot path if the Whisper model is fully cached.
Args:
model_size: Model size to check
Returns:
Path to local snapshot if fully cached, None otherwise
"""
try:
from huggingface_hub import constants as hf_constants
model_id = f"openai/whisper-{model_size}"
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--"))
if not repo_cache.exists():
return None
# Check for .incomplete files - if any exist, download is still in progress
blobs_dir = repo_cache / "blobs"
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
return None
# Check that actual model weight files exist in snapshots
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.exists():
return None
# Get the latest snapshot (by modification time)
snapshot_dirs = sorted(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)
if not snapshot_dirs:
return None
latest_snapshot = snapshot_dirs[0]
# Check for model weights (actual files, not just symlinks)
has_weights = (
any(latest_snapshot.rglob("*.safetensors")) or
any(latest_snapshot.rglob("*.bin"))
)
if not has_weights:
return None
# Check for config.json
if not (latest_snapshot / "config.json").exists():
return None
return latest_snapshot
except Exception as e:
print(f"[_get_local_snapshot_path] Error: {e}")
return None
def _is_model_cached(self, model_size: str) -> bool:
"""
Check if the Whisper model is already cached locally AND fully downloaded.
Args:
model_size: Model size to check
Returns:
True if model is fully cached, False if missing or incomplete
"""
local_path = self._get_local_snapshot_path(model_size)
if local_path:
print(f"[_is_model_cached] Whisper model {model_size} is fully cached at {local_path}")
else:
print(f"[_is_model_cached] Whisper model {model_size} is not cached")
return local_path is not None
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the Whisper model.
Args:
model_size: Model size (tiny, base, small, medium, large)
"""
print(f"[DEBUG] load_model_async called with size: {model_size}")
if model_size is None:
model_size = self.model_size
print(f"[DEBUG] Model already loaded? {self.model is not None}, current size: {self.model_size}, requested: {model_size}")
if self.model is not None and self.model_size == model_size:
print(f"[DEBUG] Early return - model already loaded")
return
print(f"[DEBUG] Calling asyncio.to_thread for _load_model_sync")
# Run blocking load in thread pool
await asyncio.to_thread(self._load_model_sync, model_size)
print(f"[DEBUG] asyncio.to_thread completed")
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
print(f"[DEBUG] _load_model_sync called for Whisper {model_size}")
try:
progress_manager = get_progress_manager()
task_manager = get_task_manager()
progress_model_name = f"whisper-{model_size}"
# Get local snapshot path if model is cached
local_snapshot_path = self._get_local_snapshot_path(model_size)
is_cached = local_snapshot_path is not None
# Get model ID for HuggingFace Hub (used for downloading)
model_id = f"openai/whisper-{model_size}"
# Determine the path to use for loading
# If cached, use local snapshot path directly to avoid any network access
# If not cached, use HuggingFace Hub ID to download
load_path = str(local_snapshot_path) if is_cached else model_id
if is_cached:
print(f"[Whisper] Loading model {model_size} from local cache: {load_path}")
else:
print(f"[Whisper] Model {model_size} not cached, will download from HuggingFace Hub")
setup_huggingface_for_online()
# Set up progress callback and tracker
# If cached: filter out non-download progress
# If not cached: report all progress (we're actually downloading)
progress_callback = create_hf_progress_callback(progress_model_name, progress_manager)
tracker = HFProgressTracker(progress_callback, filter_non_downloads=is_cached)
# Patch tqdm BEFORE importing transformers
print("[DEBUG] Starting tqdm patch BEFORE transformers import")
tracker_context = tracker.patch_download()
tracker_context.__enter__()
print("[DEBUG] tqdm patched, now importing transformers")
# Import transformers
from transformers import WhisperProcessor, WhisperForConditionalGeneration
print(f"[DEBUG] Model name: {model_id}")
print(f"Loading Whisper model {model_size} on {self.device}...")
# Only track download progress if model is NOT cached
if not is_cached:
# Start tracking download task
task_manager.start_download(progress_model_name)
# Initialize progress state so SSE endpoint has initial data to send
progress_manager.update_progress(
model_name=progress_model_name,
current=0,
total=0, # Will be updated once actual total is known
filename="Connecting to HuggingFace...",
status="downloading",
)
# Load models (tqdm is patched, but filters out non-download progress)
try:
# When loading from local path, no need for download kwargs
# When loading from HuggingFace Hub, use download kwargs
if is_cached:
# Load directly from local path - no network access
self.processor = WhisperProcessor.from_pretrained(load_path)
self.model = WhisperForConditionalGeneration.from_pretrained(load_path)
else:
# Load from HuggingFace Hub - will download
download_kwargs = get_model_download_kwargs(progress_model_name, is_cached)
self.processor = WhisperProcessor.from_pretrained(load_path, **download_kwargs)
self.model = WhisperForConditionalGeneration.from_pretrained(load_path, **download_kwargs)
finally:
# Exit the patch context
tracker_context.__exit__(None, None, None)
# Only mark download as complete if we were tracking it
if not is_cached:
progress_manager.mark_complete(progress_model_name)
task_manager.complete_download(progress_model_name)
self.model.to(self.device)
self.model_size = model_size
print(f"Whisper model {model_size} loaded successfully")
except Exception as e:
print(f"Error loading Whisper model: {e}")
progress_manager = get_progress_manager()
task_manager = get_task_manager()
progress_model_name = f"whisper-{model_size}"
progress_manager.mark_error(progress_model_name, str(e))
task_manager.error_download(progress_model_name, str(e))
raise
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
del self.processor
self.model = None
self.processor = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("Whisper model unloaded")
async def transcribe(
self,
audio_path: str,
language: Optional[str] = None,
) -> str:
"""
Transcribe audio to text.
Args:
audio_path: Path to audio file
language: Optional language hint (en or zh)
Returns:
Transcribed text
"""
await self.load_model_async(None)
def _transcribe_sync():
"""Run synchronous transcription in thread pool."""
# Load audio
audio, sr = load_audio(audio_path, sample_rate=16000)
# Process audio
inputs = self.processor(
audio,
sampling_rate=16000,
return_tensors="pt",
)
inputs = inputs.to(self.device)
# Set language if provided
forced_decoder_ids = None
if language:
# Support all languages from frontend: en, zh, ja, ko, de, fr, ru, pt, es, it
# Whisper supports these and many more
forced_decoder_ids = self.processor.get_decoder_prompt_ids(
language=language,
task="transcribe",
)
# Generate transcription
with torch.no_grad():
predicted_ids = self.model.generate(
inputs["input_features"],
forced_decoder_ids=forced_decoder_ids,
)
# Decode
transcription = self.processor.batch_decode(
predicted_ids,
skip_special_tokens=True,
)[0]
return transcription.strip()
# Run blocking transcription in thread pool
return await asyncio.to_thread(_transcribe_sync)
6. backend/backends/mlx_backend.py
"""
MLX backend implementation for TTS and STT using mlx-audio.
"""
from typing import Optional, List, Tuple
import asyncio
import numpy as np
from pathlib import Path
from . import TTSBackend, STTBackend
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
from ..utils.audio import normalize_audio, load_audio
from ..utils.progress import get_progress_manager
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
from ..utils.tasks import get_task_manager
from ..utils.hf_config import (
get_model_download_kwargs,
setup_huggingface_for_online,
)
class MLXTTSBackend:
"""MLX-based TTS backend using mlx-audio."""
def __init__(self, model_size: str = "1.7B"):
self.model = None
self.model_size = model_size
self._current_model_size = None
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _get_model_path(self, model_size: str) -> str:
"""
Get the MLX model path.
Args:
model_size: Model size (1.7B or 0.6B)
Returns:
HuggingFace Hub model ID for MLX
"""
# MLX model mapping
mlx_model_map = {
"1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16",
# 0.6B not yet converted to MLX format
"0.6B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16", # Fallback to 1.7B
}
if model_size not in mlx_model_map:
raise ValueError(f"Unknown model size: {model_size}")
hf_model_id = mlx_model_map[model_size]
print(f"Will download MLX model from HuggingFace Hub: {hf_model_id}")
return hf_model_id
def _get_local_snapshot_path(self, model_size: str) -> Optional[Path]:
"""
Get the local snapshot path if the model is fully cached.
Args:
model_size: Model size to check
Returns:
Path to local snapshot if fully cached, None otherwise
"""
try:
from huggingface_hub import constants as hf_constants
model_id = self._get_model_path(model_size)
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--"))
if not repo_cache.exists():
return None
# Check for .incomplete files - if any exist, download is still in progress
blobs_dir = repo_cache / "blobs"
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
return None
# Check that actual model weight files exist in snapshots
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.exists():
return None
# Get the latest snapshot (by modification time)
snapshot_dirs = sorted(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)
if not snapshot_dirs:
return None
latest_snapshot = snapshot_dirs[0]
# Check for model weights (actual files, not just symlinks)
has_weights = (
any(latest_snapshot.rglob("*.safetensors")) or
any(latest_snapshot.rglob("*.bin")) or
any(latest_snapshot.rglob("*.npz"))
)
if not has_weights:
return None
# Check for config.json
if not (latest_snapshot / "config.json").exists():
return None
return latest_snapshot
except Exception as e:
print(f"[_get_local_snapshot_path] Error: {e}")
return None
def _is_model_cached(self, model_size: str) -> bool:
"""
Check if the model is already cached locally AND fully downloaded.
Args:
model_size: Model size to check
Returns:
True if model is fully cached, False if missing or incomplete
"""
local_path = self._get_local_snapshot_path(model_size)
if local_path:
print(f"[_is_model_cached] Model {model_size} is fully cached at {local_path}")
else:
print(f"[_is_model_cached] Model {model_size} is not cached")
return local_path is not None
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the MLX TTS model.
Args:
model_size: Model size to load (1.7B or 0.6B)
"""
if model_size is None:
model_size = self.model_size
# If already loaded with correct size, return
if self.model is not None and self._current_model_size == model_size:
return
# Unload existing model if different size requested
if self.model is not None and self._current_model_size != model_size:
self.unload_model()
# Run blocking load in thread pool
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
try:
# Set up progress tracking
progress_manager = get_progress_manager()
task_manager = get_task_manager()
model_name = f"qwen-tts-{model_size}"
# Get local snapshot path if model is cached
local_snapshot_path = self._get_local_snapshot_path(model_size)
is_cached = local_snapshot_path is not None
# Get model ID for HuggingFace Hub (used for downloading)
model_id = self._get_model_path(model_size)
# Determine the path to use for loading
# If cached, use local snapshot path directly to avoid any network access
# If not cached, use HuggingFace Hub ID to download
load_path = str(local_snapshot_path) if is_cached else model_id
if is_cached:
print(f"[MLX TTS] Loading model {model_size} from local cache: {load_path}")
else:
print(f"[MLX TTS] Model {model_size} not cached, will download from HuggingFace Hub")
setup_huggingface_for_online()
# Set up progress callback
# If cached: filter out non-download progress
# If not cached: report all progress (we're actually downloading)
progress_callback = create_hf_progress_callback(model_name, progress_manager)
tracker = HFProgressTracker(progress_callback, filter_non_downloads=is_cached)
print(f"Loading MLX TTS model {model_size}...")
# Only track download progress if model is NOT cached
if not is_cached:
# Start tracking download task
task_manager.start_download(model_name)
# Initialize progress state so SSE endpoint has initial data to send
# This provides immediate feedback while HuggingFace fetches metadata
progress_manager.update_progress(
model_name=model_name,
current=0,
total=0, # Will be updated once actual total is known
filename="Connecting to HuggingFace...",
status="downloading",
)
# IMPORTANT: Patch tqdm BEFORE importing mlx_audio
# Otherwise mlx_audio caches reference to original tqdm
tracker_context = tracker.patch_download()
tracker_context.__enter__()
# Import mlx_audio AFTER patching tqdm
from mlx_audio.tts import load
# Load MLX model (downloads automatically)
try:
# When loading from local path, no need for download kwargs
# When loading from HuggingFace Hub, use download kwargs
if is_cached:
# Load directly from local path - no network access
self.model = load(load_path)
else:
# Load from HuggingFace Hub - will download
self.model = load(load_path)
finally:
# Exit the patch context
tracker_context.__exit__(None, None, None)
# Only mark download as complete if we were tracking it
if not is_cached:
progress_manager.mark_complete(model_name)
task_manager.complete_download(model_name)
self._current_model_size = model_size
self.model_size = model_size
print(f"MLX TTS model {model_size} loaded successfully")
except ImportError as e:
print(f"Error: mlx_audio package not found. Install with: pip install mlx-audio")
progress_manager = get_progress_manager()
task_manager = get_task_manager()
model_name = f"qwen-tts-{model_size}"
progress_manager.mark_error(model_name, str(e))
task_manager.error_download(model_name, str(e))
raise
except Exception as e:
print(f"Error loading MLX TTS model: {e}")
progress_manager = get_progress_manager()
task_manager = get_task_manager()
model_name = f"qwen-tts-{model_size}"
progress_manager.mark_error(model_name, str(e))
task_manager.error_download(model_name, str(e))
raise
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
self._current_model_size = None
print("MLX TTS model unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
MLX backend stores voice prompt as a dict with audio path and text.
The actual voice prompt processing happens during generation.
Args:
audio_path: Path to reference audio file
reference_text: Transcript of reference audio
use_cache: Whether to use cached prompt if available
Returns:
Tuple of (voice_prompt_dict, was_cached)
"""
await self.load_model_async(None)
# Check cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cached_prompt = get_cached_voice_prompt(cache_key)
if cached_prompt is not None:
# Return cached prompt (should be dict format)
if isinstance(cached_prompt, dict):
# Validate that the cached audio file still exists
cached_audio_path = cached_prompt.get("ref_audio") or cached_prompt.get("ref_audio_path")
if cached_audio_path and Path(cached_audio_path).exists():
return cached_prompt, True
else:
# Cached file no longer exists, invalidate cache
print(f"Cached audio file not found: {cached_audio_path}, regenerating prompt")
# MLX voice prompt format - store audio path and text
# The model will process this during generation
voice_prompt_items = {
"ref_audio": str(audio_path),
"ref_text": reference_text,
}
# Cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cache_voice_prompt(cache_key, voice_prompt_items)
return voice_prompt_items, False
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
"""
Combine multiple reference samples for better quality.
Args:
audio_paths: List of audio file paths
reference_texts: List of reference texts
Returns:
Tuple of (combined_audio, combined_text)
"""
combined_audio = []
for audio_path in audio_paths:
audio, sr = load_audio(audio_path)
audio = normalize_audio(audio)
combined_audio.append(audio)
# Concatenate audio
mixed = np.concatenate(combined_audio)
mixed = normalize_audio(mixed)
# Combine texts
combined_text = " ".join(reference_texts)
return mixed, combined_text
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text using voice prompt.
Args:
text: Text to synthesize
voice_prompt: Voice prompt dictionary with ref_audio and ref_text
language: Language code (en or zh) - may not be fully supported by MLX
seed: Random seed for reproducibility
instruct: Natural language instruction (may not be supported by MLX)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model_async(None)
print(f"Generating audio for text: {text}")
def _generate_sync():
"""Run synchronous generation in thread pool."""
# MLX generate() returns a generator yielding GenerationResult objects
audio_chunks = []
sample_rate = 24000
# Set seed if provided (MLX uses numpy random)
if seed is not None:
import mlx.core as mx
np.random.seed(seed)
mx.random.seed(seed)
# Extract voice prompt info
ref_audio = voice_prompt.get("ref_audio") or voice_prompt.get("ref_audio_path")
ref_text = voice_prompt.get("ref_text", "")
# Validate that the audio file exists
if ref_audio and not Path(ref_audio).exists():
print(f"Warning: Audio file not found: {ref_audio}")
print("This may be due to a cached voice prompt referencing a deleted temp file.")
print("Regenerating without voice prompt.")
ref_audio = None
# Check if model supports voice cloning via generate method
# MLX API may support ref_audio parameter directly
try:
# Try with voice cloning parameters if supported
if ref_audio:
# Check if generate accepts ref_audio parameter
import inspect
sig = inspect.signature(self.model.generate)
if "ref_audio" in sig.parameters:
# Generate with voice cloning
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
else:
# Fallback: generate without voice cloning
for result in self.model.generate(text):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
else:
# No voice prompt, generate normally
for result in self.model.generate(text):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
except Exception as e:
# If voice cloning fails, try without it
print(f"Warning: Voice cloning failed, generating without voice prompt: {e}")
for result in self.model.generate(text):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
# Concatenate all chunks
if audio_chunks:
audio = np.concatenate([np.asarray(chunk, dtype=np.float32) for chunk in audio_chunks])
else:
# Fallback: empty audio
audio = np.array([], dtype=np.float32)
return audio, sample_rate
# Run blocking inference in thread pool
audio, sample_rate = await asyncio.to_thread(_generate_sync)
return audio, sample_rate
class MLXSTTBackend:
"""MLX-based STT backend using mlx-audio Whisper."""
def __init__(self, model_size: str = "base"):
self.model = None
self.model_size = model_size
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _get_local_snapshot_path(self, model_size: str) -> Optional[Path]:
"""
Get local snapshot path if Whisper model is fully cached.
Args:
model_size: Model size to check
Returns:
Path to local snapshot if fully cached, None otherwise
"""
try:
from huggingface_hub import constants as hf_constants
model_id = f"openai/whisper-{model_size}"
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--"))
if not repo_cache.exists():
return None
# Check for .incomplete files - if any exist, download is still in progress
blobs_dir = repo_cache / "blobs"
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
return None
# Check that actual model weight files exist in snapshots
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.exists():
return None
# Get the latest snapshot (by modification time)
snapshot_dirs = sorted(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)
if not snapshot_dirs:
return None
latest_snapshot = snapshot_dirs[0]
# Check for model weights (actual files, not just symlinks)
has_weights = (
any(latest_snapshot.rglob("*.safetensors")) or
any(latest_snapshot.rglob("*.bin")) or
any(latest_snapshot.rglob("*.npz"))
)
if not has_weights:
return None
# Check for config.json
if not (latest_snapshot / "config.json").exists():
return None
return latest_snapshot
except Exception as e:
print(f"[_get_local_snapshot_path] Error: {e}")
return None
def _is_model_cached(self, model_size: str) -> bool:
"""
Check if the Whisper model is already cached locally AND fully downloaded.
Args:
model_size: Model size to check
Returns:
True if model is fully cached, False if missing or incomplete
"""
local_path = self._get_local_snapshot_path(model_size)
if local_path:
print(f"[_is_model_cached] Whisper model {model_size} is fully cached at {local_path}")
else:
print(f"[_is_model_cached] Whisper model {model_size} is not cached")
return local_path is not None
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the MLX Whisper model.
Args:
model_size: Model size (tiny, base, small, medium, large)
"""
if model_size is None:
model_size = self.model_size
if self.model is not None and self.model_size == model_size:
return
# Run blocking load in thread pool
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
try:
progress_manager = get_progress_manager()
task_manager = get_task_manager()
progress_model_name = f"whisper-{model_size}"
# Get local snapshot path if model is cached
local_snapshot_path = self._get_local_snapshot_path(model_size)
is_cached = local_snapshot_path is not None
# Get model ID for HuggingFace Hub (used for downloading)
model_id = f"openai/whisper-{model_size}"
# Determine the path to use for loading
# If cached, use local snapshot path directly to avoid any network access
# If not cached, use HuggingFace Hub ID to download
load_path = str(local_snapshot_path) if is_cached else model_id
if is_cached:
print(f"[MLX Whisper] Loading model {model_size} from local cache: {load_path}")
else:
print(f"[MLX Whisper] Model {model_size} not cached, will download from HuggingFace Hub")
setup_huggingface_for_online()
# Set up progress callback and tracker
# If cached: filter out non-download progress
# If not cached: report all progress (we're actually downloading)
progress_callback = create_hf_progress_callback(progress_model_name, progress_manager)
tracker = HFProgressTracker(progress_callback, filter_non_downloads=is_cached)
# Patch tqdm BEFORE importing mlx_audio
tracker_context = tracker.patch_download()
tracker_context.__enter__()
# Import mlx_audio
from mlx_audio.stt import load
print(f"Loading MLX Whisper model {model_size}...")
# Only track download progress if model is NOT cached
if not is_cached:
# Start tracking download task
task_manager.start_download(progress_model_name)
# Initialize progress state so SSE endpoint has initial data to send
progress_manager.update_progress(
model_name=progress_model_name,
current=0,
total=0,
filename="Connecting to HuggingFace...",
status="downloading",
)
# Load the model
try:
# When loading from local path, no need for download kwargs
# When loading from HuggingFace Hub, use download kwargs
if is_cached:
# Load directly from local path - no network access
self.model = load(load_path)
else:
# Load from HuggingFace Hub - will download
self.model = load(load_path)
finally:
# Exit the patch context
tracker_context.__exit__(None, None, None)
# Only mark download as complete if we were tracking it
if not is_cached:
progress_manager.mark_complete(progress_model_name)
task_manager.complete_download(progress_model_name)
self.model_size = model_size
print(f"MLX Whisper model {model_size} loaded successfully")
except ImportError as e:
print(f"Error: mlx_audio package not found. Install with: pip install mlx-audio")
progress_manager = get_progress_manager()
task_manager = get_task_manager()
progress_model_name = f"whisper-{model_size}"
progress_manager.mark_error(progress_model_name, str(e))
task_manager.error_download(progress_model_name, str(e))
raise
except Exception as e:
print(f"Error loading MLX Whisper model: {e}")
progress_manager = get_progress_manager()
task_manager = get_task_manager()
progress_model_name = f"whisper-{model_size}"
progress_manager.mark_error(progress_model_name, str(e))
task_manager.error_download(progress_model_name, str(e))
raise
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
print("MLX Whisper model unloaded")
async def transcribe(
self,
audio_path: str,
language: Optional[str] = None,
) -> str:
"""
Transcribe audio to text.
Args:
audio_path: Path to audio file
language: Optional language hint (en or zh)
Returns:
Transcribed text
"""
await self.load_model_async(None)
def _transcribe_sync():
"""Run synchronous transcription in thread pool."""
# MLX Whisper transcription using generate method
# The generate method accepts audio path directly
decode_options = {}
if language:
decode_options["language"] = language
result = self.model.generate(str(audio_path), **decode_options)
# Extract text from result
if isinstance(result, str):
return result.strip()
elif isinstance(result, dict):
return result.get("text", "").strip()
elif hasattr(result, "text"):
return result.text.strip()
else:
return str(result).strip()
# Run blocking transcription in thread pool
return await asyncio.to_thread(_transcribe_sync)
三、说明
1.mlx_backend.py中的代码没有经过测试,因为我没有mac电脑。
2.应用开发和发布,对于用户不可控的网络访问是会引起用户不安的,需要从用户角度考虑、慎重对待。
3.代码中加入了https://hf-mirror.com镜像支持
4.代码中禁用了ssl验证,这是为了解决模型下载时ssl验证失败而做的选择,但这存安全风险,参考上述代码的朋友需要注意。
5.Voicebox项目有多个操作系统的预编译可执行程序,我下载了windows版本的Voicebox_0.1.13_x64-setup.exe,安装运行后,因为为ssl证书验证错误,始终无法完成Qwen/Qwen3-TTS-12Hz-1.7B-Base和Qwen/Qwen3-TTS-12Hz-0.6B-Base两个模型的下载,因此这个预编译版的程序也就无法正常使用。无奈下才有了本篇文章。
更多推荐

所有评论(0)