Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Honor max concurrent streams #89

Merged
merged 8 commits into from
May 14, 2020
Merged
45 changes: 32 additions & 13 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from h2.exceptions import NoAvailableStreamIDError
from h2.settings import SettingCodes, Settings

from .._backends.auto import AsyncLock, AsyncSocketStream, AutoBackend
from .._exceptions import ProtocolError
from .._backends.auto import AsyncLock, AsyncSemaphore, AsyncSocketStream, AutoBackend
from .._exceptions import PoolTimeout, ProtocolError
from .._types import URL, Headers, TimeoutDict
from .._utils import get_logger
from .base import (
Expand Down Expand Up @@ -67,6 +67,17 @@ def read_lock(self) -> AsyncLock:
self._read_lock = self.backend.create_lock()
return self._read_lock

@property
def max_streams_semaphore(self) -> AsyncSemaphore:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_max_streams_semaphore"):
max_streams = self.h2_state.remote_settings.max_concurrent_streams
self._max_streams_semaphore = self.backend.create_semaphore(
max_streams, exc_class=PoolTimeout
)
return self._max_streams_semaphore

async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
pass

Expand Down Expand Up @@ -265,16 +276,21 @@ async def request(
b"content-length" in seen_headers or b"transfer-encoding" in seen_headers
)

await self.send_headers(method, url, headers, has_body, timeout)
if has_body:
await self.send_body(stream, timeout)

# Receive the response.
status_code, headers = await self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = AsyncByteStream(
aiterator=self.body_iter(timeout), aclose_func=self._response_closed
)
await self.connection.max_streams_semaphore.acquire()
yeraydiazdiaz marked this conversation as resolved.
Show resolved Hide resolved
try:
await self.send_headers(method, url, headers, has_body, timeout)
if has_body:
await self.send_body(stream, timeout)

# Receive the response.
status_code, headers = await self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = AsyncByteStream(
aiterator=self.body_iter(timeout), aclose_func=self._response_closed
)
except Exception:
self.connection.max_streams_semaphore.release()
raise

return (b"HTTP/2", status_code, reason_phrase, headers, stream)

Expand Down Expand Up @@ -346,4 +362,7 @@ async def body_iter(self, timeout: TimeoutDict) -> AsyncIterator[bytes]:
break

async def _response_closed(self) -> None:
await self.connection.close_stream(self.stream_id)
try:
await self.connection.close_stream(self.stream_id)
finally:
self.connection.max_streams_semaphore.release()
45 changes: 32 additions & 13 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from h2.exceptions import NoAvailableStreamIDError
from h2.settings import SettingCodes, Settings

from .._backends.auto import SyncLock, SyncSocketStream, SyncBackend
from .._exceptions import ProtocolError
from .._backends.auto import SyncLock, SyncSemaphore, SyncSocketStream, SyncBackend
from .._exceptions import PoolTimeout, ProtocolError
from .._types import URL, Headers, TimeoutDict
from .._utils import get_logger
from .base import (
Expand Down Expand Up @@ -67,6 +67,17 @@ def read_lock(self) -> SyncLock:
self._read_lock = self.backend.create_lock()
return self._read_lock

@property
def max_streams_semaphore(self) -> SyncSemaphore:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_max_streams_semaphore"):
max_streams = self.h2_state.remote_settings.max_concurrent_streams
self._max_streams_semaphore = self.backend.create_semaphore(
max_streams, exc_class=PoolTimeout
)
return self._max_streams_semaphore

def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
pass

Expand Down Expand Up @@ -265,16 +276,21 @@ def request(
b"content-length" in seen_headers or b"transfer-encoding" in seen_headers
)

self.send_headers(method, url, headers, has_body, timeout)
if has_body:
self.send_body(stream, timeout)

# Receive the response.
status_code, headers = self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = SyncByteStream(
iterator=self.body_iter(timeout), close_func=self._response_closed
)
self.connection.max_streams_semaphore.acquire()
try:
self.send_headers(method, url, headers, has_body, timeout)
if has_body:
self.send_body(stream, timeout)

# Receive the response.
status_code, headers = self.receive_response(timeout)
reason_phrase = get_reason_phrase(status_code)
stream = SyncByteStream(
iterator=self.body_iter(timeout), close_func=self._response_closed
)
except Exception:
self.connection.max_streams_semaphore.release()
raise

return (b"HTTP/2", status_code, reason_phrase, headers, stream)

Expand Down Expand Up @@ -346,4 +362,7 @@ def body_iter(self, timeout: TimeoutDict) -> Iterator[bytes]:
break

def _response_closed(self) -> None:
self.connection.close_stream(self.stream_id)
try:
self.connection.close_stream(self.stream_id)
finally:
self.connection.max_streams_semaphore.release()