diff --git a/changelog/1132.breaking.0.rst b/changelog/1132.breaking.0.rst new file mode 100644 index 0000000000..c8ef9bbf26 --- /dev/null +++ b/changelog/1132.breaking.0.rst @@ -0,0 +1 @@ +Removed ``loop`` and ``asyncio_debug`` parameters from :class:`Client`. diff --git a/changelog/1132.breaking.1.rst b/changelog/1132.breaking.1.rst new file mode 100644 index 0000000000..11893430a2 --- /dev/null +++ b/changelog/1132.breaking.1.rst @@ -0,0 +1 @@ +The majority of the library now assumes that there is an asyncio event loop running. diff --git a/changelog/1132.deprecate.rst b/changelog/1132.deprecate.rst new file mode 100644 index 0000000000..e8134a17e7 --- /dev/null +++ b/changelog/1132.deprecate.rst @@ -0,0 +1 @@ +Deprecated :attr:`Client.loop`. Use :func:`asyncio.get_running_loop` instead. diff --git a/changelog/1132.feature.rst b/changelog/1132.feature.rst new file mode 100644 index 0000000000..32470148e2 --- /dev/null +++ b/changelog/1132.feature.rst @@ -0,0 +1 @@ +Add :meth:`Client.setup_hook`. diff --git a/changelog/1132.misc.rst b/changelog/1132.misc.rst new file mode 100644 index 0000000000..9273c50aab --- /dev/null +++ b/changelog/1132.misc.rst @@ -0,0 +1 @@ +:meth:`Client.run` now uses :func:`asyncio.run` under-the-hood, instead of custom runner logic. diff --git a/changelog/641.breaking.0.rst b/changelog/641.breaking.0.rst new file mode 100644 index 0000000000..b998f47a72 --- /dev/null +++ b/changelog/641.breaking.0.rst @@ -0,0 +1 @@ +|commands| Make :meth:`.ext.commands.Bot.load_extensions`, :meth:`.ext.commands.Bot.load_extension`, :meth:`.ext.commands.Bot.unload_extension`, :meth:`.ext.commands.Bot.reload_extension`, :meth:`.ext.commands.Bot.add_cog` and :meth:`.ext.commands.Bot.remove_cog` asynchronous. diff --git a/changelog/641.breaking.1.rst b/changelog/641.breaking.1.rst new file mode 100644 index 0000000000..9b3a0fb2fa --- /dev/null +++ b/changelog/641.breaking.1.rst @@ -0,0 +1 @@ +|commands| :meth:`.ext.commands.Cog.cog_load` is now called *after* the cog finished loading. diff --git a/changelog/641.feature.0.rst b/changelog/641.feature.0.rst new file mode 100644 index 0000000000..ee9e11b621 --- /dev/null +++ b/changelog/641.feature.0.rst @@ -0,0 +1 @@ +|commands| :meth:`.ext.commands.Cog.cog_load` and :meth:`.ext.commands.Cog.cog_unload` can now be either asynchronous or not. diff --git a/changelog/641.feature.1.rst b/changelog/641.feature.1.rst new file mode 100644 index 0000000000..7a56ba23fe --- /dev/null +++ b/changelog/641.feature.1.rst @@ -0,0 +1 @@ +|commands| The ``setup`` and ``teardown`` functions utilized by :ref:`ext_commands_extensions` can now be asynchronous. diff --git a/disnake/client.py b/disnake/client.py index 0fed3eb8e4..c8b229839a 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -4,7 +4,6 @@ import asyncio import logging -import signal import sys import traceback import types @@ -82,7 +81,7 @@ from .widget import Widget if TYPE_CHECKING: - from typing_extensions import NotRequired + from typing_extensions import Never, NotRequired from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime from .app_commands import APIApplicationCommand, MessageCommand, SlashCommand, UserCommand @@ -113,41 +112,6 @@ _log = logging.getLogger(__name__) -def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: - tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} - - if not tasks: - return - - _log.info("Cleaning up after %d tasks.", len(tasks)) - for task in tasks: - task.cancel() - - loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - _log.info("All tasks finished cancelling.") - - for task in tasks: - if task.cancelled(): - continue - if task.exception() is not None: - loop.call_exception_handler( - { - "message": "Unhandled exception during Client.run shutdown.", - "exception": task.exception(), - "task": task, - } - ) - - -def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: - try: - _cancel_tasks(loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - _log.info("Closing the event loop.") - loop.close() - - class SessionStartLimit: """A class that contains information about the current session start limit, at the time when the client connected for the first time. @@ -237,13 +201,6 @@ class Client: .. versionchanged:: 1.3 Allow disabling the message cache and change the default size to ``1000``. - loop: Optional[:class:`asyncio.AbstractEventLoop`] - The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations. - Defaults to ``None``, in which case the default event loop is used via - :func:`asyncio.get_event_loop()`. - asyncio_debug: :class:`bool` - Whether to enable asyncio debugging when the client starts. - Defaults to False. connector: Optional[:class:`aiohttp.BaseConnector`] The connector to use for connection pooling. proxy: Optional[:class:`str`] @@ -361,8 +318,6 @@ class Client: ---------- ws The websocket gateway the client is currently connected to. Could be ``None``. - loop: :class:`asyncio.AbstractEventLoop` - The event loop that the client uses for asynchronous operations. session_start_limit: Optional[:class:`SessionStartLimit`] Information about the current session start limit. Only available after initiating the connection. @@ -378,8 +333,6 @@ class Client: def __init__( self, *, - asyncio_debug: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, shard_id: Optional[int] = None, shard_count: Optional[int] = None, enable_debug_events: bool = False, @@ -405,23 +358,27 @@ def __init__( # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore - if loop is None: - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() - else: - self.loop: asyncio.AbstractEventLoop = loop - - self.loop.set_debug(asyncio_debug) self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} self.session_start_limit: Optional[SessionStartLimit] = None + if connector: + try: + asyncio.get_running_loop() + except RuntimeError: + raise RuntimeError( + ( + "`connector` was created outside of an asyncio loop, which will likely cause" + "issues later down the line due to the client and `connector` running on" + "different asyncio loops; consider moving client instantiation to an '`async" + "main`' function and then manually asyncio.run it" + ) + ) from None + self.http: HTTPClient = HTTPClient( connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=assume_unsync_clock, - loop=self.loop, ) self._handlers: Dict[str, Callable] = { @@ -504,7 +461,6 @@ def _get_state( handlers=self._handlers, hooks=self._hooks, http=self.http, - loop=self.loop, max_messages=max_messages, application_id=application_id, heartbeat_timeout=heartbeat_timeout, @@ -525,6 +481,28 @@ def _handle_first_connect(self) -> None: return self._first_connect.set() + @property + def loop(self): + """:class:`asyncio.AbstractEventLoop`: Same as :func:`asyncio.get_running_loop`. + + .. deprecated:: 3.0 + Use :func:`asyncio.get_running_loop` directly. + """ + warnings.warn( + "Accessing `Client.loop` is deprecated. Use `asyncio.get_running_loop()` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + return asyncio.get_running_loop() + + @loop.setter + def loop(self, _value: Never) -> None: + warnings.warn( + "Setting `Client.loop` is deprecated and has no effect. Use `asyncio.get_running_loop()` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + @property def latency(self) -> float: """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. @@ -1015,12 +993,34 @@ async def before_identify_hook(self, shard_id: Optional[int], *, initial: bool = if not initial: await asyncio.sleep(5.0) + async def setup_hook(self) -> None: + """A hook that allows you to perform asynchronous setup like + initiating database connections or loading cogs/extensions after + the bot is logged in but before it has connected to the websocket. + + This is only called once, in :meth:`.login`, before any events are + dispatched, making it a better solution than doing such setup in + the :func:`disnake.on_ready` event. + + .. warning:: + Since this is called *before* the websocket connection is made, + anything that waits for the websocket will deadlock, which includes + methods like :meth:`.wait_for`, :meth:`.wait_until_ready` + and :meth:`.wait_until_first_connect`. + + .. versionadded:: 3.0 + """ + # login state management async def login(self, token: str) -> None: """|coro| - Logs in the client with the specified credentials. + Logs in the client with the specified credentials and calls + :meth:`.setup_hook`. + + .. versionchanged:: 3.0 + Now also calls :meth:`.setup_hook`. Parameters ---------- @@ -1044,6 +1044,8 @@ async def login(self, token: str) -> None: data = await self.http.static_login(token.strip()) self._connection.user = ClientUser(state=self._connection, data=data) + await self.setup_hook() + async def connect( self, *, reconnect: bool = True, ignore_session_start_limit: bool = False ) -> None: @@ -1245,10 +1247,14 @@ async def start( TypeError An unexpected keyword argument was received. """ - await self.login(token) - await self.connect( - reconnect=reconnect, ignore_session_start_limit=ignore_session_start_limit - ) + try: + await self.login(token) + await self.connect( + reconnect=reconnect, ignore_session_start_limit=ignore_session_start_limit + ) + finally: + if not self.is_closed(): + await self.close() def run(self, *args: Any, **kwargs: Any) -> None: """A blocking call that abstracts away the event loop @@ -1258,57 +1264,26 @@ def run(self, *args: Any, **kwargs: Any) -> None: function should not be used. Use :meth:`start` coroutine or :meth:`connect` + :meth:`login`. - Roughly Equivalent to: :: + Equivalent to: :: try: - loop.run_until_complete(start(*args, **kwargs)) + asyncio.run(start(*args, **kwargs)) except KeyboardInterrupt: - loop.run_until_complete(close()) - # cancel all tasks lingering - finally: - loop.close() + return .. warning:: This function must be the last function to call due to the fact that it is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. - """ - loop = self.loop + .. versionchanged:: 3.0 + Changed to use :func:`asyncio.run`, instead of custom logic. + """ try: - loop.add_signal_handler(signal.SIGINT, lambda: loop.stop()) - loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop()) - except NotImplementedError: - pass - - async def runner() -> None: - try: - await self.start(*args, **kwargs) - finally: - if not self.is_closed(): - await self.close() - - def stop_loop_on_completion(f) -> None: - loop.stop() - - future = asyncio.ensure_future(runner(), loop=loop) - future.add_done_callback(stop_loop_on_completion) - try: - loop.run_forever() + asyncio.run(self.start(*args, **kwargs)) except KeyboardInterrupt: - _log.info("Received signal to terminate bot and event loop.") - finally: - future.remove_done_callback(stop_loop_on_completion) - _log.info("Cleaning up tasks.") - _cleanup_loop(loop) - - if not future.cancelled(): - try: - return future.result() - except KeyboardInterrupt: - # I am unsure why this gets raised here but suppress it anyway - return None + return # properties @@ -1798,7 +1773,7 @@ def check(reaction, user): arguments that mirrors the parameters passed in the :ref:`event `. """ - future = self.loop.create_future() + future = asyncio.get_running_loop().create_future() if check is None: def _check(*args) -> bool: diff --git a/disnake/context_managers.py b/disnake/context_managers.py index 120c75ffa9..2e94d2e9fb 100644 --- a/disnake/context_managers.py +++ b/disnake/context_managers.py @@ -26,7 +26,6 @@ def _typing_done_callback(fut: asyncio.Future) -> None: class Typing: def __init__(self, messageable: Union[Messageable, ThreadOnlyGuildChannel]) -> None: - self.loop: asyncio.AbstractEventLoop = messageable._state.loop self.messageable: Union[Messageable, ThreadOnlyGuildChannel] = messageable async def do_typing(self) -> None: @@ -42,7 +41,7 @@ async def do_typing(self) -> None: await asyncio.sleep(5) def __enter__(self) -> Self: - self.task: asyncio.Task = self.loop.create_task(self.do_typing()) + self.task: asyncio.Task = asyncio.create_task(self.do_typing()) self.task.add_done_callback(_typing_done_callback) return self diff --git a/disnake/ext/commands/bot.py b/disnake/ext/commands/bot.py index 825f96e6ae..b659d4fcd8 100644 --- a/disnake/ext/commands/bot.py +++ b/disnake/ext/commands/bot.py @@ -10,8 +10,6 @@ from .interaction_bot_base import InteractionBotBase if TYPE_CHECKING: - import asyncio - import aiohttp from typing_extensions import Self @@ -237,8 +235,6 @@ def __init__( sync_commands: bool = ..., sync_commands_debug: bool = ..., sync_commands_on_cog_unload: bool = ..., - asyncio_debug: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, shard_id: Optional[int] = None, shard_count: Optional[int] = None, enable_debug_events: bool = False, @@ -289,8 +285,6 @@ def __init__( sync_commands: bool = ..., sync_commands_debug: bool = ..., sync_commands_on_cog_unload: bool = ..., - asyncio_debug: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, shard_ids: Optional[List[int]] = None, # instead of shard_id shard_count: Optional[int] = None, enable_debug_events: bool = False, @@ -438,8 +432,6 @@ def __init__( sync_commands: bool = ..., sync_commands_debug: bool = ..., sync_commands_on_cog_unload: bool = ..., - asyncio_debug: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, shard_id: Optional[int] = None, shard_count: Optional[int] = None, enable_debug_events: bool = False, @@ -483,8 +475,6 @@ def __init__( sync_commands: bool = ..., sync_commands_debug: bool = ..., sync_commands_on_cog_unload: bool = ..., - asyncio_debug: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, shard_ids: Optional[List[int]] = None, # instead of shard_id shard_count: Optional[int] = None, enable_debug_events: bool = False, diff --git a/disnake/ext/commands/bot_base.py b/disnake/ext/commands/bot_base.py index 1bba906c82..21277f0264 100644 --- a/disnake/ext/commands/bot_base.py +++ b/disnake/ext/commands/bot_base.py @@ -399,8 +399,8 @@ def after_invoke(self, coro: CFT) -> CFT: # extensions - def _remove_module_references(self, name: str) -> None: - super()._remove_module_references(name) + async def _remove_module_references(self, name: str) -> None: + await super()._remove_module_references(name) # remove all the commands from the module for cmd in self.all_commands.copy().values(): if cmd.module and _is_submodule(name, cmd.module): diff --git a/disnake/ext/commands/cog.py b/disnake/ext/commands/cog.py index 01fd59937c..3d7b4a3ea2 100644 --- a/disnake/ext/commands/cog.py +++ b/disnake/ext/commands/cog.py @@ -476,17 +476,21 @@ def has_message_error_handler(self) -> bool: @_cog_special_method async def cog_load(self) -> None: - """A special method that is called as a task when the cog is added.""" + """A special method that is called when the cog is added. + + .. versionchanged:: 3.0 + This is now ``await``\\ed directly instead of being scheduled as a task. + + This is now run when the cog has fully finished loading. + """ pass @_cog_special_method - def cog_unload(self) -> None: + async def cog_unload(self) -> None: """A special method that is called when the cog gets removed. - This function **cannot** be a coroutine. It must be a regular - function. - - Subclasses must replace this if they want special unloading behaviour. + .. versionchanged:: 3.0 + This can now be a coroutine. """ pass @@ -724,7 +728,7 @@ async def cog_after_message_command_invoke(self, inter: ApplicationCommandIntera """Similar to :meth:`cog_after_slash_command_invoke` but for message commands.""" pass - def _inject(self, bot: AnyBot) -> Self: + async def _inject(self, bot: AnyBot) -> Self: from .bot import AutoShardedInteractionBot, InteractionBot cls = self.__class__ @@ -771,9 +775,6 @@ def _inject(self, bot: AnyBot) -> Self: bot.remove_message_command(to_undo.name) raise - if not hasattr(self.cog_load.__func__, "__cog_special_method__"): - bot.loop.create_task(disnake.utils.maybe_coroutine(self.cog_load)) - # check if we're overriding the default if cls.bot_check is not Cog.bot_check: if isinstance(bot, (InteractionBot, AutoShardedInteractionBot)): @@ -830,9 +831,12 @@ def _inject(self, bot: AnyBot) -> Self: except NotImplementedError: pass + if not hasattr(self.cog_load.__func__, "__cog_special_method__"): + await disnake.utils.maybe_coroutine(self.cog_load) + return self - def _eject(self, bot: AnyBot) -> None: + async def _eject(self, bot: AnyBot) -> None: cls = self.__class__ try: @@ -896,7 +900,7 @@ def _eject(self, bot: AnyBot) -> None: except NotImplementedError: pass try: - self.cog_unload() + await disnake.utils.maybe_coroutine(self.cog_unload) except Exception as e: _log.error( "An error occurred while unloading the %s cog.", self.qualified_name, exc_info=e diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 8658aa369a..480b87d3d8 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -11,6 +11,7 @@ import sys import time import types +from functools import partial from typing import TYPE_CHECKING, Any, Dict, Generic, List, Mapping, Optional, Set, TypeVar, Union import disnake @@ -97,14 +98,14 @@ async def close(self) -> None: for extension in tuple(self.__extensions): try: - self.unload_extension(extension) + await self.unload_extension(extension) except Exception as error: error.__suppress_context__ = True _log.error("Failed to unload extension %r", extension, exc_info=error) for cog in tuple(self.__cogs): try: - self.remove_cog(cog) + await self.remove_cog(cog) except Exception as error: error.__suppress_context__ = True _log.exception("Failed to remove cog %r", cog, exc_info=error) @@ -115,12 +116,11 @@ async def close(self) -> None: async def login(self, token: str) -> None: await super().login(token=token) # type: ignore - loop: asyncio.AbstractEventLoop = self.loop # type: ignore if self.reload: - loop.create_task(self._watchdog()) + asyncio.create_task(self._watchdog()) # prefetch - loop.create_task(self._fill_owners()) + asyncio.create_task(self._fill_owners()) async def is_owner(self, user: Union[disnake.User, disnake.Member]) -> bool: """|coro| @@ -157,7 +157,7 @@ async def is_owner(self, user: Union[disnake.User, disnake.Member]) -> bool: else: return user.id in self.owner_ids - def add_cog(self, cog: Cog, *, override: bool = False) -> None: + async def add_cog(self, cog: Cog, *, override: bool = False) -> None: """Adds a "cog" to the bot. A cog is a class that has its own event listeners and commands. @@ -171,6 +171,9 @@ def add_cog(self, cog: Cog, *, override: bool = False) -> None: :exc:`.ClientException` is raised when a cog with the same name is already loaded. + .. versionchanged:: 3.0 + This is now a coroutine. + Parameters ---------- cog: :class:`.Cog` @@ -199,10 +202,10 @@ def add_cog(self, cog: Cog, *, override: bool = False) -> None: if existing is not None: if not override: raise disnake.ClientException(f"Cog named {cog_name!r} already loaded") - self.remove_cog(cog_name) + await self.remove_cog(cog_name) # NOTE: Should be covariant - cog = cog._inject(self) # type: ignore + cog = await cog._inject(self) # type: ignore self.__cogs[cog_name] = cog def get_cog(self, name: str) -> Optional[Cog]: @@ -224,7 +227,7 @@ def get_cog(self, name: str) -> Optional[Cog]: """ return self.__cogs.get(name) - def remove_cog(self, name: str) -> Optional[Cog]: + async def remove_cog(self, name: str) -> Optional[Cog]: """Removes a cog from the bot and returns it. All registered commands and event listeners that the @@ -236,6 +239,9 @@ def remove_cog(self, name: str) -> Optional[Cog]: :attr:`command_sync_flags.sync_on_cog_actions <.CommandSyncFlags.sync_on_cog_actions>` isn't disabled. + .. versionchanged:: 3.0 + This is now a coroutine. + Parameters ---------- name: :class:`str` @@ -254,7 +260,7 @@ def remove_cog(self, name: str) -> Optional[Cog]: if help_command and help_command.cog is cog: help_command.cog = None # NOTE: Should be covariant - cog._eject(self) # type: ignore + await cog._eject(self) # type: ignore return cog @@ -265,12 +271,12 @@ def cogs(self) -> Mapping[str, Cog]: # extensions - def _remove_module_references(self, name: str) -> None: + async def _remove_module_references(self, name: str) -> None: # find all references to the module # remove the cogs registered from the module for cogname, cog in self.__cogs.copy().items(): if _is_submodule(name, cog.__module__): - self.remove_cog(cogname) + await self.remove_cog(cogname) # remove all the listeners from the module for event_list in self.extra_events.copy().values(): remove = [ @@ -282,14 +288,14 @@ def _remove_module_references(self, name: str) -> None: for index in reversed(remove): del event_list[index] - def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: + async def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: try: func = lib.teardown except AttributeError: pass else: try: - func(self) + await disnake.utils.maybe_coroutine(partial(func, self)) except Exception as error: error.__suppress_context__ = True _log.error("Exception in extension finalizer %r", key, exc_info=error) @@ -301,7 +307,7 @@ def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: if _is_submodule(name, module): del sys.modules[module] - def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: + async def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: # precondition: key not in self.__extensions lib = importlib.util.module_from_spec(spec) sys.modules[key] = lib @@ -318,11 +324,11 @@ def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) raise errors.NoEntryPointError(key) from None try: - setup(self) + await disnake.utils.maybe_coroutine(partial(setup, self)) except Exception as e: del sys.modules[key] - self._remove_module_references(lib.__name__) - self._call_module_finalizers(lib, key) + await self._remove_module_references(lib.__name__) + await self._call_module_finalizers(lib, key) raise errors.ExtensionFailed(key, e) from e else: self.__extensions[key] = lib @@ -333,7 +339,7 @@ def _resolve_name(self, name: str, package: Optional[str]) -> str: except ImportError as e: raise errors.ExtensionNotFound(name) from e - def load_extension(self, name: str, *, package: Optional[str] = None) -> None: + async def load_extension(self, name: str, *, package: Optional[str] = None) -> None: """Loads an extension. An extension is a python module that contains commands, cogs, or @@ -343,6 +349,9 @@ def load_extension(self, name: str, *, package: Optional[str] = None) -> None: the entry point on what to do when the extension is loaded. This entry point must have a single argument, the ``bot``. + .. versionchanged:: 3.0 + This is now a coroutine. + Parameters ---------- name: :class:`str` @@ -377,9 +386,9 @@ def load_extension(self, name: str, *, package: Optional[str] = None) -> None: if spec is None: raise errors.ExtensionNotFound(name) - self._load_from_module_spec(spec, name) + await self._load_from_module_spec(spec, name) - def unload_extension(self, name: str, *, package: Optional[str] = None) -> None: + async def unload_extension(self, name: str, *, package: Optional[str] = None) -> None: """Unloads an extension. When the extension is unloaded, all commands, listeners, and cogs are @@ -390,6 +399,9 @@ def unload_extension(self, name: str, *, package: Optional[str] = None) -> None: parameter, the ``bot``, similar to ``setup`` from :meth:`~.Bot.load_extension`. + .. versionchanged:: 3.0 + This is now a coroutine. + Parameters ---------- name: :class:`str` @@ -416,10 +428,10 @@ def unload_extension(self, name: str, *, package: Optional[str] = None) -> None: if lib is None: raise errors.ExtensionNotLoaded(name) - self._remove_module_references(lib.__name__) - self._call_module_finalizers(lib, name) + await self._remove_module_references(lib.__name__) + await self._call_module_finalizers(lib, name) - def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: + async def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: """Atomically reloads an extension. This replaces the extension with the same extension, only refreshed. This is @@ -427,6 +439,9 @@ def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: except done in an atomic way. That is, if an operation fails mid-reload then the bot will roll-back to the prior working state. + .. versionchanged:: 3.0 + This is now a coroutine. + Parameters ---------- name: :class:`str` @@ -467,9 +482,9 @@ def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: try: # Unload and then load the module... - self._remove_module_references(lib.__name__) - self._call_module_finalizers(lib, name) - self.load_extension(name) + await self._remove_module_references(lib.__name__) + await self._call_module_finalizers(lib, name) + await self.load_extension(name) except Exception: # if the load failed, the remnants should have been # cleaned from the load_extension function call @@ -481,18 +496,21 @@ def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: sys.modules.update(modules) raise - def load_extensions(self, path: str) -> None: + async def load_extensions(self, path: str) -> None: """Loads all extensions in a directory. .. versionadded:: 2.4 + .. versionchanged:: 3.0 + This is now a coroutine. + Parameters ---------- path: :class:`str` The path to search for extensions """ for extension in disnake.utils.search_directory(path): - self.load_extension(extension) + await self.load_extension(extension) @property def extensions(self) -> Mapping[str, types.ModuleType]: @@ -535,7 +553,7 @@ async def _watchdog(self) -> None: for name in extensions: try: - self.reload_extension(name) + await self.reload_extension(name) except errors.ExtensionError as e: reload_log.exception(e) else: diff --git a/disnake/ext/commands/cooldowns.py b/disnake/ext/commands/cooldowns.py index 354754550a..4fd74bc773 100644 --- a/disnake/ext/commands/cooldowns.py +++ b/disnake/ext/commands/cooldowns.py @@ -279,11 +279,10 @@ class _Semaphore: overkill for what is basically a counter. """ - __slots__ = ("value", "loop", "_waiters") + __slots__ = ("value", "_waiters") def __init__(self, number: int) -> None: self.value: int = number - self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() self._waiters: Deque[asyncio.Future] = deque() def __repr__(self) -> str: @@ -308,7 +307,7 @@ async def acquire(self, *, wait: bool = False) -> bool: return False while self.value <= 0: - future = self.loop.create_future() + future = asyncio.get_running_loop().create_future() self._waiters.append(future) try: await future diff --git a/disnake/ext/commands/interaction_bot_base.py b/disnake/ext/commands/interaction_bot_base.py index 110066b679..545590ce5d 100644 --- a/disnake/ext/commands/interaction_bot_base.py +++ b/disnake/ext/commands/interaction_bot_base.py @@ -790,7 +790,7 @@ async def _sync_application_commands(self) -> None: if not isinstance(self, disnake.Client): raise NotImplementedError("This method is only usable in disnake.Client subclasses") - if not self._command_sync_flags._sync_enabled or self._is_closed or self.loop.is_closed(): + if not self._command_sync_flags._sync_enabled or self._is_closed: return # We assume that all commands are already cached. @@ -894,7 +894,6 @@ async def _delayed_command_sync(self) -> None: or self._sync_queued.locked() or not self.is_ready() or self._is_closed - or self.loop.is_closed() ): return # We don't do this task on login or in parallel with a similar task @@ -907,7 +906,7 @@ def _schedule_app_command_preparation(self) -> None: if not isinstance(self, disnake.Client): raise NotImplementedError("Command sync is only possible in disnake.Client subclasses") - self.loop.create_task( + asyncio.create_task( self._prepare_application_commands(), name="disnake: app_command_preparation" ) @@ -915,7 +914,7 @@ def _schedule_delayed_command_sync(self) -> None: if not isinstance(self, disnake.Client): raise NotImplementedError("This method is only usable in disnake.Client subclasses") - self.loop.create_task(self._delayed_command_sync(), name="disnake: delayed_command_sync") + asyncio.create_task(self._delayed_command_sync(), name="disnake: delayed_command_sync") # Error handlers diff --git a/disnake/ext/tasks/__init__.py b/disnake/ext/tasks/__init__.py index 6532c3d088..5b8ba7f0ea 100644 --- a/disnake/ext/tasks/__init__.py +++ b/disnake/ext/tasks/__init__.py @@ -7,7 +7,6 @@ import inspect import sys import traceback -import warnings from collections.abc import Sequence from typing import ( TYPE_CHECKING, @@ -50,18 +49,21 @@ class SleepHandle: - __slots__ = ("future", "loop", "handle") + __slots__ = ("future", "handle") - def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None: - self.loop = loop - self.future: asyncio.Future[bool] = loop.create_future() + def __init__(self, dt: datetime.datetime) -> None: + self.future: asyncio.Future[bool] = asyncio.get_running_loop().create_future() relative_delta = disnake.utils.compute_timedelta(dt) - self.handle = loop.call_later(relative_delta, self.future.set_result, True) + self.handle = asyncio.get_running_loop().call_later( + relative_delta, self.future.set_result, True + ) def recalculate(self, dt: datetime.datetime) -> None: self.handle.cancel() relative_delta = disnake.utils.compute_timedelta(dt) - self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) + self.handle = asyncio.get_running_loop().call_later( + relative_delta, self.future.set_result, True + ) def wait(self) -> asyncio.Future[bool]: return self.future @@ -90,14 +92,12 @@ def __init__( time: Union[datetime.time, Sequence[datetime.time]] = MISSING, count: Optional[int] = None, reconnect: bool = True, - loop: asyncio.AbstractEventLoop = MISSING, ) -> None: """.. note: If you overwrite ``__init__`` arguments, make sure to redefine .clone too. """ self.coro: LF = coro self.reconnect: bool = reconnect - self.loop: asyncio.AbstractEventLoop = loop self.count: Optional[int] = count self._current_loop = 0 self._handle: SleepHandle = MISSING @@ -139,7 +139,7 @@ async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> Non await coro(*args, **kwargs) def _try_sleep_until(self, dt: datetime.datetime): - self._handle = SleepHandle(dt=dt, loop=self.loop) + self._handle = SleepHandle(dt=dt) return self._handle.wait() async def _loop(self, *args: Any, **kwargs: Any) -> None: @@ -214,7 +214,6 @@ def clone(self) -> Self: time=self._time, count=self.count, reconnect=self.reconnect, - loop=self.loop, ) instance._before_loop = self._before_loop instance._after_loop = self._after_loop @@ -324,12 +323,7 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: if self._injected is not None: args = (self._injected, *args) - if self.loop is MISSING: - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - self.loop = asyncio.get_event_loop() - - self._task = self.loop.create_task(self._loop(*args, **kwargs)) + self._task = asyncio.create_task(self._loop(*args, **kwargs)) return self._task def stop(self) -> None: @@ -721,7 +715,6 @@ def loop( time: Union[datetime.time, Sequence[datetime.time]] = ..., count: Optional[int] = None, reconnect: bool = True, - loop: asyncio.AbstractEventLoop = ..., ) -> Callable[[LF], Loop[LF]]: ... @@ -775,9 +768,6 @@ def loop( Whether to handle errors and restart the task using an exponential back-off algorithm similar to the one used in :meth:`disnake.Client.connect`. - loop: :class:`asyncio.AbstractEventLoop` - The loop to use to register the task, if not given - defaults to :func:`asyncio.get_event_loop`. Raises ------ diff --git a/disnake/gateway.py b/disnake/gateway.py index 97508574c3..af7a6215d8 100644 --- a/disnake/gateway.py +++ b/disnake/gateway.py @@ -170,6 +170,7 @@ def __init__( *args: Any, ws: HeartbeatWebSocket, interval: float, + loop: asyncio.AbstractEventLoop, shard_id: Optional[int] = None, **kwargs: Any, ) -> None: @@ -177,6 +178,7 @@ def __init__( self.ws: HeartbeatWebSocket = ws self._main_thread_id: int = ws.thread_id self.interval: float = interval + self.loop = loop self.daemon: bool = True self.shard_id: Optional[int] = shard_id self.msg = "Keeping shard ID %s websocket alive with sequence %s." @@ -197,7 +199,7 @@ def run(self) -> None: self.shard_id, ) coro = self.ws.close(4000) - f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) + f = asyncio.run_coroutine_threadsafe(coro, loop=self.loop) try: f.result() @@ -210,7 +212,7 @@ def run(self) -> None: data = self.get_payload() _log.debug(self.msg, self.shard_id, data["d"]) coro = self.ws.send_heartbeat(data) - f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) + f = asyncio.run_coroutine_threadsafe(coro, loop=self.loop) try: # block until sending is complete total = 0 @@ -277,7 +279,6 @@ class HeartbeatWebSocket(Protocol): HEARTBEAT: Final[Literal[1, 3]] thread_id: int - loop: asyncio.AbstractEventLoop _max_heartbeat_timeout: float async def close(self, code: int) -> None: @@ -345,10 +346,10 @@ class DiscordWebSocket: GUILD_SYNC: Final[Literal[12]] = 12 def __init__( - self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop + self, + socket: aiohttp.ClientWebSocketResponse, ) -> None: self.socket: aiohttp.ClientWebSocketResponse = socket - self.loop: asyncio.AbstractEventLoop = loop # an empty dispatcher to prevent crashes self._dispatch: DispatchFunc = lambda event, *args: None @@ -420,7 +421,7 @@ async def from_client( gateway = await client.http.get_gateway(encoding=params.encoding, zlib=params.zlib) socket = await client.http.ws_connect(gateway) - ws = cls(socket, loop=client.loop) + ws = cls(socket) # dynamically add attributes needed ws.token = client.http.token # type: ignore @@ -483,7 +484,7 @@ def wait_for( asyncio.Future A future to wait for. """ - future = self.loop.create_future() + future = asyncio.get_running_loop().create_future() entry = EventListener(event=event, predicate=predicate, result=result, future=future) self._dispatch_listeners.append(entry) return future @@ -587,8 +588,12 @@ async def received_message(self, raw_msg: Union[str, bytes], /) -> None: if op == self.HELLO: interval: float = data["heartbeat_interval"] / 1000.0 self._keep_alive = KeepAliveHandler( - ws=self, interval=interval, shard_id=self.shard_id + ws=self, + interval=interval, + shard_id=self.shard_id, + loop=asyncio.get_running_loop(), # share loop to the thread ) + self._keep_alive.name = "disnake heartbeat thread" # send a heartbeat immediately await self.send_as_json(self._keep_alive.get_payload()) self._keep_alive.start() @@ -890,12 +895,10 @@ class DiscordVoiceWebSocket: def __init__( self, socket: aiohttp.ClientWebSocketResponse, - loop: asyncio.AbstractEventLoop, *, hook: Optional[HookFunc] = None, ) -> None: self.ws: aiohttp.ClientWebSocketResponse = socket - self.loop: asyncio.AbstractEventLoop = loop self._keep_alive: Optional[VoiceKeepAliveHandler] = None self._close_code: Optional[int] = None self.secret_key: Optional[List[int]] = None @@ -957,7 +960,7 @@ async def from_client( gateway = f"wss://{client.endpoint}/?v=4" http = client._state.http socket = await http.ws_connect(gateway, compress=15) - ws = cls(socket, loop=client.loop, hook=hook) + ws = cls(socket, hook=hook) ws.gateway = gateway ws._connection = client ws._max_heartbeat_timeout = 60.0 @@ -1011,7 +1014,9 @@ async def received_message(self, msg: VoicePayload) -> None: await self.load_secret_key(data) elif op == self.HELLO: interval: float = data["heartbeat_interval"] / 1000.0 - self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) + self._keep_alive = VoiceKeepAliveHandler( + ws=self, interval=min(interval, 5.0), loop=asyncio.get_running_loop() + ) self._keep_alive.start() await self._hook(self, msg) @@ -1027,7 +1032,7 @@ async def initial_connection(self, data: VoiceReadyPayload) -> None: struct.pack_into(">H", packet, 2, 70) # 70 = Length struct.pack_into(">I", packet, 4, state.ssrc) state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) - recv = await self.loop.sock_recv(state.socket, 74) + recv = await asyncio.get_running_loop().sock_recv(state.socket, 74) _log.debug("received packet in initial_connection: %s", recv) # the ip is ascii starting at the 8th byte and ending at the first null diff --git a/disnake/http.py b/disnake/http.py index e5c22622ef..3ff7c76e8a 100644 --- a/disnake/http.py +++ b/disnake/http.py @@ -223,13 +223,11 @@ def __init__( self, connector: Optional[aiohttp.BaseConnector] = None, *, - loop: asyncio.AbstractEventLoop, proxy: Optional[str] = None, proxy_auth: Optional[aiohttp.BasicAuth] = None, unsync_clock: bool = True, ) -> None: - self.loop: asyncio.AbstractEventLoop = loop - self.connector = connector + self.connector = connector or MISSING self.__session: aiohttp.ClientSession = MISSING # filled in static_login self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() self._global_over: asyncio.Event = asyncio.Event() @@ -359,7 +357,7 @@ async def request( delta, ) maybe_lock.defer() - self.loop.call_later(delta, lock.release) + asyncio.get_running_loop().call_later(delta, lock.release) # the request was successful so just return the text/json if 300 > response.status >= 200: @@ -450,6 +448,8 @@ async def close(self) -> None: # login management async def static_login(self, token: str) -> user.User: + if self.connector is MISSING: + self.connector = aiohttp.TCPConnector(limit=0) # Necessary to get aiohttp to stop complaining about session creation self.__session = aiohttp.ClientSession( connector=self.connector, ws_response_class=DiscordClientWebSocketResponse diff --git a/disnake/player.py b/disnake/player.py index 8012d640b4..1248b1d00f 100644 --- a/disnake/player.py +++ b/disnake/player.py @@ -689,11 +689,19 @@ def read(self) -> bytes: class AudioPlayer(threading.Thread): DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 - def __init__(self, source: AudioSource, client: VoiceClient, *, after=None) -> None: + def __init__( + self, + source: AudioSource, + client: VoiceClient, + loop: asyncio.AbstractEventLoop, + *, + after=None, + ) -> None: threading.Thread.__init__(self) self.daemon: bool = True self.source: AudioSource = source self.client: VoiceClient = client + self.loop = loop self.after: Optional[Callable[[Optional[Exception]], Any]] = after self._end: threading.Event = threading.Event() @@ -803,6 +811,6 @@ def _set_source(self, source: AudioSource) -> None: def _speak(self, speaking: bool) -> None: try: - asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop) + asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.loop) except Exception as e: _log.info("Speaking call in player failed: %s", e) diff --git a/disnake/shard.py b/disnake/shard.py index 3903d413aa..98ef7ff504 100644 --- a/disnake/shard.py +++ b/disnake/shard.py @@ -93,7 +93,6 @@ def __init__( self._client: Client = client self._dispatch: Callable[..., None] = client.dispatch self._queue_put: Callable[[EventItem], None] = queue_put - self.loop: asyncio.AbstractEventLoop = self._client.loop self._disconnect: bool = False self._reconnect = client._reconnect self._backoff: ExponentialBackoff = ExponentialBackoff() @@ -113,7 +112,7 @@ def id(self) -> int: return self.ws.shard_id # type: ignore def launch(self) -> None: - self._task = self.loop.create_task(self.worker()) + self._task = asyncio.create_task(self.worker()) def _cancel_task(self) -> None: if self._task is not None and not self._task.done(): @@ -347,7 +346,6 @@ def __init__( self, *, asyncio_debug: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, shard_ids: Optional[List[int]] = None, # instead of Client's shard_id: Optional[int] shard_count: Optional[int] = None, enable_debug_events: bool = False, @@ -409,7 +407,6 @@ def _get_state(self, **options: Any) -> AutoShardedConnectionState: handlers=self._handlers, hooks=self._hooks, http=self.http, - loop=self.loop, **options, ) @@ -537,7 +534,8 @@ async def close(self) -> None: pass to_close = [ - asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values() + asyncio.ensure_future(shard.close(), loop=asyncio.get_running_loop()) + for shard in self.__shards.values() ] if to_close: await asyncio.wait(to_close) diff --git a/disnake/state.py b/disnake/state.py index f4885513d7..a09923bfad 100644 --- a/disnake/state.py +++ b/disnake/state.py @@ -115,14 +115,12 @@ class ChunkRequest: def __init__( self, guild_id: int, - loop: asyncio.AbstractEventLoop, resolver: Callable[[int], Any], *, cache: bool = True, ) -> None: self.guild_id: int = guild_id self.resolver: Callable[[int], Any] = resolver - self.loop: asyncio.AbstractEventLoop = loop self.cache: bool = cache self.nonce: str = os.urandom(16).hex() self.buffer: List[Member] = [] @@ -141,7 +139,7 @@ def add_members(self, members: List[Member]) -> None: guild._add_member(member) async def wait(self) -> List[Member]: - future = self.loop.create_future() + future = asyncio.get_running_loop().create_future() self.waiters.append(future) try: return await future @@ -149,7 +147,7 @@ async def wait(self) -> List[Member]: self.waiters.remove(future) def get_future(self) -> asyncio.Future[List[Member]]: - future = self.loop.create_future() + future = asyncio.get_running_loop().create_future() self.waiters.append(future) return future @@ -193,7 +191,6 @@ def __init__( handlers: Dict[str, Callable], hooks: Dict[str, Callable], http: HTTPClient, - loop: asyncio.AbstractEventLoop, max_messages: Optional[int] = 1000, application_id: Optional[int] = None, heartbeat_timeout: float = 60.0, @@ -205,7 +202,6 @@ def __init__( chunk_guilds_at_startup: Optional[bool] = None, member_cache_flags: Optional[MemberCacheFlags] = None, ) -> None: - self.loop: asyncio.AbstractEventLoop = loop self.http: HTTPClient = http self.max_messages: Optional[int] = max_messages if self.max_messages is not None and self.max_messages <= 0: @@ -639,7 +635,7 @@ async def query_members( guild_id = guild.id ws = self._get_websocket(guild_id) - request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) + request = ChunkRequest(guild.id, self._get_guild, cache=cache) self._chunk_requests[request.nonce] = request try: @@ -1392,7 +1388,7 @@ async def chunk_guild( request = self._chunk_requests.get(guild.id) if request is None: self._chunk_requests[guild.id] = request = ChunkRequest( - guild.id, self.loop, self._get_guild, cache=cache + guild.id, self._get_guild, cache=cache ) await self.chunker(guild.id, nonce=request.nonce) @@ -2237,7 +2233,7 @@ async def _delay_ready(self) -> None: future = asyncio.ensure_future(self.chunk_guild(guild)) current_bucket.append(future) else: - future = self.loop.create_future() + future = asyncio.get_running_loop().create_future() future.set_result([]) processed.append((guild, future)) diff --git a/disnake/ui/modal.py b/disnake/ui/modal.py index adf21ffa9c..ef7e446e19 100644 --- a/disnake/ui/modal.py +++ b/disnake/ui/modal.py @@ -232,9 +232,8 @@ def __init__(self, state: ConnectionState) -> None: self._modals: Dict[Tuple[int, str], Modal] = {} def add_modal(self, user_id: int, modal: Modal) -> None: - loop = asyncio.get_running_loop() self._modals[(user_id, modal.custom_id)] = modal - loop.create_task(self.handle_timeout(user_id, modal.custom_id, modal.timeout)) + asyncio.create_task(self.handle_timeout(user_id, modal.custom_id, modal.timeout)) def remove_modal(self, user_id: int, modal_custom_id: str) -> Modal: return self._modals.pop((user_id, modal_custom_id)) diff --git a/disnake/ui/view.py b/disnake/ui/view.py index 71c2965074..68a6fc4840 100644 --- a/disnake/ui/view.py +++ b/disnake/ui/view.py @@ -178,12 +178,11 @@ def __init__(self, *, timeout: Optional[float] = 180.0) -> None: self.children.append(item) self.__weights = _ViewWeights(self.children) - loop = asyncio.get_running_loop() self.id: str = os.urandom(16).hex() self.__cancel_callback: Optional[Callable[[View], None]] = None self.__timeout_expiry: Optional[float] = None self.__timeout_task: Optional[asyncio.Task[None]] = None - self.__stopped: asyncio.Future[bool] = loop.create_future() + self.__stopped: asyncio.Future[bool] = asyncio.get_running_loop().create_future() def __repr__(self) -> str: return f"<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>" @@ -389,12 +388,11 @@ async def _scheduled_task(self, item: Item, interaction: MessageInteraction): def _start_listening_from_store(self, store: ViewStore) -> None: self.__cancel_callback = partial(store.remove_view) if self.timeout: - loop = asyncio.get_running_loop() if self.__timeout_task is not None: self.__timeout_task.cancel() self.__timeout_expiry = time.monotonic() + self.timeout - self.__timeout_task = loop.create_task(self.__timeout_task_impl()) + self.__timeout_task = asyncio.create_task(self.__timeout_task_impl()) def _dispatch_timeout(self) -> None: if self.__stopped.done(): diff --git a/disnake/voice_client.py b/disnake/voice_client.py index a6cc13e0ba..2bd50d11a5 100644 --- a/disnake/voice_client.py +++ b/disnake/voice_client.py @@ -14,6 +14,7 @@ - When that's all done, we receive opcode 4 from the vWS. - Finally we can transmit data to endpoint:port. """ + from __future__ import annotations import asyncio @@ -187,8 +188,6 @@ class VoiceClient(VoiceProtocol): The endpoint we are connecting to. channel: :class:`abc.Connectable` The voice channel connected to. - loop: :class:`asyncio.AbstractEventLoop` - The event loop that the voice client is running on. """ endpoint_ip: str @@ -206,7 +205,6 @@ def __init__(self, client: Client, channel: abc.Connectable) -> None: state = client._connection self.token: str = MISSING self.socket: socket.socket = MISSING - self.loop: asyncio.AbstractEventLoop = state.loop self._state: ConnectionState = state # this will be used in the AudioPlayer thread self._connected: threading.Event = threading.Event() @@ -376,7 +374,7 @@ async def connect(self, *, reconnect: bool, timeout: float) -> None: raise if self._runner is MISSING: - self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) + self._runner = asyncio.create_task(self.poll_voice_ws(reconnect)) async def potential_reconnect(self) -> bool: # Attempt to stop the player thread from playing early @@ -585,7 +583,7 @@ def play( if not self.encoder and not source.is_opus(): self.encoder = opus.Encoder() - self._player = AudioPlayer(source, self, after=after) + self._player = AudioPlayer(source, self, asyncio.get_running_loop(), after=after) self._player.start() def is_playing(self) -> bool: diff --git a/docs/ext/commands/cogs.rst b/docs/ext/commands/cogs.rst index 5a5f5389f0..024799704c 100644 --- a/docs/ext/commands/cogs.rst +++ b/docs/ext/commands/cogs.rst @@ -63,7 +63,7 @@ Once you have defined your cogs, you need to tell the bot to register the cogs t .. code-block:: python3 - bot.add_cog(Greetings(bot)) + await bot.add_cog(Greetings(bot)) This binds the cog to the bot, adding all commands and listeners to the bot automatically. @@ -71,7 +71,7 @@ Note that we reference the cog by name, which we can override through :ref:`ext_ .. code-block:: python3 - bot.remove_cog('Greetings') + await bot.remove_cog('Greetings') Using Cogs ---------- diff --git a/docs/ext/commands/extensions.rst b/docs/ext/commands/extensions.rst index d3db0436c5..d7699124ec 100644 --- a/docs/ext/commands/extensions.rst +++ b/docs/ext/commands/extensions.rst @@ -36,6 +36,11 @@ In this example we define a simple command, and when the extension is loaded thi Extensions are usually used in conjunction with cogs. To read more about them, check out the documentation, :ref:`ext_commands_cogs`. +.. admonition:: Async + :class: helpful + + ``setup`` (as well as ``teardown``, see below) can be ``async`` too! + .. note:: Extension paths are ultimately similar to the import mechanism. What this means is that if there is a folder, then it must be dot-qualified. For example to load an extension in ``plugins/hello.py`` then we use the string ``plugins.hello``. @@ -47,7 +52,7 @@ When you make a change to the extension and want to reload the references, the l .. code-block:: python3 - >>> bot.reload_extension('hello') + >>> await bot.reload_extension('hello') Once the extension reloads, any changes that we did will be applied. This is useful if we want to add or remove functionality without restarting our bot. If an error occurred during the reloading process, the bot will pretend as if the reload never happened. diff --git a/examples/basic_voice.py b/examples/basic_voice.py index 45046c780f..126312a340 100644 --- a/examples/basic_voice.py +++ b/examples/basic_voice.py @@ -9,7 +9,7 @@ import asyncio import os -from typing import Any, Dict, Optional +from typing import Any, Dict import disnake import youtube_dl # type: ignore @@ -43,11 +43,8 @@ def __init__(self, source: disnake.AudioSource, *, data: Dict[str, Any], volume: self.title = data.get("title") @classmethod - async def from_url( - cls, url, *, loop: Optional[asyncio.AbstractEventLoop] = None, stream: bool = False - ): - loop = loop or asyncio.get_event_loop() - data: Any = await loop.run_in_executor( + async def from_url(cls, url, *, stream: bool = False): + data: Any = await asyncio.get_running_loop().run_in_executor( None, lambda: ytdl.extract_info(url, download=not stream) ) @@ -94,7 +91,7 @@ async def stream(self, ctx, *, url: str): async def _play_url(self, ctx, *, url: str, stream: bool): await self.ensure_voice(ctx) async with ctx.typing(): - player = await YTDLSource.from_url(url, loop=self.bot.loop, stream=stream) + player = await YTDLSource.from_url(url, stream=stream) ctx.voice_client.play( player, after=lambda e: print(f"Player error: {e}") if e else None ) @@ -137,7 +134,11 @@ async def on_ready(): print(f"Logged in as {bot.user} (ID: {bot.user.id})\n------") -bot.add_cog(Music(bot)) +async def setup_hook(): + await bot.add_cog(Music(bot)) + + +bot.setup_hook = setup_hook if __name__ == "__main__": bot.run(os.getenv("BOT_TOKEN")) diff --git a/examples/interactions/subcmd.py b/examples/interactions/subcmd.py index ef9da28c38..119dffa6fd 100644 --- a/examples/interactions/subcmd.py +++ b/examples/interactions/subcmd.py @@ -71,7 +71,11 @@ async def on_ready(): print(f"Logged in as {bot.user} (ID: {bot.user.id})\n------") -bot.add_cog(MyCog()) +async def setup_hook(): + await bot.add_cog(MyCog()) + + +bot.setup_hook = setup_hook if __name__ == "__main__": bot.run(os.getenv("BOT_TOKEN")) diff --git a/pyproject.toml b/pyproject.toml index a184953fdd..c6d79ca6d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,15 +8,11 @@ build-backend = "setuptools.build_meta" name = "disnake" description = "A Python wrapper for the Discord API" readme = "README.md" -authors = [ - { name = "Disnake Development" } -] +authors = [{ name = "Disnake Development" }] requires-python = ">=3.8" keywords = ["disnake", "discord", "discord api"] license = { text = "MIT" } -dependencies = [ - "aiohttp>=3.7.0,<4.0", -] +dependencies = ["aiohttp>=3.7.0,<4.0"] classifiers = [ "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: MIT License", @@ -44,14 +40,11 @@ Repository = "https://github.com/DisnakeDev/disnake" [project.optional-dependencies] speed = [ "orjson~=3.6", - # taken from aiohttp[speedups] - "aiodns>=1.1", + # taken from aiohttp[speedups] "aiodns>=1.1", "Brotli", 'cchardet; python_version < "3.10"', ] -voice = [ - "PyNaCl>=1.3.0,<1.6", -] +voice = ["PyNaCl>=1.3.0,<1.6"] docs = [ "sphinx==7.0.1", "sphinxcontrib-trio~=1.1.2", @@ -64,9 +57,7 @@ docs = [ discord = ["discord-disnake"] [tool.pdm.dev-dependencies] -nox = [ - "nox==2022.11.21", -] +nox = ["nox==2022.11.21"] tools = [ "pre-commit~=3.0", "slotscheck~=0.16.4", @@ -74,9 +65,7 @@ tools = [ "check-manifest==0.49", "ruff==0.3.4", ] -changelog = [ - "towncrier==23.6.0", -] +changelog = ["towncrier==23.6.0"] codemod = [ # run codemods on the respository (mostly automated typing) "libcst~=1.1.0", @@ -97,11 +86,7 @@ test = [ "looptime~=0.2", "coverage[toml]~=6.5.0", ] -build = [ - "wheel~=0.40.0", - "build~=0.10.0", - "twine~=4.0.2", -] +build = ["wheel~=0.40.0", "build~=0.10.0", "twine~=4.0.2"] [tool.pdm.scripts] black = { composite = ["lint black"], help = "Run black" } @@ -109,7 +94,10 @@ docs = { cmd = "nox -Rs docs --", help = "Build the documentation for developmen lint = { cmd = "nox -Rs lint --", help = "Check all files for linting errors" } pyright = { cmd = "nox -Rs pyright --", help = "Run pyright" } setup_env = { cmd = "pdm install -d -G speed -G docs -G voice", help = "Set up the local environment and all dependencies" } -post_setup_env = { composite = ["python -m ensurepip --default-pip", "pre-commit install --install-hooks"] } +post_setup_env = { composite = [ + "python -m ensurepip --default-pip", + "pre-commit install --install-hooks", +] } test = { cmd = "nox -Rs test --", help = "Run pytest" } # legacy tasks for those who still type `task` @@ -136,14 +124,15 @@ target-version = "py38" select = [ # commented out codes are intended to be enabled in future prs "F", # pyflakes - "E", "W", # pycodestyle + "E", + "W", # pycodestyle # "D", # pydocstyle - "D2", # pydocstyle, docstring formatting - "D4", # pydocstyle, docstring structure/content + "D2", # pydocstyle, docstring formatting + "D4", # pydocstyle, docstring structure/content # "ANN", # flake8-annotations - "S", # flake8-bandit - "B", # flake8-bugbear - "C4", # flake8-comprehensions + "S", # flake8-bandit + "B", # flake8-bugbear + "C4", # flake8-comprehensions "DTZ", # flake8-datetimez # "EM", # flake8-errmsg "G", # flake8-logging-format @@ -161,8 +150,10 @@ select = [ "PLE", # pylint error # "PLR", # pylint refactor "PLW", # pylint warnings - "TRY002", "TRY004", "TRY201", # tryceratops - "I", # isort + "TRY002", + "TRY004", + "TRY201", # tryceratops + "I", # isort ] ignore = [ # star imports @@ -211,14 +202,14 @@ ignore = [ "PLE0237", # pyright seems to catch this already # temporary disables, to fix later - "D205", # blank line required between summary and description - "D401", # first line of docstring should be in imperative mood - "D417", # missing argument description in docstring - "B904", # within an except clause raise from error or from none - "B026", # backwards star-arg unpacking - "E501", # line too long - "E731", # assigning lambdas to variables - "T201", # print statements + "D205", # blank line required between summary and description + "D401", # first line of docstring should be in imperative mood + "D417", # missing argument description in docstring + "B904", # within an except clause raise from error or from none + "B026", # backwards star-arg unpacking + "E501", # line too long + "E731", # assigning lambdas to variables + "T201", # print statements ] [tool.ruff.lint.per-file-ignores] @@ -269,35 +260,35 @@ title_format = false underlines = "-~" issue_format = ":issue:`{issue}`" - [[tool.towncrier.type]] - directory = "breaking" - name = "Breaking Changes" - showcontent = true +[[tool.towncrier.type]] +directory = "breaking" +name = "Breaking Changes" +showcontent = true - [[tool.towncrier.type]] - directory = "deprecate" - name = "Deprecations" - showcontent = true +[[tool.towncrier.type]] +directory = "deprecate" +name = "Deprecations" +showcontent = true - [[tool.towncrier.type]] - directory = "feature" - name = "New Features" - showcontent = true +[[tool.towncrier.type]] +directory = "feature" +name = "New Features" +showcontent = true - [[tool.towncrier.type]] - directory = "bugfix" - name = "Bug Fixes" - showcontent = true +[[tool.towncrier.type]] +directory = "bugfix" +name = "Bug Fixes" +showcontent = true - [[tool.towncrier.type]] - directory = "doc" - name = "Documentation" - showcontent = true +[[tool.towncrier.type]] +directory = "doc" +name = "Documentation" +showcontent = true - [[tool.towncrier.type]] - directory = "misc" - name = "Miscellaneous" - showcontent = true +[[tool.towncrier.type]] +directory = "misc" +name = "Miscellaneous" +showcontent = true [tool.slotscheck] @@ -313,17 +304,8 @@ exclude-modules = ''' [tool.pyright] typeCheckingMode = "strict" -include = [ - "disnake", - "docs", - "examples", - "test_bot", - "tests", - "*.py", -] -ignore = [ - "disnake/ext/mypy_plugin", -] +include = ["disnake", "docs", "examples", "test_bot", "tests", "*.py"] +ignore = ["disnake/ext/mypy_plugin"] # this is one of the diagnostics that aren't enabled by default, even in strict mode reportUnnecessaryTypeIgnoreComment = true @@ -358,15 +340,8 @@ asyncio_mode = "strict" [tool.coverage.run] branch = true -include = [ - "disnake/*", - "tests/*", -] -omit = [ - "disnake/ext/mypy_plugin/*", - "disnake/types/*", - "disnake/__main__.py", -] +include = ["disnake/*", "tests/*"] +omit = ["disnake/ext/mypy_plugin/*", "disnake/types/*", "disnake/__main__.py"] [tool.coverage.report] precision = 1 diff --git a/test_bot/__main__.py b/test_bot/__main__.py index 37c5afa288..a58ea00739 100644 --- a/test_bot/__main__.py +++ b/test_bot/__main__.py @@ -52,9 +52,12 @@ async def on_ready(self) -> None: ) # fmt: on - def add_cog(self, cog: commands.Cog, *, override: bool = False) -> None: + async def setup_hook(self) -> None: + await self.load_extensions(os.path.join(__package__, Config.cogs_folder)) + + async def add_cog(self, cog: commands.Cog, *, override: bool = False) -> None: logger.info("Loading cog %s", cog.qualified_name) - return super().add_cog(cog, override=override) + return await super().add_cog(cog, override=override) async def _handle_error( self, ctx: Union[commands.Context, disnake.AppCommandInter], error: Exception, prefix: str @@ -98,5 +101,4 @@ async def on_message_command_error( if __name__ == "__main__": bot = TestBot() - bot.load_extensions(os.path.join(__package__, Config.cogs_folder)) bot.run(Config.token) diff --git a/test_bot/cogs/events.py b/test_bot/cogs/events.py index 50c4833b6d..33f1ac7dbd 100644 --- a/test_bot/cogs/events.py +++ b/test_bot/cogs/events.py @@ -28,5 +28,5 @@ async def on_guild_scheduled_event_unsubscribe(self, event, user) -> None: print("Scheduled event unsubscribe", event, user, sep="\n", end="\n\n") -def setup(bot) -> None: - bot.add_cog(EventListeners(bot)) +async def setup(bot) -> None: + await bot.add_cog(EventListeners(bot)) diff --git a/test_bot/cogs/guild_scheduled_events.py b/test_bot/cogs/guild_scheduled_events.py index 1ffcb92295..703a94bdd1 100644 --- a/test_bot/cogs/guild_scheduled_events.py +++ b/test_bot/cogs/guild_scheduled_events.py @@ -53,5 +53,5 @@ async def create_event( await inter.response.send_message(str(gse.image)) -def setup(bot: commands.Bot) -> None: - bot.add_cog(GuildScheduledEvents(bot)) +async def setup(bot: commands.Bot) -> None: + await bot.add_cog(GuildScheduledEvents(bot)) diff --git a/test_bot/cogs/injections.py b/test_bot/cogs/injections.py index 192ca10137..9b75adf25e 100644 --- a/test_bot/cogs/injections.py +++ b/test_bot/cogs/injections.py @@ -123,5 +123,5 @@ async def discerned_injections( await inter.response.send_message(f"```py\n{pformat(locals())}\n```") -def setup(bot) -> None: - bot.add_cog(InjectionSlashCommands(bot)) +async def setup(bot) -> None: + await bot.add_cog(InjectionSlashCommands(bot)) diff --git a/test_bot/cogs/localization.py b/test_bot/cogs/localization.py index 9bba8d1495..dae2ec2c20 100644 --- a/test_bot/cogs/localization.py +++ b/test_bot/cogs/localization.py @@ -79,5 +79,5 @@ async def cmd_msg(self, inter: disnake.AppCmdInter[commands.Bot], msg: disnake.M await inter.response.send_message(msg.content[::-1]) -def setup(bot) -> None: - bot.add_cog(Localizations(bot)) +async def setup(bot) -> None: + await bot.add_cog(Localizations(bot)) diff --git a/test_bot/cogs/message_commands.py b/test_bot/cogs/message_commands.py index d4101b8e41..a3f4e914fe 100644 --- a/test_bot/cogs/message_commands.py +++ b/test_bot/cogs/message_commands.py @@ -13,5 +13,5 @@ async def reverse(self, inter: disnake.MessageCommandInteraction[commands.Bot]) await inter.response.send_message(inter.target.content[::-1]) -def setup(bot) -> None: - bot.add_cog(MessageCommands(bot)) +async def setup(bot) -> None: + await bot.add_cog(MessageCommands(bot)) diff --git a/test_bot/cogs/misc.py b/test_bot/cogs/misc.py index c10081e967..a5b5ef2a46 100644 --- a/test_bot/cogs/misc.py +++ b/test_bot/cogs/misc.py @@ -50,5 +50,5 @@ async def attachment_desc_edit( await inter.response.send_message(".", view=view) -def setup(bot) -> None: - bot.add_cog(Misc(bot)) +async def setup(bot) -> None: + await bot.add_cog(Misc(bot)) diff --git a/test_bot/cogs/modals.py b/test_bot/cogs/modals.py index c5d514a25c..998fedb057 100644 --- a/test_bot/cogs/modals.py +++ b/test_bot/cogs/modals.py @@ -74,5 +74,5 @@ async def create_tag_low(self, inter: disnake.AppCmdInter[commands.Bot]) -> None await modal_inter.response.send_message(embed=embed) -def setup(bot: commands.Bot) -> None: - bot.add_cog(Modals(bot)) +async def setup(bot: commands.Bot) -> None: + await bot.add_cog(Modals(bot)) diff --git a/test_bot/cogs/slash_commands.py b/test_bot/cogs/slash_commands.py index e7a2437d56..60b546f3fd 100644 --- a/test_bot/cogs/slash_commands.py +++ b/test_bot/cogs/slash_commands.py @@ -73,5 +73,5 @@ async def largenumber( await inter.send(f"Is int: {isinstance(largenum, int)}") -def setup(bot) -> None: - bot.add_cog(SlashCommands(bot)) +async def setup(bot) -> None: + await bot.add_cog(SlashCommands(bot)) diff --git a/test_bot/cogs/user_commands.py b/test_bot/cogs/user_commands.py index e8c67efdca..cd588f8f40 100644 --- a/test_bot/cogs/user_commands.py +++ b/test_bot/cogs/user_commands.py @@ -15,5 +15,5 @@ async def avatar( await inter.response.send_message(user.display_avatar.url, ephemeral=True) -def setup(bot) -> None: - bot.add_cog(UserCommands(bot)) +async def setup(bot) -> None: + await bot.add_cog(UserCommands(bot)) diff --git a/tests/ext/tasks/test_loops.py b/tests/ext/tasks/test_loops.py index 796b16f0c5..c5103db92c 100644 --- a/tests/ext/tasks/test_loops.py +++ b/tests/ext/tasks/test_loops.py @@ -49,7 +49,6 @@ def clone(self): instance._time = self._time instance.count = self.count instance.reconnect = self.reconnect - instance.loop = self.loop instance._before_loop = self._before_loop instance._after_loop = self._after_loop instance._error = self._error diff --git a/tests/test_events.py b/tests/test_events.py index 15cc467151..4c558e8b31 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -43,8 +43,9 @@ async def on_message_edit(self, *args: Any) -> None: # Client.wait_for +@pytest.mark.asyncio @pytest.mark.parametrize("event", ["thread_create", Event.thread_create]) -def test_wait_for(client_or_bot: disnake.Client, event) -> None: +async def test_wait_for(client_or_bot: disnake.Client, event) -> None: coro = client_or_bot.wait_for(event) assert len(client_or_bot._listeners["thread_create"]) == 1 coro.close() # close coroutine to avoid warning @@ -95,22 +96,24 @@ async def on_guild_role_create(self, *args: Any) -> None: # @commands.Cog.listener +@pytest.mark.asyncio @pytest.mark.parametrize("event", ["on_automod_rule_update", Event.automod_rule_update]) -def test_listener(bot: commands.Bot, event) -> None: +async def test_listener(bot: commands.Bot, event) -> None: class Cog(commands.Cog): @commands.Cog.listener(event) async def callback(self, *args: Any) -> None: ... - bot.add_cog(Cog()) + await bot.add_cog(Cog()) assert len(bot.extra_events["on_automod_rule_update"]) == 1 -def test_listener__implicit(bot: commands.Bot) -> None: +@pytest.mark.asyncio +async def test_listener__implicit(bot: commands.Bot) -> None: class Cog(commands.Cog): @commands.Cog.listener() async def on_automod_rule_update(self, *args: Any) -> None: ... - bot.add_cog(Cog()) + await bot.add_cog(Cog()) assert len(bot.extra_events["on_automod_rule_update"]) == 1