Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unclosed aiohttp.ClientResponse objects #2528

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@

if TYPE_CHECKING:
import numpy as np
from aiohttp import ClientSession
from aiohttp import ClientResponse, ClientSession
from PIL.Image import Image

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -190,7 +190,7 @@ def __init__(
self.base_url = base_url

# Keep track of the sessions to close them properly
self._sessions: Set["ClientSession"] = set()
self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict()

def __repr__(self):
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
Expand Down Expand Up @@ -358,7 +358,7 @@ async def close(self):

Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`).
"""
await asyncio.gather(*[session.close() for session in self._sessions])
await asyncio.gather(*[session.close() for session in self._sessions.keys()])

async def audio_classification(
self,
Expand Down Expand Up @@ -2648,14 +2648,28 @@ def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession"
)

# Keep track of sessions to close them later
self._sessions.add(session)
self._sessions[session] = set()

# Override the 'close' method to deregister the session when closed
# Override the `._request` method to register responses to be closed
session._wrapped_request = session._request

async def _request(method, url, **kwargs):
response = await session._wrapped_request(method, url, **kwargs)
self._sessions[session].add(response)
return response

session._request = _request

# Override the 'close' method to
# 1. close ongoing responses
# 2. deregister the session when closed
session._close = session.close

async def close_session():
for response in self._sessions[session]:
response.close()
await session._close()
self._sessions.discard(session)
self._sessions.pop(session, None)

session.close = close_session
return session
Expand Down
20 changes: 20 additions & 0 deletions tests/test_inference_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,26 @@ async def test_use_async_with_inference_client():
mock_close.assert_called_once()


@pytest.mark.asyncio
@patch("aiohttp.ClientSession._request")
async def test_client_responses_correctly_closed(request_mock: Mock) -> None:
"""
Regression test for #2521.
Async client must close the ClientResponse objects when exiting the async context manager.
Fixed by closing the response objects when the session is closed.

See https://github.com/huggingface/huggingface_hub/issues/2521.
"""
async with AsyncInferenceClient() as client:
session = client._get_client_session()
response1 = await session.get("http://this-is-a-fake-url.com")
response2 = await session.post("http://this-is-a-fake-url.com", json={})

# Response objects are closed when the AsyncInferenceClient is closed
response1.close.assert_called_once()
response2.close.assert_called_once()


@pytest.mark.asyncio
async def test_warns_if_client_deleted_with_opened_sessions():
client = AsyncInferenceClient()
Expand Down
25 changes: 20 additions & 5 deletions utils/generate_async_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ async def close(self):

Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`).
\"""
await asyncio.gather(*[session.close() for session in self._sessions])"""
await asyncio.gather(*[session.close() for session in self._sessions.keys()])"""


def _make_post_async(code: str) -> str:
Expand Down Expand Up @@ -535,14 +535,28 @@ def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession"
)

# Keep track of sessions to close them later
self._sessions.add(session)
self._sessions[session] = set()

# Override the 'close' method to deregister the session when closed
# Override the `._request` method to register responses to be closed
session._wrapped_request = session._request

async def _request(method, url, **kwargs):
response = await session._wrapped_request(method, url, **kwargs)
self._sessions[session].add(response)
return response

session._request = _request

# Override the 'close' method to
# 1. close ongoing responses
# 2. deregister the session when closed
session._close = session.close

async def close_session():
for response in self._sessions[session]:
response.close()
await session._close()
self._sessions.discard(session)
self._sessions.pop(session, None)

session.close = close_session
return session
Expand All @@ -554,7 +568,8 @@ async def close_session():
code = _add_before(
code,
"\n def __repr__(self):\n",
"\n # Keep track of the sessions to close them properly\n self._sessions: Set['ClientSession']= set()",
"\n # Keep track of the sessions to close them properly"
"\n self._sessions: Dict['ClientSession', Set['ClientResponse']] = dict()",
)

return code
Expand Down
Loading