diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index 9b3ce2d8..c1b0a5ce 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -8,8 +8,8 @@ from h2.exceptions import NoAvailableStreamIDError from h2.settings import SettingCodes, Settings -from .._backends.auto import AsyncLock, AsyncSocketStream, AutoBackend -from .._exceptions import ProtocolError +from .._backends.auto import AsyncLock, AsyncSemaphore, AsyncSocketStream, AutoBackend +from .._exceptions import PoolTimeout, ProtocolError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger from .base import ( @@ -67,6 +67,17 @@ def read_lock(self) -> AsyncLock: self._read_lock = self.backend.create_lock() return self._read_lock + @property + def max_streams_semaphore(self) -> AsyncSemaphore: + # We do this lazily, to make sure backend autodetection always + # runs within an async context. + if not hasattr(self, "_max_streams_semaphore"): + max_streams = self.h2_state.remote_settings.max_concurrent_streams + self._max_streams_semaphore = self.backend.create_semaphore( + max_streams, exc_class=PoolTimeout + ) + return self._max_streams_semaphore + async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None: pass @@ -265,16 +276,21 @@ async def request( b"content-length" in seen_headers or b"transfer-encoding" in seen_headers ) - await self.send_headers(method, url, headers, has_body, timeout) - if has_body: - await self.send_body(stream, timeout) - - # Receive the response. - status_code, headers = await self.receive_response(timeout) - reason_phrase = get_reason_phrase(status_code) - stream = AsyncByteStream( - aiterator=self.body_iter(timeout), aclose_func=self._response_closed - ) + await self.connection.max_streams_semaphore.acquire() + try: + await self.send_headers(method, url, headers, has_body, timeout) + if has_body: + await self.send_body(stream, timeout) + + # Receive the response. + status_code, headers = await self.receive_response(timeout) + reason_phrase = get_reason_phrase(status_code) + stream = AsyncByteStream( + aiterator=self.body_iter(timeout), aclose_func=self._response_closed + ) + except Exception: + self.connection.max_streams_semaphore.release() + raise return (b"HTTP/2", status_code, reason_phrase, headers, stream) @@ -346,4 +362,7 @@ async def body_iter(self, timeout: TimeoutDict) -> AsyncIterator[bytes]: break async def _response_closed(self) -> None: - await self.connection.close_stream(self.stream_id) + try: + await self.connection.close_stream(self.stream_id) + finally: + self.connection.max_streams_semaphore.release() diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 47039235..35213600 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -8,8 +8,8 @@ from h2.exceptions import NoAvailableStreamIDError from h2.settings import SettingCodes, Settings -from .._backends.auto import SyncLock, SyncSocketStream, SyncBackend -from .._exceptions import ProtocolError +from .._backends.auto import SyncLock, SyncSemaphore, SyncSocketStream, SyncBackend +from .._exceptions import PoolTimeout, ProtocolError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger from .base import ( @@ -67,6 +67,17 @@ def read_lock(self) -> SyncLock: self._read_lock = self.backend.create_lock() return self._read_lock + @property + def max_streams_semaphore(self) -> SyncSemaphore: + # We do this lazily, to make sure backend autodetection always + # runs within an async context. + if not hasattr(self, "_max_streams_semaphore"): + max_streams = self.h2_state.remote_settings.max_concurrent_streams + self._max_streams_semaphore = self.backend.create_semaphore( + max_streams, exc_class=PoolTimeout + ) + return self._max_streams_semaphore + def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None: pass @@ -265,16 +276,21 @@ def request( b"content-length" in seen_headers or b"transfer-encoding" in seen_headers ) - self.send_headers(method, url, headers, has_body, timeout) - if has_body: - self.send_body(stream, timeout) - - # Receive the response. - status_code, headers = self.receive_response(timeout) - reason_phrase = get_reason_phrase(status_code) - stream = SyncByteStream( - iterator=self.body_iter(timeout), close_func=self._response_closed - ) + self.connection.max_streams_semaphore.acquire() + try: + self.send_headers(method, url, headers, has_body, timeout) + if has_body: + self.send_body(stream, timeout) + + # Receive the response. + status_code, headers = self.receive_response(timeout) + reason_phrase = get_reason_phrase(status_code) + stream = SyncByteStream( + iterator=self.body_iter(timeout), close_func=self._response_closed + ) + except Exception: + self.connection.max_streams_semaphore.release() + raise return (b"HTTP/2", status_code, reason_phrase, headers, stream) @@ -346,4 +362,7 @@ def body_iter(self, timeout: TimeoutDict) -> Iterator[bytes]: break def _response_closed(self) -> None: - self.connection.close_stream(self.stream_id) + try: + self.connection.close_stream(self.stream_id) + finally: + self.connection.max_streams_semaphore.release()