TTS_API_LIBRARY/tts_api_library/client.py

198 lines
No EOL
6.7 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Union, Any, Dict
import json as _json
import requests
import aiohttp
# --------- модели данных ---------
@dataclass(frozen=True)
class SynthesizeParams:
text: str
model: str
speaker_id: Optional[int] = None
rate: Optional[int] = None
noise_level: Optional[float] = None
speech_rate: Optional[float] = None
duration_noise_level: Optional[float] = None
scale: Optional[float] = None
as_wav: bool = False
# --------- ошибки ---------
class TTSApiError(RuntimeError):
pass
# --------- утилиты ---------
def _safe_json_sync(r: requests.Response) -> Any:
try:
return r.json()
except Exception:
# если пришёл не JSON, но статус 2xx — вернём сырой текст
return {"text": r.text}
def _extract_error_detail_sync(r: requests.Response) -> str:
try:
j = r.json()
if isinstance(j, dict) and "detail" in j:
return str(j["detail"])
except Exception:
pass
return r.text or "unknown error"
async def _safe_json_async(r: aiohttp.ClientResponse) -> Any:
try:
return await r.json()
except Exception:
txt = await r.text()
return {"text": txt}
def _extract_error_detail_from_text(text: str) -> str:
try:
j = _json.loads(text)
if isinstance(j, dict) and "detail" in j:
return str(j["detail"])
except Exception:
pass
return text or "unknown error"
# --------- синхронный клиент ---------
class TTSClient:
"""
Синхронный клиент. Единые точки GET/POST/REQUEST.
Сессии создаются на каждый запрос.
"""
def __init__(self, base_url: str, timeout: float = 30.0, default_headers: Optional[Dict[str, str]] = None, api_key: Optional[str] = None):
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.default_headers = default_headers or {}
if api_key:
self.default_headers["X-API-Key"] = api_key
# ---- публичное API ----
def list_models(self) -> List[str]:
data = self._get("/models")
return data["models"]
def list_voices(self, model: str) -> List[int]:
data = self._get(f"/models/{model}/voices")
return data["voices"]
def synthesize(self, params: Union[SynthesizeParams, dict]) -> bytes:
payload = params.__dict__ if isinstance(params, SynthesizeParams) else dict(params)
return self._post("/synthesize", json=payload, return_bytes=True)
def synthesize_pcm(self, **kwargs) -> bytes:
kwargs["as_wav"] = False
return self.synthesize(kwargs)
def synthesize_wav(self, **kwargs) -> bytes:
kwargs["as_wav"] = True
return self.synthesize(kwargs)
# ---- единые точки ----
def _get(self, path: str, *, headers: Optional[Dict[str, str]] = None) -> Any:
return self._request("GET", path, headers=headers)
def _post(self, path: str, *, json: Optional[dict] = None, headers: Optional[Dict[str, str]] = None, return_bytes: bool = False) -> Any:
return self._request("POST", path, json=json, headers=headers, return_bytes=return_bytes)
def _request(
self,
method: str,
path: str,
*,
json: Optional[dict] = None,
headers: Optional[Dict[str, str]] = None,
return_bytes: bool = False,
) -> Any:
url = f"{self.base_url}{path}"
hdrs = {**self.default_headers, **(headers or {})}
r = requests.request(method=method, url=url, json=json, headers=hdrs, timeout=self.timeout)
if 200 <= r.status_code < 300:
return r.content if return_bytes else _safe_json_sync(r)
# ошибка
detail = _extract_error_detail_sync(r)
raise TTSApiError(f"{r.status_code}: {detail}")
# --------- асинхронный клиент ---------
class TTSAioClient:
"""
Асинхронный клиент. Единые точки GET/POST/REQUEST.
Сессии создаются на каждый запрос.
"""
def __init__(self, base_url: str, timeout: float = 30.0, default_headers: Optional[Dict[str, str]] = None, api_key: Optional[str] = None):
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.default_headers = default_headers or {}
if api_key:
self.default_headers["X-API-Key"] = api_key
# ---- публичное API ----
async def list_models(self) -> List[str]:
data = await self._get("/models")
return data["models"]
async def list_voices(self, model: str) -> List[int]:
data = await self._get(f"/models/{model}/voices")
return data["voices"]
async def synthesize(self, params: Union[SynthesizeParams, dict]) -> bytes:
payload = params.__dict__ if isinstance(params, SynthesizeParams) else dict(params)
return await self._post("/synthesize", json=payload, return_bytes=True)
async def synthesize_pcm(self, **kwargs) -> bytes:
kwargs["as_wav"] = False
return await self.synthesize(kwargs)
async def synthesize_wav(self, **kwargs) -> bytes:
kwargs["as_wav"] = True
return await self.synthesize(kwargs)
# ---- единые точки ----
async def _get(self, path: str, *, headers: Optional[Dict[str, str]] = None) -> Any:
return await self._request("GET", path, headers=headers)
async def _post(self, path: str, *, json: Optional[dict] = None, headers: Optional[Dict[str, str]] = None, return_bytes: bool = False) -> Any:
return await self._request("POST", path, json=json, headers=headers, return_bytes=return_bytes)
async def _request(
self,
method: str,
path: str,
*,
json: Optional[dict] = None,
headers: Optional[Dict[str, str]] = None,
return_bytes: bool = False,
) -> Any:
url = f"{self.base_url}{path}"
hdrs = {**self.default_headers, **(headers or {})}
timeout = aiohttp.ClientTimeout(total=self.timeout)
async with aiohttp.ClientSession(timeout=timeout, headers=hdrs) as s:
async with s.request(method, url, json=json) as r:
if 200 <= r.status < 300:
if return_bytes:
return await r.read()
return await _safe_json_async(r)
# ошибка
status = r.status
text = await r.text()
detail = _extract_error_detail_from_text(text)
raise TTSApiError(f"{status}: {detail}")