from __future__ import annotations
import asyncio
import time
from functools import partial
from os import environ
from typing import TYPE_CHECKING, Any, cast
from warnings import warn
import aiohttp
from tenacity import AsyncRetrying
from zyte_api._x402 import _x402Handler
from zyte_api.apikey import NoApiKey
from ._errors import RequestError
from ._retry import zyte_api_retrying
from ._utils import _AIO_API_TIMEOUT, create_session
from .constants import API_URL
from .constants import ENV_VARIABLE as API_KEY_ENV_VAR
from .stats import AggStats, ResponseStats
from .utils import USER_AGENT, _process_query
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator
from contextlib import AbstractAsyncContextManager
from eth_account.signers.local import LocalAccount
# typing.Self requires Python 3.11
from typing_extensions import Self
_ResponseFuture = Awaitable[dict[str, Any]]
def _post_func(
session: aiohttp.ClientSession | None,
) -> Callable[..., AbstractAsyncContextManager[aiohttp.ClientResponse]]:
"""Return a function to send a POST request"""
if session is None:
return partial(aiohttp.request, method="POST", timeout=_AIO_API_TIMEOUT)
return session.post
class _AsyncSession:
def __init__(self, client: AsyncZyteAPI, **session_kwargs: Any):
self._client: AsyncZyteAPI = client
session_kwargs.setdefault("trust_env", client.trust_env)
self._session: aiohttp.ClientSession = create_session(
client.n_conn, **session_kwargs
)
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, *exc_info: object) -> None:
await self._session.close()
async def close(self) -> None:
await self._session.close()
async def get(
self,
query: dict[str, Any],
*,
endpoint: str = "extract",
handle_retries: bool = True,
retrying: AsyncRetrying | None = None,
) -> dict[str, Any]:
return await self._client.get(
query=query,
endpoint=endpoint,
handle_retries=handle_retries,
retrying=retrying,
session=self._session,
)
def iter(
self,
queries: list[dict[str, Any]],
*,
endpoint: str = "extract",
handle_retries: bool = True,
retrying: AsyncRetrying | None = None,
) -> Iterator[_ResponseFuture]:
return self._client.iter(
queries=queries,
endpoint=endpoint,
session=self._session,
handle_retries=handle_retries,
retrying=retrying,
)
class AuthInfo:
def __init__(self, *, _auth: str | _x402Handler):
self._auth: str | _x402Handler = _auth
@property
def key(self) -> str:
if isinstance(self._auth, str):
return self._auth
return cast("LocalAccount", self._auth.client.account).key.hex()
@property
def type(self) -> str:
if isinstance(self._auth, str):
return "zyte"
return "eth"
[docs]
class AsyncZyteAPI:
""":ref:`Asynchronous Zyte API client <asyncio_api>`.
Parameters work the same as for :class:`ZyteAPI`.
"""
def __init__(
self,
*,
api_key: str | None = None,
api_url: str | None = None,
n_conn: int = 15,
retrying: AsyncRetrying | None = None,
user_agent: str | None = None,
eth_key: str | None = None,
trust_env: bool = False,
):
if retrying is not None and not isinstance(retrying, AsyncRetrying):
raise ValueError(
"The retrying parameter, if defined, must be an instance of "
"AsyncRetrying."
)
self.n_conn = n_conn
self.agg_stats = AggStats()
self.retrying = retrying or zyte_api_retrying
self.user_agent = user_agent or USER_AGENT
self.trust_env = trust_env
self._semaphore = asyncio.Semaphore(n_conn)
self._auth: str | _x402Handler
self.auth: AuthInfo
self.api_url: str
self._load_auth(api_key, eth_key, api_url)
def _load_auth(
self, api_key: str | None, eth_key: str | None, api_url: str | None
) -> None:
if api_key:
self._auth = api_key
elif eth_key:
self._auth = _x402Handler(eth_key, self._semaphore, self.agg_stats)
elif api_key := environ.get(API_KEY_ENV_VAR):
self._auth = api_key
elif eth_key := environ.get("ZYTE_API_ETH_KEY"):
self._auth = _x402Handler(eth_key, self._semaphore, self.agg_stats)
else:
raise NoApiKey(
"You must provide either a Zyte API key or an Ethereum "
"private key. For the latter, you must also install "
"zyte-api as zyte-api[x402]."
)
self.auth = AuthInfo(_auth=self._auth)
self.api_url = (
api_url
if api_url is not None
else "https://api-x402.zyte.com/v1/"
if self.auth.type == "eth"
else API_URL
)
@property
def api_key(self) -> str:
if isinstance(self._auth, str):
warn(
"The api_key property is deprecated, use auth.key instead.",
DeprecationWarning,
stacklevel=2,
)
return self._auth
raise NotImplementedError(
"api_key is not available when using an Ethereum private key, use auth.key instead."
)
[docs]
async def get(
self,
query: dict[str, Any],
*,
endpoint: str = "extract",
session: aiohttp.ClientSession | None = None,
handle_retries: bool = True,
retrying: AsyncRetrying | None = None,
) -> dict[str, Any]:
"""Asynchronous equivalent to :meth:`ZyteAPI.get`."""
retrying = retrying or self.retrying
owned_session: aiohttp.ClientSession | None = None
if session is None:
owned_session = create_session(self.n_conn, trust_env=self.trust_env)
session = owned_session
post = _post_func(session)
url = self.api_url + endpoint
query = _process_query(query)
headers = {"User-Agent": self.user_agent, "Accept-Encoding": "br"}
if isinstance(self._auth, str):
if hasattr(aiohttp, "encode_basic_auth"): # aiohttp 3.14+
headers["Authorization"] = aiohttp.encode_basic_auth(self._auth, "")
else:
headers["Authorization"] = aiohttp.BasicAuth(self._auth).encode()
else:
x402_headers = await self._auth.get_headers(url, query, headers, post)
headers.update(x402_headers)
post_kwargs = {
"url": url,
"json": query,
"headers": headers,
}
response_stats = []
start_global = time.perf_counter()
async def request() -> dict[str, Any]:
stats = ResponseStats.create(start_global)
self.agg_stats.n_attempts += 1
try:
async with self._semaphore, post(**post_kwargs) as resp:
stats.record_connected(resp.status, self.agg_stats)
if (
resp.status == 402
and isinstance(self._auth, _x402Handler)
and "X-Payment" in post_kwargs["headers"]
):
self._auth.refresh_post_kwargs(post_kwargs, await resp.json())
if resp.status >= 400:
content = await resp.read()
resp.release()
stats.record_read()
stats.record_request_error(content, self.agg_stats)
raise RequestError(
request_info=resp.request_info,
history=resp.history,
status=resp.status,
message=resp.reason,
headers=resp.headers,
response_content=content,
query=query,
)
response = cast("dict[str, Any]", await resp.json())
stats.record_read(self.agg_stats)
return response
except Exception as e:
if not isinstance(e, RequestError):
self.agg_stats.n_errors += 1
stats.record_exception(e, agg_stats=self.agg_stats)
raise
finally:
response_stats.append(stats)
if handle_retries:
request = retrying.wraps(request)
try:
try:
# Try to make a request
result = await request()
self.agg_stats.n_success += 1
except Exception:
self.agg_stats.n_fatal_errors += 1
raise
return result
finally:
if owned_session is not None:
await owned_session.close()
[docs]
def iter(
self,
queries: list[dict[str, Any]],
*,
endpoint: str = "extract",
session: aiohttp.ClientSession | None = None,
handle_retries: bool = True,
retrying: AsyncRetrying | None = None,
) -> Iterator[_ResponseFuture]:
"""Asynchronous equivalent to :meth:`ZyteAPI.iter`.
.. note:: Yielded futures, when awaited, do raise their exceptions,
instead of only returning them.
"""
def _request(query: dict[str, Any]) -> _ResponseFuture:
return self.get(
query,
endpoint=endpoint,
session=session,
handle_retries=handle_retries,
retrying=retrying,
)
return asyncio.as_completed([_request(query) for query in queries])
[docs]
def session(self, **kwargs: Any) -> _AsyncSession:
"""Asynchronous equivalent to :meth:`ZyteAPI.session`.
You do not need to use :meth:`~AsyncZyteAPI.session` as an async
context manager as long as you await ``close()`` on the object it
returns when you are done:
.. code-block:: python
session = client.session()
try:
...
finally:
await session.close()
"""
return _AsyncSession(client=self, **kwargs)