# -*- coding: utf-8 -*-
# Copyright (c) 2023-present tandemdude
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import annotations
__all__ = ["sync_application_commands"]
import collections
import dataclasses
import inspect
import logging
import typing as t
import hikari
from lightbulb.commands import commands
from lightbulb.commands import groups
from lightbulb.internal import constants
from lightbulb.internal.utils import non_undefined_or
if t.TYPE_CHECKING:
from collections.abc import Sequence
from lightbulb import client as client_
LOGGER = logging.getLogger(__name__)
@dataclasses.dataclass(slots=True)
class _CommandBuilderCollection:
slash: hikari.api.SlashCommandBuilder | None = None
user: hikari.api.ContextMenuCommandBuilder | None = None
message: hikari.api.ContextMenuCommandBuilder | None = None
def put(self, bld: hikari.api.CommandBuilder) -> None:
if isinstance(bld, hikari.api.SlashCommandBuilder):
self.slash = bld
elif isinstance(bld, hikari.api.ContextMenuCommandBuilder):
if bld.type is hikari.CommandType.USER:
self.user = bld
else:
self.message = bld
else:
raise TypeError("unrecognised builder type")
def _hikari_command_to_builder(
cmd: hikari.PartialCommand, default_integration_types: list[hikari.ApplicationIntegrationType]
) -> hikari.api.SlashCommandBuilder | hikari.api.ContextMenuCommandBuilder:
bld: hikari.api.SlashCommandBuilder | hikari.api.ContextMenuCommandBuilder
if desc := getattr(cmd, "description", None):
bld = hikari.impl.SlashCommandBuilder(cmd.name, description=desc)
for option in getattr(cmd, "options", []) or []:
bld.add_option(option)
else:
bld = hikari.impl.ContextMenuCommandBuilder(type=cmd.type, name=cmd.name)
if cmd.guild_id is None:
bld = bld.set_integration_types(cmd.integration_types or default_integration_types).set_context_types(
cmd.context_types or list(hikari.ApplicationContextType)
)
return (
bld.set_default_member_permissions(cmd.default_member_permissions)
.set_is_nsfw(cmd.is_nsfw)
.set_name_localizations(cmd.name_localizations)
.set_id(cmd.id)
)
async def _get_existing_and_registered_commands(
client: client_.Client,
application: hikari.Application,
guild: hikari.UndefinedOr[hikari.Snowflakeish],
default_integration_types: list[hikari.ApplicationIntegrationType],
) -> tuple[dict[str, _CommandBuilderCollection], dict[str, _CommandBuilderCollection]]:
existing: dict[str, _CommandBuilderCollection] = collections.defaultdict(_CommandBuilderCollection)
registered: dict[str, _CommandBuilderCollection] = collections.defaultdict(_CommandBuilderCollection)
existing_commands = await client.rest.fetch_application_commands(application, guild=guild)
client._created_commands[guild or constants.GLOBAL_COMMAND_KEY] = existing_commands
for existing_command in existing_commands:
existing[existing_command.name].put(
_hikari_command_to_builder(existing_command, list(application.integration_types_config.keys()))
)
for collection in client._command_invocation_mapping.get(
constants.GLOBAL_COMMAND_KEY if guild is hikari.UNDEFINED else guild, {}
).values():
for item in [collection.slash, collection.user, collection.message]:
if item is None:
continue
command_data = item._command_data
# Get the parent command, luckily groups can only go two levels deep
root = getattr(command_data.parent, "parent", command_data.parent) or item
assert isinstance(root, groups.Group) or (inspect.isclass(root) and issubclass(root, commands.CommandBase))
builder = await root.as_command_builder(client.default_locale, client.localization_provider)
if guild is hikari.UNDEFINED:
builder = builder.set_integration_types(
builder.integration_types or default_integration_types
).set_context_types(builder.context_types or list(hikari.ApplicationContextType))
registered[builder.name].put(builder)
return existing, registered
def _serialize_builder(bld: hikari.api.CommandBuilder) -> dict[str, t.Any]:
def serialize_option(opt: hikari.CommandOption) -> dict[str, t.Any]:
return {
"type": opt.type,
"name": opt.name,
"description": opt.description,
"is_required": opt.is_required,
"choices": opt.choices or [],
"options": [serialize_option(o) for o in (opt.options or [])],
"channel_types": opt.channel_types or [],
"autocomplete": opt.autocomplete,
"min_value": opt.min_value,
"max_value": opt.max_value,
"name_localizations": opt.name_localizations,
"description_localizations": opt.description_localizations,
"min_length": opt.min_length,
"max_length": opt.max_length,
}
out: dict[str, t.Any] = {
"name": bld.name,
"integration_types": list(sorted(bld.integration_types or [])),
"contexts": list(sorted(bld.context_types or [])),
"is_nsfw": non_undefined_or(bld.is_nsfw, False),
"name_localizations": bld.name_localizations,
}
if isinstance(bld, hikari.api.SlashCommandBuilder):
out["description"] = bld.description
out["description_localizations"] = bld.description_localizations
out["options"] = [serialize_option(opt) for opt in bld.options]
return out
def _get_commands_to_set(
existing: dict[str, _CommandBuilderCollection],
registered: dict[str, _CommandBuilderCollection],
delete_unknown: bool,
) -> Sequence[hikari.api.CommandBuilder] | None:
created, deleted, updated, unchanged = 0, 0, 0, 0
commands_to_set: list[hikari.api.CommandBuilder] = []
for name in {*existing.keys(), *registered.keys()}:
existing_cmds, registered_cmds = existing[name], registered[name]
for existing_bld, registered_bld in zip(
[existing_cmds.slash, existing_cmds.user, existing_cmds.message],
[registered_cmds.slash, registered_cmds.user, registered_cmds.message],
):
if existing_bld is None and registered_bld is None:
continue
if existing_bld is None:
assert registered_bld is not None
commands_to_set.append(registered_bld)
created += 1
elif registered_bld is None:
if delete_unknown:
deleted += 1
else:
commands_to_set.append(existing_bld)
else:
if _serialize_builder(existing_bld) != _serialize_builder(registered_bld):
commands_to_set.append(registered_bld)
updated += 1
else:
commands_to_set.append(existing_bld)
unchanged += 1
LOGGER.debug("created: %s, deleted: %s, updated: %s, unchanged: %s", created, deleted, updated, unchanged)
return commands_to_set if any([created, deleted, updated]) else None
[docs]
async def sync_application_commands(client: client_.Client) -> None:
"""
Synchronise the commands registered to the given client with discord.
Args:
client: The client which has the commands to synchronise registered.
Returns:
:obj:`None`
"""
client._created_commands.clear()
application = await client._ensure_application()
default_integration_types = list(application.integration_types_config.keys())
LOGGER.info("syncing global commands")
existing_global_commands, registered_global_commands = await _get_existing_and_registered_commands(
client, application, hikari.UNDEFINED, default_integration_types
)
global_commands_to_set = _get_commands_to_set(
existing_global_commands, registered_global_commands, client.delete_unknown_commands
)
if global_commands_to_set is not None:
client._created_commands[constants.GLOBAL_COMMAND_KEY] = await client.rest.set_application_commands(
application, global_commands_to_set
)
LOGGER.info("finished syncing global commands")
for guild in client._command_invocation_mapping:
if guild == constants.GLOBAL_COMMAND_KEY:
continue
LOGGER.info("syncing commands for guild '%s'", guild)
existing_guild_commands, registered_guild_commands = await _get_existing_and_registered_commands(
client, application, guild, default_integration_types
)
guild_commands_to_set = _get_commands_to_set(
existing_guild_commands, registered_guild_commands, client.delete_unknown_commands
)
if guild_commands_to_set is not None:
client._created_commands[guild] = await client.rest.set_application_commands(
application, guild_commands_to_set, guild=guild
)
LOGGER.info("finished syncing commands for guild '%s'", guild)