[INIT]
This commit is contained in:
parent
ebfd5311fa
commit
dc8ef9b9e8
3 changed files with 226 additions and 0 deletions
18
pyproject.toml
Normal file
18
pyproject.toml
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
[project]
|
||||
name = "tts_api_library"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = [
|
||||
{name = "Evgeny (Krymmy) Momotov",email = "evgeny.momotov@gmail.com"}
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11, <3.13"
|
||||
dependencies = [
|
||||
"requests (>=2.32.5,<3.0.0)",
|
||||
"aiohttp (>=3.12.15,<4.0.0)"
|
||||
]
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
13
tts_api_library/__init__.py
Normal file
13
tts_api_library/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from .client import (
|
||||
SynthesizeParams,
|
||||
TTSApiError,
|
||||
TTSClient,
|
||||
TTSAioClient,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SynthesizeParams",
|
||||
"TTSApiError",
|
||||
"TTSClient",
|
||||
"TTSAioClient",
|
||||
]
|
||||
195
tts_api_library/client.py
Normal file
195
tts_api_library/client.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Any, Dict
|
||||
import json as _json
|
||||
import requests
|
||||
import aiohttp
|
||||
|
||||
|
||||
# --------- модели данных ---------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SynthesizeParams:
|
||||
text: str
|
||||
model: str
|
||||
speaker_id: Optional[int] = None
|
||||
rate: Optional[int] = None
|
||||
noise_level: Optional[float] = None
|
||||
speech_rate: Optional[float] = None
|
||||
duration_noise_level: Optional[float] = None
|
||||
scale: Optional[float] = None
|
||||
as_wav: bool = False
|
||||
|
||||
|
||||
# --------- ошибки ---------
|
||||
|
||||
class TTSApiError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
# --------- синхронный клиент ---------
|
||||
|
||||
class TTSClient:
|
||||
"""
|
||||
Синхронный клиент. Единые точки GET/POST/REQUEST.
|
||||
Сессии создаются на каждый запрос.
|
||||
"""
|
||||
def __init__(self, base_url: str, timeout: float = 30.0, default_headers: Optional[Dict[str, str]] = None):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.default_headers = default_headers or {}
|
||||
|
||||
# ---- публичное API ----
|
||||
|
||||
def list_models(self) -> List[str]:
|
||||
data = self._get("/models")
|
||||
return data["models"]
|
||||
|
||||
def list_voices(self, model: str) -> List[int]:
|
||||
data = self._get(f"/models/{model}/voices")
|
||||
return data["voices"]
|
||||
|
||||
def synthesize(self, params: Union[SynthesizeParams, dict]) -> bytes:
|
||||
payload = params.__dict__ if isinstance(params, SynthesizeParams) else dict(params)
|
||||
return self._post("/synthesize", json=payload, return_bytes=True)
|
||||
|
||||
def synthesize_pcm(self, **kwargs) -> bytes:
|
||||
kwargs["as_wav"] = False
|
||||
return self.synthesize(kwargs)
|
||||
|
||||
def synthesize_wav(self, **kwargs) -> bytes:
|
||||
kwargs["as_wav"] = True
|
||||
return self.synthesize(kwargs)
|
||||
|
||||
# ---- единые точки ----
|
||||
|
||||
def _get(self, path: str, *, headers: Optional[Dict[str, str]] = None) -> Any:
|
||||
return self._request("GET", path, headers=headers)
|
||||
|
||||
def _post(self, path: str, *, json: Optional[dict] = None, headers: Optional[Dict[str, str]] = None, return_bytes: bool = False) -> Any:
|
||||
return self._request("POST", path, json=json, headers=headers, return_bytes=return_bytes)
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
json: Optional[dict] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
return_bytes: bool = False,
|
||||
) -> Any:
|
||||
url = f"{self.base_url}{path}"
|
||||
hdrs = {**self.default_headers, **(headers or {})}
|
||||
|
||||
r = requests.request(method=method, url=url, json=json, headers=hdrs, timeout=self.timeout)
|
||||
|
||||
if 200 <= r.status_code < 300:
|
||||
return r.content if return_bytes else _safe_json_sync(r)
|
||||
|
||||
# ошибка
|
||||
detail = _extract_error_detail_sync(r)
|
||||
raise TTSApiError(f"{r.status_code}: {detail}")
|
||||
|
||||
|
||||
# --------- асинхронный клиент ---------
|
||||
|
||||
class TTSAioClient:
|
||||
"""
|
||||
Асинхронный клиент. Единые точки GET/POST/REQUEST.
|
||||
Сессии создаются на каждый запрос.
|
||||
"""
|
||||
def __init__(self, base_url: str, timeout: float = 30.0, default_headers: Optional[Dict[str, str]] = None):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.default_headers = default_headers or {}
|
||||
|
||||
# ---- публичное API ----
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
data = await self._get("/models")
|
||||
return data["models"]
|
||||
|
||||
async def list_voices(self, model: str) -> List[int]:
|
||||
data = await self._get(f"/models/{model}/voices")
|
||||
return data["voices"]
|
||||
|
||||
async def synthesize(self, params: Union[SynthesizeParams, dict]) -> bytes:
|
||||
payload = params.__dict__ if isinstance(params, SynthesizeParams) else dict(params)
|
||||
return await self._post("/synthesize", json=payload, return_bytes=True)
|
||||
|
||||
async def synthesize_pcm(self, **kwargs) -> bytes:
|
||||
kwargs["as_wav"] = False
|
||||
return await self.synthesize(kwargs)
|
||||
|
||||
async def synthesize_wav(self, **kwargs) -> bytes:
|
||||
kwargs["as_wav"] = True
|
||||
return await self.synthesize(kwargs)
|
||||
|
||||
# ---- единые точки ----
|
||||
|
||||
async def _get(self, path: str, *, headers: Optional[Dict[str, str]] = None) -> Any:
|
||||
return await self._request("GET", path, headers=headers)
|
||||
|
||||
async def _post(self, path: str, *, json: Optional[dict] = None, headers: Optional[Dict[str, str]] = None, return_bytes: bool = False) -> Any:
|
||||
return await self._request("POST", path, json=json, headers=headers, return_bytes=return_bytes)
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
json: Optional[dict] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
return_bytes: bool = False,
|
||||
) -> Any:
|
||||
url = f"{self.base_url}{path}"
|
||||
hdrs = {**self.default_headers, **(headers or {})}
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout, headers=hdrs) as s:
|
||||
async with s.request(method, url, json=json) as r:
|
||||
if 200 <= r.status < 300:
|
||||
if return_bytes:
|
||||
return await r.read()
|
||||
return await _safe_json_async(r)
|
||||
|
||||
# ошибка
|
||||
status = r.status
|
||||
text = await r.text()
|
||||
detail = _extract_error_detail_from_text(text)
|
||||
raise TTSApiError(f"{status}: {detail}")
|
||||
|
||||
|
||||
# --------- утилиты ---------
|
||||
|
||||
def _safe_json_sync(r: requests.Response) -> Any:
|
||||
try:
|
||||
return r.json()
|
||||
except Exception:
|
||||
# если пришёл не JSON, но статус 2xx — вернём сырой текст
|
||||
return {"text": r.text}
|
||||
|
||||
def _extract_error_detail_sync(r: requests.Response) -> str:
|
||||
try:
|
||||
j = r.json()
|
||||
if isinstance(j, dict) and "detail" in j:
|
||||
return str(j["detail"])
|
||||
except Exception:
|
||||
pass
|
||||
return r.text or "unknown error"
|
||||
|
||||
async def _safe_json_async(r: aiohttp.ClientResponse) -> Any:
|
||||
try:
|
||||
return await r.json()
|
||||
except Exception:
|
||||
txt = await r.text()
|
||||
return {"text": txt}
|
||||
|
||||
def _extract_error_detail_from_text(text: str) -> str:
|
||||
try:
|
||||
j = _json.loads(text)
|
||||
if isinstance(j, dict) and "detail" in j:
|
||||
return str(j["detail"])
|
||||
except Exception:
|
||||
pass
|
||||
return text or "unknown error"
|
||||
Loading…
Add table
Reference in a new issue