198 lines
No EOL
6.7 KiB
Python
198 lines
No EOL
6.7 KiB
Python
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Union, Any, Dict
|
|
import json as _json
|
|
import requests
|
|
import aiohttp
|
|
|
|
|
|
# --------- модели данных ---------
|
|
|
|
@dataclass(frozen=True)
|
|
class SynthesizeParams:
|
|
text: str
|
|
model: str
|
|
speaker_id: Optional[int] = None
|
|
rate: Optional[int] = None
|
|
noise_level: Optional[float] = None
|
|
speech_rate: Optional[float] = None
|
|
duration_noise_level: Optional[float] = None
|
|
scale: Optional[float] = None
|
|
as_wav: bool = False
|
|
|
|
|
|
# --------- ошибки ---------
|
|
|
|
class TTSApiError(RuntimeError):
|
|
pass
|
|
|
|
|
|
# --------- утилиты ---------
|
|
|
|
def _safe_json_sync(r: requests.Response) -> Any:
|
|
try:
|
|
return r.json()
|
|
except Exception:
|
|
# если пришёл не JSON, но статус 2xx — вернём сырой текст
|
|
return {"text": r.text}
|
|
|
|
def _extract_error_detail_sync(r: requests.Response) -> str:
|
|
try:
|
|
j = r.json()
|
|
if isinstance(j, dict) and "detail" in j:
|
|
return str(j["detail"])
|
|
except Exception:
|
|
pass
|
|
return r.text or "unknown error"
|
|
|
|
async def _safe_json_async(r: aiohttp.ClientResponse) -> Any:
|
|
try:
|
|
return await r.json()
|
|
except Exception:
|
|
txt = await r.text()
|
|
return {"text": txt}
|
|
|
|
def _extract_error_detail_from_text(text: str) -> str:
|
|
try:
|
|
j = _json.loads(text)
|
|
if isinstance(j, dict) and "detail" in j:
|
|
return str(j["detail"])
|
|
except Exception:
|
|
pass
|
|
return text or "unknown error"
|
|
|
|
# --------- синхронный клиент ---------
|
|
|
|
class TTSClient:
|
|
"""
|
|
Синхронный клиент. Единые точки GET/POST/REQUEST.
|
|
Сессии создаются на каждый запрос.
|
|
"""
|
|
def __init__(self, base_url: str, timeout: float = 30.0, default_headers: Optional[Dict[str, str]] = None, api_key: Optional[str] = None):
|
|
self.base_url = base_url.rstrip("/")
|
|
self.timeout = timeout
|
|
self.default_headers = default_headers or {}
|
|
if api_key:
|
|
self.default_headers["X-API-Key"] = api_key
|
|
|
|
# ---- публичное API ----
|
|
|
|
def list_models(self) -> List[str]:
|
|
data = self._get("/models")
|
|
return data["models"]
|
|
|
|
def list_voices(self, model: str) -> List[int]:
|
|
data = self._get(f"/models/{model}/voices")
|
|
return data["voices"]
|
|
|
|
def synthesize(self, params: Union[SynthesizeParams, dict]) -> bytes:
|
|
payload = params.__dict__ if isinstance(params, SynthesizeParams) else dict(params)
|
|
return self._post("/synthesize", json=payload, return_bytes=True)
|
|
|
|
def synthesize_pcm(self, **kwargs) -> bytes:
|
|
kwargs["as_wav"] = False
|
|
return self.synthesize(kwargs)
|
|
|
|
def synthesize_wav(self, **kwargs) -> bytes:
|
|
kwargs["as_wav"] = True
|
|
return self.synthesize(kwargs)
|
|
|
|
# ---- единые точки ----
|
|
|
|
def _get(self, path: str, *, headers: Optional[Dict[str, str]] = None) -> Any:
|
|
return self._request("GET", path, headers=headers)
|
|
|
|
def _post(self, path: str, *, json: Optional[dict] = None, headers: Optional[Dict[str, str]] = None, return_bytes: bool = False) -> Any:
|
|
return self._request("POST", path, json=json, headers=headers, return_bytes=return_bytes)
|
|
|
|
def _request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
json: Optional[dict] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
return_bytes: bool = False,
|
|
) -> Any:
|
|
url = f"{self.base_url}{path}"
|
|
hdrs = {**self.default_headers, **(headers or {})}
|
|
|
|
r = requests.request(method=method, url=url, json=json, headers=hdrs, timeout=self.timeout)
|
|
|
|
if 200 <= r.status_code < 300:
|
|
return r.content if return_bytes else _safe_json_sync(r)
|
|
|
|
# ошибка
|
|
detail = _extract_error_detail_sync(r)
|
|
raise TTSApiError(f"{r.status_code}: {detail}")
|
|
|
|
|
|
# --------- асинхронный клиент ---------
|
|
|
|
class TTSAioClient:
|
|
"""
|
|
Асинхронный клиент. Единые точки GET/POST/REQUEST.
|
|
Сессии создаются на каждый запрос.
|
|
"""
|
|
def __init__(self, base_url: str, timeout: float = 30.0, default_headers: Optional[Dict[str, str]] = None, api_key: Optional[str] = None):
|
|
self.base_url = base_url.rstrip("/")
|
|
self.timeout = timeout
|
|
self.default_headers = default_headers or {}
|
|
if api_key:
|
|
self.default_headers["X-API-Key"] = api_key
|
|
|
|
# ---- публичное API ----
|
|
|
|
async def list_models(self) -> List[str]:
|
|
data = await self._get("/models")
|
|
return data["models"]
|
|
|
|
async def list_voices(self, model: str) -> List[int]:
|
|
data = await self._get(f"/models/{model}/voices")
|
|
return data["voices"]
|
|
|
|
async def synthesize(self, params: Union[SynthesizeParams, dict]) -> bytes:
|
|
payload = params.__dict__ if isinstance(params, SynthesizeParams) else dict(params)
|
|
return await self._post("/synthesize", json=payload, return_bytes=True)
|
|
|
|
async def synthesize_pcm(self, **kwargs) -> bytes:
|
|
kwargs["as_wav"] = False
|
|
return await self.synthesize(kwargs)
|
|
|
|
async def synthesize_wav(self, **kwargs) -> bytes:
|
|
kwargs["as_wav"] = True
|
|
return await self.synthesize(kwargs)
|
|
|
|
# ---- единые точки ----
|
|
|
|
async def _get(self, path: str, *, headers: Optional[Dict[str, str]] = None) -> Any:
|
|
return await self._request("GET", path, headers=headers)
|
|
|
|
async def _post(self, path: str, *, json: Optional[dict] = None, headers: Optional[Dict[str, str]] = None, return_bytes: bool = False) -> Any:
|
|
return await self._request("POST", path, json=json, headers=headers, return_bytes=return_bytes)
|
|
|
|
async def _request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
json: Optional[dict] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
return_bytes: bool = False,
|
|
) -> Any:
|
|
url = f"{self.base_url}{path}"
|
|
hdrs = {**self.default_headers, **(headers or {})}
|
|
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
|
|
|
async with aiohttp.ClientSession(timeout=timeout, headers=hdrs) as s:
|
|
async with s.request(method, url, json=json) as r:
|
|
if 200 <= r.status < 300:
|
|
if return_bytes:
|
|
return await r.read()
|
|
return await _safe_json_async(r)
|
|
|
|
# ошибка
|
|
status = r.status
|
|
text = await r.text()
|
|
detail = _extract_error_detail_from_text(text)
|
|
raise TTSApiError(f"{status}: {detail}") |