diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index fd7343ea09..98d30b4971 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -97,7 +97,7 @@ if TYPE_CHECKING: import numpy as np - from aiohttp import ClientSession + from aiohttp import ClientResponse, ClientSession from PIL.Image import Image logger = logging.getLogger(__name__) @@ -190,7 +190,7 @@ def __init__( self.base_url = base_url # Keep track of the sessions to close them properly - self._sessions: Set["ClientSession"] = set() + self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict() def __repr__(self): return f"" @@ -358,7 +358,7 @@ async def close(self): Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`). """ - await asyncio.gather(*[session.close() for session in self._sessions]) + await asyncio.gather(*[session.close() for session in self._sessions.keys()]) async def audio_classification( self, @@ -2648,14 +2648,28 @@ def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession" ) # Keep track of sessions to close them later - self._sessions.add(session) + self._sessions[session] = set() - # Override the 'close' method to deregister the session when closed + # Override the `._request` method to register responses to be closed + session._wrapped_request = session._request + + async def _request(method, url, **kwargs): + response = await session._wrapped_request(method, url, **kwargs) + self._sessions[session].add(response) + return response + + session._request = _request + + # Override the 'close' method to + # 1. close ongoing responses + # 2. deregister the session when closed session._close = session.close async def close_session(): + for response in self._sessions[session]: + response.close() await session._close() - self._sessions.discard(session) + self._sessions.pop(session, None) session.close = close_session return session diff --git a/tests/test_inference_async_client.py b/tests/test_inference_async_client.py index 13922e4f89..8d8b90bbf2 100644 --- a/tests/test_inference_async_client.py +++ b/tests/test_inference_async_client.py @@ -451,6 +451,26 @@ async def test_use_async_with_inference_client(): mock_close.assert_called_once() +@pytest.mark.asyncio +@patch("aiohttp.ClientSession._request") +async def test_client_responses_correctly_closed(request_mock: Mock) -> None: + """ + Regression test for #2521. + Async client must close the ClientResponse objects when exiting the async context manager. + Fixed by closing the response objects when the session is closed. + + See https://github.com/huggingface/huggingface_hub/issues/2521. + """ + async with AsyncInferenceClient() as client: + session = client._get_client_session() + response1 = await session.get("http://this-is-a-fake-url.com") + response2 = await session.post("http://this-is-a-fake-url.com", json={}) + + # Response objects are closed when the AsyncInferenceClient is closed + response1.close.assert_called_once() + response2.close.assert_called_once() + + @pytest.mark.asyncio async def test_warns_if_client_deleted_with_opened_sessions(): client = AsyncInferenceClient() diff --git a/utils/generate_async_inference_client.py b/utils/generate_async_inference_client.py index 4d607f7f8c..17818d1858 100644 --- a/utils/generate_async_inference_client.py +++ b/utils/generate_async_inference_client.py @@ -272,7 +272,7 @@ async def close(self): Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`). \""" - await asyncio.gather(*[session.close() for session in self._sessions])""" + await asyncio.gather(*[session.close() for session in self._sessions.keys()])""" def _make_post_async(code: str) -> str: @@ -535,14 +535,28 @@ def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession" ) # Keep track of sessions to close them later - self._sessions.add(session) + self._sessions[session] = set() - # Override the 'close' method to deregister the session when closed + # Override the `._request` method to register responses to be closed + session._wrapped_request = session._request + + async def _request(method, url, **kwargs): + response = await session._wrapped_request(method, url, **kwargs) + self._sessions[session].add(response) + return response + + session._request = _request + + # Override the 'close' method to + # 1. close ongoing responses + # 2. deregister the session when closed session._close = session.close async def close_session(): + for response in self._sessions[session]: + response.close() await session._close() - self._sessions.discard(session) + self._sessions.pop(session, None) session.close = close_session return session @@ -554,7 +568,8 @@ async def close_session(): code = _add_before( code, "\n def __repr__(self):\n", - "\n # Keep track of the sessions to close them properly\n self._sessions: Set['ClientSession']= set()", + "\n # Keep track of the sessions to close them properly" + "\n self._sessions: Dict['ClientSession', Set['ClientResponse']] = dict()", ) return code