feat: 支持小米语音模型
This commit is contained in:
parent
2c171a6269
commit
376e635fbe
@ -138,37 +138,53 @@ def _load_json_field(raw: object) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
def load_tts_settings(conn, from_wx_id: str) -> tuple[bool, dict]:
|
||||
global_row = _query_one(conn, "SELECT tts_enabled, tts_settings FROM global_settings LIMIT 1")
|
||||
def load_tts_settings(conn, from_wx_id: str) -> tuple[bool, str, dict, str, str]:
|
||||
global_row = _query_one(
|
||||
conn,
|
||||
"SELECT tts_enabled, tts_model, tts_settings, chat_base_url, chat_api_key FROM global_settings LIMIT 1",
|
||||
)
|
||||
enabled = False
|
||||
tts_model: str = "doubao"
|
||||
settings_json: dict = {}
|
||||
fallback_base_url: str = ""
|
||||
fallback_api_key: str = ""
|
||||
|
||||
if global_row:
|
||||
if global_row.get("tts_enabled") is not None:
|
||||
enabled = bool(global_row["tts_enabled"])
|
||||
if global_row.get("tts_model"):
|
||||
tts_model = str(global_row["tts_model"]).strip() or "doubao"
|
||||
settings_json = _load_json_field(global_row.get("tts_settings"))
|
||||
fallback_base_url = str(global_row.get("chat_base_url") or "").strip()
|
||||
fallback_api_key = str(global_row.get("chat_api_key") or "").strip()
|
||||
|
||||
if from_wx_id.endswith("@chatroom"):
|
||||
override = _query_one(
|
||||
conn,
|
||||
"SELECT tts_enabled, tts_settings FROM chat_room_settings WHERE chat_room_id = %s LIMIT 1",
|
||||
"SELECT tts_enabled, tts_model, tts_settings, chat_base_url, chat_api_key FROM chat_room_settings WHERE chat_room_id = %s LIMIT 1",
|
||||
(from_wx_id,),
|
||||
)
|
||||
else:
|
||||
override = _query_one(
|
||||
conn,
|
||||
"SELECT tts_enabled, tts_settings FROM friend_settings WHERE wechat_id = %s LIMIT 1",
|
||||
"SELECT tts_enabled, tts_model, tts_settings, chat_base_url, chat_api_key FROM friend_settings WHERE wechat_id = %s LIMIT 1",
|
||||
(from_wx_id,),
|
||||
)
|
||||
|
||||
if override:
|
||||
if override.get("tts_enabled") is not None:
|
||||
enabled = bool(override["tts_enabled"])
|
||||
if override.get("tts_model"):
|
||||
tts_model = str(override["tts_model"]).strip() or tts_model
|
||||
override_settings = _load_json_field(override.get("tts_settings"))
|
||||
if override_settings:
|
||||
settings_json = override_settings
|
||||
if str(override.get("chat_base_url") or "").strip():
|
||||
fallback_base_url = str(override["chat_base_url"]).strip()
|
||||
if str(override.get("chat_api_key") or "").strip():
|
||||
fallback_api_key = str(override["chat_api_key"]).strip()
|
||||
|
||||
return enabled, settings_json
|
||||
return enabled, tts_model, settings_json, fallback_base_url, fallback_api_key
|
||||
|
||||
|
||||
def _normalize_emotion(emotion: str) -> str:
|
||||
@ -322,6 +338,98 @@ def synthesize_audio(config: dict, content: str, emotion: str, context_texts: li
|
||||
return bytes(audio_chunks), audio_format
|
||||
|
||||
|
||||
def _pcm16le_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1) -> bytes:
|
||||
import struct
|
||||
|
||||
data_size = len(pcm_data)
|
||||
byte_rate = sample_rate * channels * 2
|
||||
block_align = channels * 2
|
||||
header = struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF",
|
||||
36 + data_size,
|
||||
b"WAVE",
|
||||
b"fmt ",
|
||||
16,
|
||||
1,
|
||||
channels,
|
||||
sample_rate,
|
||||
byte_rate,
|
||||
block_align,
|
||||
16,
|
||||
b"data",
|
||||
data_size,
|
||||
)
|
||||
return header + pcm_data
|
||||
|
||||
|
||||
def synthesize_audio_mimo(config: dict, content: str, voice: str) -> tuple[bytes, str]:
|
||||
api_key = str(config.get("api_key") or "").strip()
|
||||
base_url = str(config.get("base_url") or "https://api.xiaomimimo.com/v1").strip().rstrip("/")
|
||||
model = str(config.get("model") or "mimo-v2.5-tts").strip()
|
||||
if not voice:
|
||||
voice = str(config.get("voice") or "mimo_default").strip()
|
||||
if not api_key:
|
||||
raise RuntimeError("mimo api_key 不能为空")
|
||||
|
||||
url = f"{base_url}/chat/completions"
|
||||
payload = json.dumps({
|
||||
"model": model,
|
||||
"messages": [{"role": "assistant", "content": content}],
|
||||
"audio": {"format": "pcm16", "voice": voice},
|
||||
"stream": True,
|
||||
}).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data=payload,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"api-key": api_key,
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
pcm_chunks = bytearray()
|
||||
try:
|
||||
response = urllib.request.urlopen(req, timeout=300)
|
||||
except urllib.error.HTTPError as exc:
|
||||
error_body = exc.read().decode("utf-8", errors="replace")
|
||||
raise RuntimeError(f"mimo API请求失败,状态码 {exc.code}: {error_body}") from exc
|
||||
except urllib.error.URLError as exc:
|
||||
raise RuntimeError(f"mimo 发送请求失败: {exc}") from exc
|
||||
|
||||
with response:
|
||||
for raw_line in response:
|
||||
line = raw_line.decode("utf-8", errors="replace").strip()
|
||||
if not line or not line.startswith("data:"):
|
||||
continue
|
||||
data_str = line[5:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
delta = choices[0].get("delta") or {}
|
||||
audio = delta.get("audio") or {}
|
||||
audio_data_b64 = audio.get("data") if isinstance(audio, dict) else None
|
||||
if audio_data_b64:
|
||||
try:
|
||||
pcm_chunks.extend(base64.b64decode(audio_data_b64))
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"解码 mimo 音频数据失败: {exc}") from exc
|
||||
|
||||
if not pcm_chunks:
|
||||
raise RuntimeError("mimo 未接收到音频数据")
|
||||
|
||||
wav_data = _pcm16le_to_wav(bytes(pcm_chunks))
|
||||
return wav_data, "wav"
|
||||
|
||||
|
||||
def _guess_mime_type(audio_format: str) -> str:
|
||||
fmt = audio_format.lower()
|
||||
if fmt == "mp3":
|
||||
@ -445,7 +553,7 @@ def main() -> int:
|
||||
return 1
|
||||
|
||||
try:
|
||||
enabled, tts_settings = load_tts_settings(conn, from_wx_id)
|
||||
enabled, tts_model, tts_settings, fallback_base_url, fallback_api_key = load_tts_settings(conn, from_wx_id)
|
||||
except Exception as exc:
|
||||
sys.stdout.write(f"加载文本转语音配置失败: {exc}\n")
|
||||
return 1
|
||||
@ -463,8 +571,25 @@ def main() -> int:
|
||||
sys.stdout.write("未找到文本转语音配置\n")
|
||||
return 1
|
||||
|
||||
model_config = tts_settings.get(tts_model)
|
||||
if not isinstance(model_config, dict) or not model_config:
|
||||
sys.stdout.write(f"未找到 {tts_model} 的文本转语音配置\n")
|
||||
return 1
|
||||
|
||||
try:
|
||||
audio_data, audio_format = synthesize_audio(tts_settings, content, emotion, context_texts)
|
||||
if tts_model == "doubao":
|
||||
audio_data, audio_format = synthesize_audio(model_config, content, emotion, context_texts)
|
||||
elif tts_model == "mimo":
|
||||
if not str(model_config.get("api_key") or "").strip() and fallback_api_key:
|
||||
model_config = dict(model_config)
|
||||
model_config["api_key"] = fallback_api_key
|
||||
if not str(model_config.get("base_url") or "").strip() and fallback_base_url:
|
||||
model_config = dict(model_config)
|
||||
model_config["base_url"] = fallback_base_url
|
||||
audio_data, audio_format = synthesize_audio_mimo(model_config, content, "")
|
||||
else:
|
||||
sys.stdout.write(f"未知的 TTS 模型: {tts_model}\n")
|
||||
return 1
|
||||
except Exception as exc:
|
||||
sys.stdout.write(f"语音合成失败: {exc}\n")
|
||||
return 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user