mirror of https://github.com/coddrago/Heroku
1108 lines
34 KiB
Python
1108 lines
34 KiB
Python
# ©️ Dan Gazizullin, 2021-2023
|
|
# This file is a part of Hikka Userbot
|
|
# 🌐 https://github.com/hikariatama/Hikka
|
|
# You can redistribute it and/or modify it under the terms of the GNU AGPLv3
|
|
# 🔑 https://www.gnu.org/licenses/agpl-3.0.html
|
|
|
|
# ©️ Codrago, 2024-2025
|
|
# This file is a part of Heroku Userbot
|
|
# 🌐 https://github.com/coddrago/Heroku
|
|
# You can redistribute it and/or modify it under the terms of the GNU AGPLv3
|
|
# 🔑 https://www.gnu.org/licenses/agpl-3.0.html
|
|
|
|
|
|
import ast
|
|
import asyncio
|
|
import contextlib
|
|
import copy
|
|
import importlib
|
|
import importlib.machinery
|
|
import importlib.util
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import typing
|
|
from dataclasses import dataclass, field
|
|
from importlib.abc import SourceLoader
|
|
|
|
import requests
|
|
from herokutl.hints import EntityLike
|
|
from herokutl.tl.functions.account import UpdateNotifySettingsRequest
|
|
from herokutl.tl.types import (
|
|
Channel,
|
|
ChannelFull,
|
|
InputPeerNotifySettings,
|
|
Message,
|
|
UserFull,
|
|
)
|
|
|
|
from . import version
|
|
from ._reference_finder import replace_all_refs
|
|
from .inline.types import (
|
|
BotInlineCall,
|
|
BotInlineMessage,
|
|
BotMessage,
|
|
InlineCall,
|
|
InlineMessage,
|
|
InlineQuery,
|
|
InlineUnit,
|
|
)
|
|
from .pointers import PointerDict, PointerList
|
|
|
|
__all__ = [
|
|
"JSONSerializable",
|
|
"HerokuReplyMarkup",
|
|
"ListLike",
|
|
"Command",
|
|
"StringLoader",
|
|
"Module",
|
|
"get_commands",
|
|
"get_inline_handlers",
|
|
"get_callback_handlers",
|
|
"BotInlineCall",
|
|
"BotMessage",
|
|
"InlineCall",
|
|
"InlineMessage",
|
|
"InlineQuery",
|
|
"InlineUnit",
|
|
"BotInlineMessage",
|
|
"PointerDict",
|
|
"PointerList",
|
|
]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
JSONSerializable = typing.Union[str, int, float, bool, list, dict, None]
|
|
HerokuReplyMarkup = typing.Union[typing.List[typing.List[dict]], typing.List[dict], dict]
|
|
ListLike = typing.Union[list, set, tuple]
|
|
Command = typing.Callable[..., typing.Awaitable[typing.Any]]
|
|
|
|
|
|
class StringLoader(SourceLoader):
|
|
"""Load a python module/file from a string"""
|
|
|
|
def __init__(self, data: str, origin: str):
|
|
self.data = data.encode("utf-8") if isinstance(data, str) else data
|
|
self.origin = origin
|
|
|
|
def get_source(self, _=None) -> str:
|
|
return self.data.decode("utf-8")
|
|
|
|
def get_code(self, fullname: str) -> bytes:
|
|
return (
|
|
compile(source, self.origin, "exec", dont_inherit=True)
|
|
if (source := self.get_data(fullname))
|
|
else None
|
|
)
|
|
|
|
def get_filename(self, *args, **kwargs) -> str:
|
|
return self.origin
|
|
|
|
def get_data(self, *args, **kwargs) -> bytes:
|
|
return self.data
|
|
|
|
|
|
class Module:
|
|
strings = {"name": "Unknown"}
|
|
|
|
"""There is no help for this module"""
|
|
|
|
def config_complete(self):
|
|
"""Called when module.config is populated"""
|
|
|
|
async def client_ready(self):
|
|
"""Called after client is ready (after config_loaded)"""
|
|
|
|
def internal_init(self):
|
|
"""Called after the class is initialized in order to pass the client and db. Do not call it yourself"""
|
|
self.db = self.allmodules.db
|
|
self._db = self.allmodules.db
|
|
self.client = self.allmodules.client
|
|
self._client = self.allmodules.client
|
|
self.lookup = self.allmodules.lookup
|
|
self.get_prefix = self.allmodules.get_prefix
|
|
self.inline = self.allmodules.inline
|
|
self.allclients = self.allmodules.allclients
|
|
self.tg_id = self._client.tg_id
|
|
self._tg_id = self._client.tg_id
|
|
|
|
async def on_unload(self):
|
|
"""Called after unloading / reloading module"""
|
|
|
|
async def on_dlmod(self):
|
|
"""
|
|
Called after the module is first time loaded with .dlmod or .loadmod
|
|
|
|
Possible use-cases:
|
|
- Send reaction to author's channel message
|
|
- Create asset folder
|
|
- ...
|
|
|
|
⚠️ Note, that any error there will not interrupt module load, and will just
|
|
send a message to logs with verbosity INFO and exception traceback
|
|
"""
|
|
|
|
async def invoke(
|
|
self,
|
|
command: str,
|
|
args: typing.Optional[str] = None,
|
|
peer: typing.Optional[EntityLike] = None,
|
|
message: typing.Optional[Message] = None,
|
|
edit: bool = False,
|
|
) -> Message:
|
|
"""
|
|
Invoke another command
|
|
:param command: Command to invoke
|
|
:param args: Arguments to pass to command
|
|
:param peer: Peer to send the command to. If not specified, will send to the current chat
|
|
:param edit: Whether to edit the message
|
|
:returns Message:
|
|
"""
|
|
if command not in self.allmodules.commands:
|
|
raise ValueError(f"Command {command} not found")
|
|
|
|
if not message and not peer:
|
|
raise ValueError("Either peer or message must be specified")
|
|
|
|
cmd = f"{self.get_prefix()}{command} {args or ''}".strip()
|
|
|
|
message = (
|
|
(await self._client.send_message(peer, cmd))
|
|
if peer
|
|
else (await (message.edit if edit else message.respond)(cmd))
|
|
)
|
|
await self.allmodules.commands[command](message)
|
|
return message
|
|
|
|
@property
|
|
def commands(self) -> typing.Dict[str, Command]:
|
|
"""List of commands that module supports"""
|
|
return get_commands(self)
|
|
|
|
@property
|
|
def heroku_commands(self) -> typing.Dict[str, Command]:
|
|
"""List of commands that module supports"""
|
|
return get_commands(self)
|
|
|
|
@property
|
|
def inline_handlers(self) -> typing.Dict[str, Command]:
|
|
"""List of inline handlers that module supports"""
|
|
return get_inline_handlers(self)
|
|
|
|
@property
|
|
def heroku_inline_handlers(self) -> typing.Dict[str, Command]:
|
|
"""List of inline handlers that module supports"""
|
|
return get_inline_handlers(self)
|
|
|
|
@property
|
|
def callback_handlers(self) -> typing.Dict[str, Command]:
|
|
"""List of callback handlers that module supports"""
|
|
return get_callback_handlers(self)
|
|
|
|
@property
|
|
def heroku_callback_handlers(self) -> typing.Dict[str, Command]:
|
|
"""List of callback handlers that module supports"""
|
|
return get_callback_handlers(self)
|
|
|
|
@property
|
|
def watchers(self) -> typing.Dict[str, Command]:
|
|
"""List of watchers that module supports"""
|
|
return get_watchers(self)
|
|
|
|
@property
|
|
def heroku_watchers(self) -> typing.Dict[str, Command]:
|
|
"""List of watchers that module supports"""
|
|
return get_watchers(self)
|
|
|
|
@commands.setter
|
|
def commands(self, _):
|
|
pass
|
|
|
|
@heroku_commands.setter
|
|
def heroku_commands(self, _):
|
|
pass
|
|
|
|
@inline_handlers.setter
|
|
def inline_handlers(self, _):
|
|
pass
|
|
|
|
@heroku_inline_handlers.setter
|
|
def heroku_inline_handlers(self, _):
|
|
pass
|
|
|
|
@callback_handlers.setter
|
|
def callback_handlers(self, _):
|
|
pass
|
|
|
|
@heroku_callback_handlers.setter
|
|
def heroku_callback_handlers(self, _):
|
|
pass
|
|
|
|
@watchers.setter
|
|
def watchers(self, _):
|
|
pass
|
|
|
|
@heroku_watchers.setter
|
|
def heroku_watchers(self, _):
|
|
pass
|
|
|
|
async def animate(
|
|
self,
|
|
message: typing.Union[Message, InlineMessage],
|
|
frames: typing.List[str],
|
|
interval: typing.Union[float, int],
|
|
*,
|
|
inline: bool = False,
|
|
) -> None:
|
|
"""
|
|
Animate message
|
|
:param message: Message to animate
|
|
:param frames: A List of strings which are the frames of animation
|
|
:param interval: Animation delay
|
|
:param inline: Whether to use inline bot for animation
|
|
:returns message:
|
|
|
|
Please, note that if you set `inline=True`, first frame will be shown with an empty
|
|
button due to the limitations of Telegram API
|
|
"""
|
|
from . import utils
|
|
|
|
with contextlib.suppress(AttributeError):
|
|
_heroku_client_id_logging_tag = copy.copy(self.client.tg_id) # noqa: F841
|
|
|
|
if interval < 0.1:
|
|
logger.warning(
|
|
"Resetting animation interval to 0.1s, because it may get you in"
|
|
" floodwaits"
|
|
)
|
|
interval = 0.1
|
|
|
|
for frame in frames:
|
|
if isinstance(message, Message):
|
|
if inline:
|
|
message = await self.inline.form(
|
|
message=message,
|
|
text=frame,
|
|
reply_markup={"text": "\u0020\u2800", "data": "empty"},
|
|
)
|
|
else:
|
|
message = await utils.answer(message, frame)
|
|
elif isinstance(message, InlineMessage) and inline:
|
|
await message.edit(frame)
|
|
|
|
await asyncio.sleep(interval)
|
|
|
|
return message
|
|
|
|
def get(
|
|
self,
|
|
key: str,
|
|
default: typing.Optional[JSONSerializable] = None,
|
|
) -> JSONSerializable:
|
|
return self._db.get(self.__class__.__name__, key, default)
|
|
|
|
def set(self, key: str, value: JSONSerializable) -> bool:
|
|
self._db.set(self.__class__.__name__, key, value)
|
|
|
|
def pointer(
|
|
self,
|
|
key: str,
|
|
default: typing.Optional[JSONSerializable] = None,
|
|
item_type: typing.Optional[typing.Any] = None,
|
|
) -> typing.Union[JSONSerializable, PointerList, PointerDict]:
|
|
return self._db.pointer(self.__class__.__name__, key, default, item_type)
|
|
|
|
async def _decline(
|
|
self,
|
|
call: InlineCall,
|
|
channel: EntityLike,
|
|
event: asyncio.Event,
|
|
):
|
|
from . import utils
|
|
|
|
self._db.set(
|
|
"heroku.main",
|
|
"declined_joins",
|
|
list(set(self._db.get("heroku.main", "declined_joins", []) + [channel.id])),
|
|
)
|
|
event.status = False
|
|
event.set()
|
|
await call.edit(
|
|
(
|
|
"✖️ <b>Declined joining <a"
|
|
f' href="https://t.me/{channel.username}">{utils.escape_html(channel.title)}</a></b>'
|
|
),
|
|
photo="https://imgur.com/a/XpwmHo6.png",
|
|
)
|
|
|
|
async def request_join(
|
|
self,
|
|
peer: EntityLike,
|
|
reason: str,
|
|
assure_joined: typing.Optional[bool] = False,
|
|
) -> bool:
|
|
"""
|
|
Request to join a channel.
|
|
:param peer: The channel to join.
|
|
:param reason: The reason for joining.
|
|
:param assure_joined: If set, module will not be loaded unless the required channel is joined.
|
|
⚠️ Works only in `client_ready`!
|
|
⚠️ If user declines to join channel, he will not be asked to
|
|
join again, so unless he joins it manually, module will not be loaded
|
|
ever.
|
|
:return: Status of the request.
|
|
:rtype: bool
|
|
:notice: This method will block module loading until the request is approved or declined.
|
|
"""
|
|
from . import utils
|
|
|
|
event = asyncio.Event()
|
|
await self.client(
|
|
UpdateNotifySettingsRequest(
|
|
peer=self.inline.bot_username,
|
|
settings=InputPeerNotifySettings(show_previews=False, silent=False),
|
|
)
|
|
)
|
|
|
|
channel = await self.client.get_entity(peer)
|
|
if channel.id in self._db.get("heroku.main", "declined_joins", []):
|
|
if assure_joined:
|
|
raise LoadError(
|
|
f"You need to join @{channel.username} in order to use this module"
|
|
)
|
|
|
|
return False
|
|
|
|
if not isinstance(channel, Channel):
|
|
raise TypeError("`peer` field must be a channel")
|
|
|
|
if getattr(channel, "left", True):
|
|
channel = await self.client.force_get_entity(peer)
|
|
|
|
if not getattr(channel, "left", True):
|
|
return True
|
|
|
|
await self.inline.bot.send_photo(
|
|
self.tg_id,
|
|
"https://imgur.com/a/gWKLn7h.png",
|
|
caption=(
|
|
self._client.loader.lookup("translations")
|
|
.strings("requested_join")
|
|
.format(
|
|
self.__class__.__name__,
|
|
channel.username,
|
|
utils.escape_html(channel.title),
|
|
utils.escape_html(reason),
|
|
)
|
|
),
|
|
reply_markup=self.inline.generate_markup(
|
|
[
|
|
{
|
|
"text": "💫 Approve",
|
|
"callback": self.lookup("loader").approve_internal,
|
|
"args": (channel, event),
|
|
},
|
|
{
|
|
"text": "✖️ Decline",
|
|
"callback": self._decline,
|
|
"args": (channel, event),
|
|
},
|
|
]
|
|
),
|
|
)
|
|
|
|
self.heroku_wait_channel_approve = (
|
|
self.__class__.__name__,
|
|
channel,
|
|
reason,
|
|
)
|
|
event.status = False
|
|
await event.wait()
|
|
|
|
with contextlib.suppress(AttributeError):
|
|
delattr(self, "heroku_wait_channel_approve")
|
|
|
|
if assure_joined and not event.status:
|
|
raise LoadError(
|
|
f"You need to join @{channel.username} in order to use this module"
|
|
)
|
|
|
|
return event.status
|
|
|
|
async def import_lib(
|
|
self,
|
|
url: str,
|
|
*,
|
|
suspend_on_error: typing.Optional[bool] = False,
|
|
_did_requirements: bool = False,
|
|
) -> "Library":
|
|
"""
|
|
Import library from url and register it in :obj:`Modules`
|
|
:param url: Url to import
|
|
:param suspend_on_error: Will raise :obj:`loader.SelfSuspend` if library can't be loaded
|
|
:return: :obj:`Library`
|
|
:raise: SelfUnload if :attr:`suspend_on_error` is True and error occurred
|
|
:raise: HTTPError if library is not found
|
|
:raise: ImportError if library doesn't have any class which is a subclass of :obj:`loader.Library`
|
|
:raise: ImportError if library name doesn't end with `Lib`
|
|
:raise: RuntimeError if library throws in :method:`init`
|
|
:raise: RuntimeError if library classname exists in :obj:`Modules`.libraries
|
|
"""
|
|
|
|
from . import utils # Avoiding circular import
|
|
from .loader import USER_INSTALL, VALID_PIP_PACKAGES
|
|
from .translations import Strings
|
|
|
|
def _raise(e: Exception):
|
|
if suspend_on_error:
|
|
raise SelfSuspend("Required library is not available or is corrupted.")
|
|
|
|
raise e
|
|
|
|
if not utils.check_url(url):
|
|
_raise(ValueError("Invalid url for library"))
|
|
|
|
code = await utils.run_sync(requests.get, url)
|
|
code.raise_for_status()
|
|
code = code.text
|
|
|
|
if re.search(r"# ?scope: ?heroku_min", code):
|
|
ver = tuple(
|
|
map(
|
|
int,
|
|
re.search(r"# ?scope: ?heroku_min ((\d+\.){2}\d+)", code)[1].split(
|
|
"."
|
|
),
|
|
)
|
|
)
|
|
|
|
if version.__version__ < ver:
|
|
_raise(
|
|
RuntimeError(
|
|
f"Library requires Heroku version {'{}.{}.{}'.format(*ver)}+"
|
|
)
|
|
)
|
|
|
|
module = f"heroku.libraries.{url.replace('%', '%%').replace('.', '%d')}"
|
|
origin = f"<library {url}>"
|
|
|
|
spec = importlib.machinery.ModuleSpec(
|
|
module,
|
|
StringLoader(code, origin),
|
|
origin=origin,
|
|
)
|
|
try:
|
|
instance = importlib.util.module_from_spec(spec)
|
|
sys.modules[module] = instance
|
|
spec.loader.exec_module(instance)
|
|
except ImportError as e:
|
|
logger.info(
|
|
"Library loading failed, attemping dependency installation (%s)",
|
|
e.name,
|
|
)
|
|
# Let's try to reinstall dependencies
|
|
try:
|
|
requirements = list(
|
|
filter(
|
|
lambda x: not x.startswith(("-", "_", ".")),
|
|
map(
|
|
str.strip,
|
|
VALID_PIP_PACKAGES.search(code)[1].split(),
|
|
),
|
|
)
|
|
)
|
|
except TypeError:
|
|
logger.warning(
|
|
"No valid pip packages specified in code, attemping"
|
|
" installation from error"
|
|
)
|
|
requirements = [e.name]
|
|
|
|
logger.debug("Installing requirements: %s", requirements)
|
|
|
|
if not requirements or _did_requirements:
|
|
_raise(e)
|
|
|
|
pip = await asyncio.create_subprocess_exec(
|
|
sys.executable,
|
|
"-m",
|
|
"pip",
|
|
"install",
|
|
"--upgrade",
|
|
"-q",
|
|
"--disable-pip-version-check",
|
|
"--no-warn-script-location",
|
|
*["--user"] if USER_INSTALL else [],
|
|
*requirements,
|
|
)
|
|
|
|
rc = await pip.wait()
|
|
|
|
if rc != 0:
|
|
_raise(e)
|
|
|
|
importlib.invalidate_caches()
|
|
|
|
kwargs = utils.get_kwargs()
|
|
kwargs["_did_requirements"] = True
|
|
|
|
return await self._mod_import_lib(**kwargs) # Try again
|
|
|
|
lib_obj = next(
|
|
(
|
|
value()
|
|
for value in vars(instance).values()
|
|
if inspect.isclass(value) and issubclass(value, Library)
|
|
),
|
|
None,
|
|
)
|
|
|
|
if not lib_obj:
|
|
_raise(ImportError("Invalid library. No class found"))
|
|
|
|
if not lib_obj.__class__.__name__.endswith("Lib"):
|
|
_raise(
|
|
ImportError(
|
|
"Invalid library. Classname {} does not end with 'Lib'".format(
|
|
lib_obj.__class__.__name__
|
|
)
|
|
)
|
|
)
|
|
|
|
if (
|
|
all(
|
|
line.replace(" ", "") != "#scope:no_stats" for line in code.splitlines()
|
|
)
|
|
and self._db.get("heroku.main", "stats", True)
|
|
and url is not None
|
|
and utils.check_url(url)
|
|
):
|
|
with contextlib.suppress(Exception):
|
|
await self.lookup("loader")._send_stats(url)
|
|
|
|
lib_obj.source_url = url.strip("/")
|
|
lib_obj.allmodules = self.allmodules
|
|
lib_obj.internal_init()
|
|
|
|
for old_lib in self.allmodules.libraries:
|
|
if old_lib.name == lib_obj.name and (
|
|
not isinstance(getattr(old_lib, "version", None), tuple)
|
|
and not isinstance(getattr(lib_obj, "version", None), tuple)
|
|
or old_lib.version >= lib_obj.version
|
|
):
|
|
logger.debug("Using existing instance of library %s", old_lib.name)
|
|
return old_lib
|
|
|
|
if hasattr(lib_obj, "init"):
|
|
if not callable(lib_obj.init):
|
|
_raise(ValueError("Library init() must be callable"))
|
|
|
|
try:
|
|
await lib_obj.init()
|
|
except Exception:
|
|
_raise(RuntimeError("Library init() failed"))
|
|
|
|
if hasattr(lib_obj, "config"):
|
|
if not isinstance(lib_obj.config, LibraryConfig):
|
|
_raise(
|
|
RuntimeError("Library config must be a `LibraryConfig` instance")
|
|
)
|
|
|
|
libcfg = lib_obj.db.get(
|
|
lib_obj.__class__.__name__,
|
|
"__config__",
|
|
{},
|
|
)
|
|
|
|
for conf in lib_obj.config:
|
|
with contextlib.suppress(Exception):
|
|
lib_obj.config.set_no_raise(
|
|
conf,
|
|
(
|
|
libcfg[conf]
|
|
if conf in libcfg
|
|
else os.environ.get(f"{lib_obj.__class__.__name__}.{conf}")
|
|
or lib_obj.config.getdef(conf)
|
|
),
|
|
)
|
|
|
|
if hasattr(lib_obj, "strings"):
|
|
lib_obj.strings = Strings(lib_obj, self.translator)
|
|
|
|
lib_obj.translator = self.translator
|
|
|
|
for old_lib in self.allmodules.libraries:
|
|
if old_lib.name == lib_obj.name:
|
|
if hasattr(old_lib, "on_lib_update") and callable(
|
|
old_lib.on_lib_update
|
|
):
|
|
await old_lib.on_lib_update(lib_obj)
|
|
|
|
replace_all_refs(old_lib, lib_obj)
|
|
logger.debug(
|
|
"Replacing existing instance of library %s with updated object",
|
|
lib_obj.name,
|
|
)
|
|
return lib_obj
|
|
|
|
self.allmodules.libraries += [lib_obj]
|
|
return lib_obj
|
|
|
|
|
|
class Library:
|
|
"""All external libraries must have a class-inheritant from this class"""
|
|
|
|
def internal_init(self):
|
|
self.name = self.__class__.__name__
|
|
self.db = self.allmodules.db
|
|
self._db = self.allmodules.db
|
|
self.client = self.allmodules.client
|
|
self._client = self.allmodules.client
|
|
self.tg_id = self._client.tg_id
|
|
self._tg_id = self._client.tg_id
|
|
self.lookup = self.allmodules.lookup
|
|
self.get_prefix = self.allmodules.get_prefix
|
|
self.inline = self.allmodules.inline
|
|
self.allclients = self.allmodules.allclients
|
|
|
|
def _lib_get(
|
|
self,
|
|
key: str,
|
|
default: typing.Optional[JSONSerializable] = None,
|
|
) -> JSONSerializable:
|
|
return self._db.get(self.__class__.__name__, key, default)
|
|
|
|
def _lib_set(self, key: str, value: JSONSerializable) -> bool:
|
|
self._db.set(self.__class__.__name__, key, value)
|
|
|
|
def _lib_pointer(
|
|
self,
|
|
key: str,
|
|
default: typing.Optional[JSONSerializable] = None,
|
|
) -> typing.Union[JSONSerializable, PointerDict, PointerList]:
|
|
return self._db.pointer(self.__class__.__name__, key, default)
|
|
|
|
|
|
class LoadError(Exception):
|
|
"""Tells user, why your module can't be loaded, if raised in `client_ready`"""
|
|
|
|
def __init__(self, error_message: str): # skipcq: PYL-W0231
|
|
self._error = error_message
|
|
|
|
def __str__(self) -> str:
|
|
return self._error
|
|
|
|
|
|
class CoreOverwriteError(LoadError):
|
|
"""Is being raised when core module or command is overwritten"""
|
|
|
|
def __init__(
|
|
self,
|
|
module: typing.Optional[str] = None,
|
|
command: typing.Optional[str] = None,
|
|
):
|
|
self.type = "module" if module else "command"
|
|
self.target = module or command
|
|
super().__init__(str(self))
|
|
|
|
def __str__(self) -> str:
|
|
return (
|
|
f"{'Module' if self.type == 'module' else 'command'} {self.target} will not"
|
|
" be overwritten, because it's core"
|
|
)
|
|
|
|
|
|
class CoreUnloadError(Exception):
|
|
"""Is being raised when user tries to unload core module"""
|
|
|
|
def __init__(self, module: str):
|
|
self.module = module
|
|
super().__init__()
|
|
|
|
def __str__(self) -> str:
|
|
return f"Module {self.module} will not be unloaded, because it's core"
|
|
|
|
|
|
class SelfUnload(Exception):
|
|
"""Silently unloads module, if raised in `client_ready`"""
|
|
|
|
def __init__(self, error_message: str = ""):
|
|
super().__init__()
|
|
self._error = error_message
|
|
|
|
def __str__(self) -> str:
|
|
return self._error
|
|
|
|
|
|
class SelfSuspend(Exception):
|
|
"""
|
|
Silently suspends module, if raised in `client_ready`
|
|
Commands and watcher will not be registered if raised
|
|
Module won't be unloaded from db and will be unfreezed after restart, unless
|
|
the exception is raised again
|
|
"""
|
|
|
|
def __init__(self, error_message: str = ""):
|
|
super().__init__()
|
|
self._error = error_message
|
|
|
|
def __str__(self) -> str:
|
|
return self._error
|
|
|
|
|
|
class StopLoop(Exception):
|
|
"""Stops the loop, in which is raised"""
|
|
|
|
|
|
class ModuleConfig(dict):
|
|
"""Stores config for modules and apparently libraries"""
|
|
|
|
def __init__(self, *entries: typing.Union[str, "ConfigValue"]):
|
|
if all(isinstance(entry, ConfigValue) for entry in entries):
|
|
# New config format processing
|
|
self._config = {config.option: config for config in entries}
|
|
else:
|
|
# Legacy config processing
|
|
keys = []
|
|
values = []
|
|
defaults = []
|
|
docstrings = []
|
|
for i, entry in enumerate(entries):
|
|
if i % 3 == 0:
|
|
keys += [entry]
|
|
elif i % 3 == 1:
|
|
values += [entry]
|
|
defaults += [entry]
|
|
else:
|
|
docstrings += [entry]
|
|
|
|
self._config = {
|
|
key: ConfigValue(option=key, default=default, doc=doc)
|
|
for key, default, doc in zip(keys, defaults, docstrings)
|
|
}
|
|
|
|
super().__init__(
|
|
{option: config.value for option, config in self._config.items()}
|
|
)
|
|
|
|
def getdoc(self, key: str, message: typing.Optional[Message] = None) -> str:
|
|
"""Get the documentation by key"""
|
|
ret = self._config[key].doc
|
|
|
|
if callable(ret):
|
|
try:
|
|
# Compatibility tweak
|
|
# does nothing in Heroku
|
|
ret = ret(message)
|
|
except Exception:
|
|
ret = ret()
|
|
|
|
return ret
|
|
|
|
def getdef(self, key: str) -> str:
|
|
"""Get the default value by key"""
|
|
return self._config[key].default
|
|
|
|
def __setitem__(self, key: str, value: typing.Any):
|
|
self._config[key].value = value
|
|
super().__setitem__(key, value)
|
|
|
|
def set_no_raise(self, key: str, value: typing.Any):
|
|
self._config[key].set_no_raise(value)
|
|
super().__setitem__(key, value)
|
|
|
|
def __getitem__(self, key: str) -> typing.Any:
|
|
try:
|
|
return self._config[key].value
|
|
except KeyError:
|
|
return None
|
|
|
|
def reload(self):
|
|
for key in self._config:
|
|
super().__setitem__(key, self._config[key].value)
|
|
|
|
def change_validator(
|
|
self,
|
|
key: str,
|
|
validator: typing.Callable[[JSONSerializable], JSONSerializable],
|
|
):
|
|
self._config[key].validator = validator
|
|
|
|
|
|
LibraryConfig = ModuleConfig
|
|
|
|
|
|
class _Placeholder:
|
|
"""Placeholder to determine if the default value is going to be set"""
|
|
|
|
|
|
async def wrap(func: typing.Callable[[], typing.Awaitable]) -> typing.Any:
|
|
with contextlib.suppress(Exception):
|
|
return await func()
|
|
|
|
|
|
def syncwrap(func: typing.Callable[[], typing.Any]) -> typing.Any:
|
|
with contextlib.suppress(Exception):
|
|
return func()
|
|
|
|
|
|
@dataclass(repr=True)
|
|
class ConfigValue:
|
|
option: str
|
|
default: typing.Any = None
|
|
doc: typing.Union[typing.Callable[[], str], str] = "No description"
|
|
value: typing.Any = field(default_factory=_Placeholder)
|
|
validator: typing.Optional[
|
|
typing.Callable[[JSONSerializable], JSONSerializable]
|
|
] = None
|
|
on_change: typing.Optional[
|
|
typing.Union[typing.Callable[[], typing.Awaitable], typing.Callable]
|
|
] = None
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.value, _Placeholder):
|
|
self.value = self.default
|
|
|
|
def set_no_raise(self, value: typing.Any) -> bool:
|
|
"""
|
|
Sets the config value w/o ValidationError being raised
|
|
Should not be used uninternally
|
|
"""
|
|
return self.__setattr__("value", value, ignore_validation=True)
|
|
|
|
def __setattr__(
|
|
self,
|
|
key: str,
|
|
value: typing.Any,
|
|
*,
|
|
ignore_validation: bool = False,
|
|
):
|
|
if key == "value":
|
|
try:
|
|
value = ast.literal_eval(value)
|
|
except Exception:
|
|
pass
|
|
|
|
# Convert value to list if it's tuple just not to mess up
|
|
# with json convertations
|
|
if isinstance(value, (set, tuple)):
|
|
value = list(value)
|
|
|
|
if isinstance(value, list):
|
|
value = [
|
|
item.strip() if isinstance(item, str) else item for item in value
|
|
]
|
|
|
|
if self.validator is not None:
|
|
if value is not None:
|
|
from . import validators
|
|
|
|
try:
|
|
value = self.validator.validate(value)
|
|
except validators.ValidationError as e:
|
|
if not ignore_validation:
|
|
raise e
|
|
|
|
logger.debug(
|
|
"Config value was broken (%s), so it was reset to %s",
|
|
value,
|
|
self.default,
|
|
)
|
|
|
|
value = self.default
|
|
else:
|
|
defaults = {
|
|
"String": "",
|
|
"Integer": 0,
|
|
"Boolean": False,
|
|
"Series": [],
|
|
"Float": 0.0,
|
|
}
|
|
|
|
if self.validator.internal_id in defaults:
|
|
logger.debug(
|
|
"Config value was None, so it was reset to %s",
|
|
defaults[self.validator.internal_id],
|
|
)
|
|
value = defaults[self.validator.internal_id]
|
|
|
|
# This attribute will tell the `Loader` to save this value in db
|
|
self._save_marker = True
|
|
|
|
object.__setattr__(self, key, value)
|
|
|
|
if key == "value" and not ignore_validation and callable(self.on_change):
|
|
if inspect.iscoroutinefunction(self.on_change):
|
|
asyncio.ensure_future(wrap(self.on_change))
|
|
else:
|
|
syncwrap(self.on_change)
|
|
|
|
|
|
def _get_members(
|
|
mod: Module,
|
|
ending: str,
|
|
attribute: typing.Optional[str] = None,
|
|
strict: bool = False,
|
|
) -> dict:
|
|
"""Get method of module, which end with ending"""
|
|
return {
|
|
(
|
|
method_name.rsplit(ending, maxsplit=1)[0]
|
|
if (method_name == ending if strict else method_name.endswith(ending))
|
|
else method_name
|
|
).lower(): getattr(mod, method_name)
|
|
for method_name in dir(mod)
|
|
if not isinstance(getattr(type(mod), method_name, None), property)
|
|
and callable(getattr(mod, method_name))
|
|
and (
|
|
(method_name == ending if strict else method_name.endswith(ending))
|
|
or attribute
|
|
and getattr(getattr(mod, method_name), attribute, False)
|
|
)
|
|
}
|
|
|
|
|
|
class CacheRecordEntity:
|
|
def __init__(
|
|
self,
|
|
hashable_entity: "Hashable", # type: ignore # noqa: F821
|
|
resolved_entity: EntityLike,
|
|
exp: int,
|
|
):
|
|
self.entity = copy.deepcopy(resolved_entity)
|
|
self._hashable_entity = copy.deepcopy(hashable_entity)
|
|
self._exp = round(time.time() + exp)
|
|
self.ts = time.time()
|
|
|
|
@property
|
|
def expired(self) -> bool:
|
|
return self._exp < time.time()
|
|
|
|
def __eq__(self, record: "CacheRecordEntity") -> bool:
|
|
return hash(record) == hash(self)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self._hashable_entity)
|
|
|
|
def __str__(self) -> str:
|
|
return f"CacheRecordEntity of {self.entity}"
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"CacheRecordEntity(entity={type(self.entity).__name__}(...),"
|
|
f" exp={self._exp})"
|
|
)
|
|
|
|
|
|
class CacheRecordPerms:
|
|
def __init__(
|
|
self,
|
|
hashable_entity: "Hashable", # type: ignore # noqa: F821
|
|
hashable_user: "Hashable", # type: ignore # noqa: F821
|
|
resolved_perms: EntityLike,
|
|
exp: int,
|
|
):
|
|
self.perms = copy.deepcopy(resolved_perms)
|
|
self._hashable_entity = copy.deepcopy(hashable_entity)
|
|
self._hashable_user = copy.deepcopy(hashable_user)
|
|
self._exp = round(time.time() + exp)
|
|
self.ts = time.time()
|
|
|
|
@property
|
|
def expired(self) -> bool:
|
|
return self._exp < time.time()
|
|
|
|
def __eq__(self, record: "CacheRecordPerms") -> bool:
|
|
return hash(record) == hash(self)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((self._hashable_entity, self._hashable_user))
|
|
|
|
def __str__(self) -> str:
|
|
return f"CacheRecordPerms of {self.perms}"
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})"
|
|
)
|
|
|
|
|
|
class CacheRecordFullChannel:
|
|
def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int):
|
|
self.channel_id = channel_id
|
|
self.full_channel = full_channel
|
|
self._exp = round(time.time() + exp)
|
|
self.ts = time.time()
|
|
|
|
@property
|
|
def expired(self) -> bool:
|
|
return self._exp < time.time()
|
|
|
|
def __eq__(self, record: "CacheRecordFullChannel") -> bool:
|
|
return hash(record) == hash(self)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((self._hashable_entity, self._hashable_user))
|
|
|
|
def __str__(self) -> str:
|
|
return f"CacheRecordFullChannel of {self.channel_id}"
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"CacheRecordFullChannel(channel_id={self.channel_id}(...),"
|
|
f" exp={self._exp})"
|
|
)
|
|
|
|
|
|
class CacheRecordFullUser:
|
|
def __init__(self, user_id: int, full_user: UserFull, exp: int):
|
|
self.user_id = user_id
|
|
self.full_user = full_user
|
|
self._exp = round(time.time() + exp)
|
|
self.ts = time.time()
|
|
|
|
@property
|
|
def expired(self) -> bool:
|
|
return self._exp < time.time()
|
|
|
|
def __eq__(self, record: "CacheRecordFullUser") -> bool:
|
|
return hash(record) == hash(self)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((self._hashable_entity, self._hashable_user))
|
|
|
|
def __str__(self) -> str:
|
|
return f"CacheRecordFullUser of {self.user_id}"
|
|
|
|
def __repr__(self) -> str:
|
|
return f"CacheRecordFullUser(channel_id={self.user_id}(...), exp={self._exp})"
|
|
|
|
|
|
def get_commands(mod: Module) -> dict:
|
|
"""Introspect the module to get its commands"""
|
|
return _get_members(mod, "cmd", "is_command")
|
|
|
|
|
|
def get_inline_handlers(mod: Module) -> dict:
|
|
"""Introspect the module to get its inline handlers"""
|
|
return _get_members(mod, "_inline_handler", "is_inline_handler")
|
|
|
|
|
|
def get_callback_handlers(mod: Module) -> dict:
|
|
"""Introspect the module to get its callback handlers"""
|
|
return _get_members(mod, "_callback_handler", "is_callback_handler")
|
|
|
|
|
|
def get_watchers(mod: Module) -> dict:
|
|
"""Introspect the module to get its watchers"""
|
|
return _get_members(
|
|
mod,
|
|
"watcher",
|
|
"is_watcher",
|
|
strict=True,
|
|
)
|