Skip to content

Commit

Permalink
Implement async with x operations for discord.ext.ipcx.Client (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
No767 authored Dec 7, 2024
1 parent 1333e1e commit f3b2a62
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
1 change: 1 addition & 0 deletions changelog.d/56.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement `async with x` operations for `discord.ext.ipcx.Client`
35 changes: 28 additions & 7 deletions discord/ext/ipcx/client.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
from __future__ import annotations

import asyncio
import logging
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

import aiohttp

from .errors import NotConnectedError

if TYPE_CHECKING:
from types import TracebackType

from typing_extensions import Self

log = logging.getLogger(__name__)


class Client:
"""
Handles webserver side requests to the bot process.
Operations with ``async with`` will automatically initialize the client and automatically cleans up.
Parameters
----------
host: str
Expand All @@ -38,14 +47,31 @@ def __init__(
self.port = port
self.multicast_port = multicast_port

self.session = None
self.session: Optional[aiohttp.ClientSession] = None

@property
def url(self):
return "ws://{0.host}:{1}".format(
self, self.port if self.port else self.multicast_port
)

async def __aenter__(self) -> Self:
await self._get_session()
return self

async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.close()

async def _get_session(self) -> aiohttp.ClientSession:
if not self.session:
self.session = aiohttp.ClientSession()
return self.session

async def close(self) -> None:
"""Properly closes the :class:`aiohttp.ClientSession` session used for connections
Expand All @@ -58,11 +84,6 @@ async def close(self) -> None:
if self.session:
await self.session.close()

async def _get_session(self) -> aiohttp.ClientSession:
if not self.session:
self.session = aiohttp.ClientSession()
return self.session

async def get_port(self) -> int:
"""Attempts to obtain the provided port.
Expand Down

0 comments on commit f3b2a62

Please sign in to comment.