Source code for lightbulb.loaders

# -*- coding: utf-8 -*-
# Copyright (c) 2023-present tandemdude
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import annotations

__all__ = ["Loadable", "Loader"]

import abc
import functools
import logging
import typing as t

import hikari
import linkd

from lightbulb import di
from lightbulb import tasks
from lightbulb.commands import groups

if t.TYPE_CHECKING:
    from collections.abc import Awaitable
    from collections.abc import Callable
    from collections.abc import Sequence

    from lightbulb import client as client_
    from lightbulb.internal import types

CommandOrGroupT = t.TypeVar("CommandOrGroupT", bound="types.CommandOrGroup")
ErrorHandlerT = t.TypeVar("ErrorHandlerT", bound="types.ErrorHandler")
EventT = t.TypeVar("EventT", bound=hikari.Event)

LOGGER = logging.getLogger(__name__)


[docs] class Loadable(abc.ABC): """Abstract class containing the logic required to add and remove a feature from a client instance.""" __slots__ = ()
[docs] @abc.abstractmethod async def load(self, client: client_.Client) -> None: """ Add the feature to the client instance. Args: client: The client instance to add the feature to. Returns: :obj:`None` Warning: This method **must** be idempotent. I.e. if the item being loaded is already loaded, the method **must not** attempt to load the item a second time. """
[docs] @abc.abstractmethod async def unload(self, client: client_.Client) -> None: """ Remove the feature from the client instance. Args: client: The client instance to remove the feature from. Returns: :obj:`None` Warning: This method **must** be idempotent. I.e. if the item being unloaded is already unloaded, the method **must not** attempt to unload the item a second time. """
class _CommandLoadable(Loadable): # TODO - check this is correctly idempotent __slots__ = ("_command", "_defer_guilds", "_global", "_guilds") def __init__( self, command: types.CommandOrGroup, guilds: Sequence[hikari.Snowflakeish] | None, global_: bool | None, defer_guilds: bool, ) -> None: self._command = command self._guilds = guilds self._global = global_ self._defer_guilds = defer_guilds async def load(self, client: client_.Client) -> None: if isinstance(self._command, groups.Group): object.__setattr__(self._command, "extension", client._current_extension_being_loaded) else: self._command._command_data.extension = client._current_extension_being_loaded if self._defer_guilds: client.register(self._command, defer_guilds=True) else: client.register(self._command, guilds=self._guilds, global_=self._global) async def unload(self, client: client_.Client) -> None: client.unregister(self._command) class _ListenerLoadable(Loadable): __slots__ = ("_callback", "_event_types", "_wrapped_callback") def __init__(self, callback: Callable[[EventT], Awaitable[None]], *event_types: type[EventT]) -> None: self._callback = callback self._event_types = event_types self._wrapped_callback: Callable[..., Awaitable[t.Any]] | None = None async def load(self, client: client_.Client) -> None: em = getattr(getattr(client, "_app", None), "event_manager", None) if not isinstance(em, hikari.api.EventManager): LOGGER.warning("skipping loading listener - bot is not event manager aware") return @functools.wraps(self._callback) async def _wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: async with client.di.enter_context(di.Contexts.DEFAULT), client.di.enter_context(di.Contexts.LISTENER): return await self._callback(*args, **kwargs) self._wrapped_callback = _wrapped if linkd.DI_ENABLED else None for event in self._event_types: if (self._wrapped_callback or self._callback) in em.get_listeners(event): continue em.subscribe(event, self._wrapped_callback or self._callback) # type: ignore[reportArgumentType] async def unload(self, client: client_.Client) -> None: em = getattr(getattr(client, "_app", None), "event_manager", None) if not isinstance(em, hikari.api.EventManager): return for event in self._event_types: if (self._wrapped_callback or self._callback) not in em.get_listeners(event): continue em.unsubscribe(event, self._wrapped_callback or self._callback) # type: ignore[reportArgumentType] class _ErrorHandlerLoadable(Loadable): __slots__ = ("_callback", "_priority") def __init__(self, callback: types.ErrorHandler, priority: int) -> None: self._callback = callback self._priority = priority async def load(self, client: client_.Client) -> None: if self._callback in client._error_handlers.get(self._priority, []): return client.error_handler(self._callback, priority=self._priority) async def unload(self, client: client_.Client) -> None: if self._callback not in client._error_handlers.get(self._priority, []): return client.remove_error_handler(self._callback) class _TaskLoadable(Loadable): __slots__ = ("_task",) def __init__(self, task: tasks.Task) -> None: self._task = task async def load(self, client: client_.Client) -> None: if self._task in client._tasks: return self._task = client.task(self._task) async def unload(self, client: client_.Client) -> None: if self._task not in client._tasks: return client.remove_task(self._task, cancel=True)
[docs] class Loader: """ Class used for loading features into the client from extensions. Args: should_load_hook: Synchronous or asynchronous function which will be called when the loader is added to the client. Returns a boolean indicating whether this loader should be loaded or not. If it returns :obj:`False`, the loader **will not** be loaded and none of its features will be added to the client. """ __slots__ = ("_loadables", "_should_load_hook") def __init__(self, should_load_hook: Callable[[], types.MaybeAwaitable[bool]] = lambda: True) -> None: self._should_load_hook = should_load_hook self._loadables: list[Loadable] = []
[docs] async def add_to_client(self, client: client_.Client) -> None: """ Add the features contained within this loader to the given client. Args: client: The client to add this loader's features to. Returns: :obj:`None` """ for loadable in self._loadables: await loadable.load(client)
[docs] async def remove_from_client(self, client: client_.Client) -> None: """ Remove the features contained within this loader from the given client. If any single loadable's unload method raises an exception then the remaining loadables will still be unloaded. Args: client: The client to remove this loader's features from. Returns: :obj:`None` """ for loadable in self._loadables: try: await loadable.unload(client) except Exception as e: LOGGER.warning("error while unloading loadable %r", loadable, exc_info=(type(e), e, e.__traceback__))
[docs] def add(self, loadable: Loadable) -> Loadable: """ Add the given loadable to this loader. Args: loadable: The loadable to add. Returns: The added loadable, unchanged. """ self._loadables.append(loadable) return loadable
@t.overload def command( self, *, guilds: Sequence[hikari.Snowflakeish] | None = None, global_: bool | None = None ) -> Callable[[CommandOrGroupT], CommandOrGroupT]: ... @t.overload def command(self, *, defer_guilds: t.Literal[True]) -> Callable[[CommandOrGroupT], CommandOrGroupT]: ... @t.overload def command( self, command: CommandOrGroupT, *, guilds: Sequence[hikari.Snowflakeish] | None = None, global_: bool | None = None, ) -> CommandOrGroupT: ... @t.overload def command(self, command: CommandOrGroupT, *, defer_guilds: t.Literal[True]) -> CommandOrGroupT: ...
[docs] def command( self, command: CommandOrGroupT | None = None, *, guilds: Sequence[hikari.Snowflakeish] | None = None, global_: bool | None = None, defer_guilds: bool = False, ) -> CommandOrGroupT | Callable[[CommandOrGroupT], CommandOrGroupT]: """ Register a command or group with this loader. This method can be used as a function, or a first or second order decorator. Args: command: The command class or command group to register with the client. guilds: The guilds to create the command or group in. global_: Whether the command should be registered globally. defer_guilds: Whether the guilds to create this command in should be resolved when the client is started. If :obj:`True`, the client's ``deferred_registration_callback`` will be used to resolve which guilds to create the command in. You can also use this to conditionally prevent the command from being registered to any guilds. Returns: The registered command or group, unchanged. Example: .. code-block:: python loader = lightbulb.Loader() # valid @loader.register # also valid @loader.register(guilds=[...]) class Example( lightbulb.SlashCommand, ... ): ... # also valid loader.register(Example, guilds=[...]) See Also: :meth:`~lightbulb.client.Client.register` """ # Used as a function or first-order decorator if command is not None: self.add(_CommandLoadable(command, guilds, global_, defer_guilds)) return command # Used as a second-order decorator def _inner(command_: CommandOrGroupT) -> CommandOrGroupT: return self.command(command_, guilds=guilds) return _inner
[docs] def listener( self, *event_types: type[EventT] ) -> Callable[[Callable["t.Concatenate[EventT, ...]", Awaitable[None]]], Callable[[EventT], Awaitable[None]]]: """ Decorator to register a listener with this loader. Also enables dependency injection on the listener callback. If an :obj:`hikari.api.event_manager.EventManager` instance is not available through dependency injection then adding this loader to the client will fail at runtime. Args: *event_types: The event class(es) for the listener to listen to. Example: .. code-block:: python loader = lightbulb.Loader() @loader.listener(hikari.MessageCreateEvent) async def message_create_listener(event: hikari.MessageCreateEvent) -> None: ... """ if not event_types: raise ValueError("you must specify at least one event type") def _inner( callback: Callable["t.Concatenate[EventT, ...]", Awaitable[None]], ) -> Callable[[EventT], Awaitable[None]]: wrapped = t.cast("Callable[[EventT], Awaitable[None]]", di.with_di(callback)) self.add(_ListenerLoadable(wrapped, *event_types)) return wrapped return _inner
@t.overload def error_handler(self, *, priority: int = 0) -> Callable[[ErrorHandlerT], ErrorHandlerT]: ... @t.overload def error_handler(self, func: ErrorHandlerT, *, priority: int = 0) -> ErrorHandlerT: ...
[docs] def error_handler( self, func: ErrorHandlerT | None = None, *, priority: int = 0 ) -> ErrorHandlerT | Callable[[ErrorHandlerT], ErrorHandlerT]: """ Register an error handler function to call when an :obj:`~lightbulb.commands.execution.ExecutionPipeline` fails. Also enables dependency injection for the error handler function. The function must take the exception as its first argument, which will be an instance of :obj:`~lightbulb.exceptions.ExecutionPipelineFailedException`. The function **must** return a boolean indicating whether the exception was successfully handled. Non-boolean return values will be cast to booleans. Args: func: The function to register as a command error handler. priority: The priority that this handler should be registered at. Higher priority handlers will be executed first. """ if func is not None: wrapped = di.with_di(func) self.add(_ErrorHandlerLoadable(wrapped, priority)) # type: ignore[reportArgumentType] return t.cast("ErrorHandlerT", wrapped) def _inner(func_: ErrorHandlerT) -> ErrorHandlerT: return self.error_handler(func_, priority=priority) return _inner
[docs] def task( self, trigger: tasks.Trigger, /, auto_start: bool = True, max_failures: int = 1, max_invocations: int = -1 ) -> Callable[[tasks.TaskFunc], tasks.Task]: """ Second order decorator to register a repeating task with this loader. Task functions will have dependency injection enabled on them automatically. Task functions **must** be asynchronous. Args: trigger: The trigger function to use to resolve the interval between task executions. auto_start: Whether the task should be started automatically. This means that if the task is added to the client upon the client being started, the task will also be started; it will also be started if being added to an already-started client. max_failures: The maximum number of failed attempts to execute the task before it is cancelled. Setting this to a negative number will prevent the task from being cancelled, regardless of how often the task fails. max_invocations: The maximum number of times the task can be invoked before being stopped. Setting this to a negative number will disable this behaviour, allowing unlimited invocations. Example: .. code-block:: python loader = lightbulb.Loader() @loader.task(lightbulb.uniformtrigger(minutes=1)) async def print_hi() -> None: print("HI") """ def _inner(func: tasks.TaskFunc) -> tasks.Task: task_obj = tasks.Task(func, trigger, auto_start, max_failures, max_invocations) self.add(_TaskLoadable(task_obj)) return task_obj return _inner