mirror of https://github.com/coddrago/Heroku
564 lines
17 KiB
Python
564 lines
17 KiB
Python
# █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
|
||
# █▀█ █ █ █ █▀█ █▀▄ █
|
||
# © Copyright 2022
|
||
# https://t.me/hikariatama
|
||
#
|
||
# 🔒 Licensed under the GNU AGPLv3
|
||
# 🌐 https://www.gnu.org/licenses/agpl-3.0.html
|
||
|
||
|
||
import ast
|
||
import asyncio
|
||
import contextlib
|
||
import copy
|
||
import inspect
|
||
import logging
|
||
import time
|
||
import typing
|
||
from dataclasses import dataclass, field
|
||
from importlib.abc import SourceLoader
|
||
|
||
from telethon.hints import EntityLike
|
||
from telethon.tl.types import ChannelFull, Message, UserFull
|
||
|
||
from . import validators # skipcq: PY-W2000
|
||
from .inline.types import BotInlineMessage # skipcq: PY-W2000
|
||
from .inline.types import (
|
||
BotInlineCall,
|
||
BotMessage,
|
||
InlineCall,
|
||
InlineMessage,
|
||
InlineQuery,
|
||
InlineUnit,
|
||
)
|
||
from .pointers import PointerDict, PointerList # skipcq: PY-W2000
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
JSONSerializable = typing.Union[str, int, float, bool, list, dict, None]
|
||
HikkaReplyMarkup = typing.Union[typing.List[typing.List[dict]], typing.List[dict], dict]
|
||
ListLike = typing.Union[list, set, tuple]
|
||
|
||
|
||
class StringLoader(SourceLoader):
|
||
"""Load a python module/file from a string"""
|
||
|
||
def __init__(self, data: str, origin: str):
|
||
self.data = data.encode("utf-8") if isinstance(data, str) else data
|
||
self.origin = origin
|
||
|
||
def get_source(self, _=None) -> str:
|
||
return self.data.decode("utf-8")
|
||
|
||
def get_code(self, fullname: str) -> bytes:
|
||
return (
|
||
compile(source, self.origin, "exec", dont_inherit=True)
|
||
if (source := self.get_data(fullname))
|
||
else None
|
||
)
|
||
|
||
def get_filename(self, *args, **kwargs) -> str:
|
||
return self.origin
|
||
|
||
def get_data(self, *args, **kwargs) -> bytes:
|
||
return self.data
|
||
|
||
|
||
class Module:
|
||
strings = {"name": "Unknown"}
|
||
|
||
"""There is no help for this module"""
|
||
|
||
def config_complete(self):
|
||
"""Called when module.config is populated"""
|
||
|
||
async def client_ready(self, client, db):
|
||
"""Called after client is ready (after config_loaded)"""
|
||
|
||
async def on_unload(self):
|
||
"""Called after unloading / reloading module"""
|
||
|
||
async def on_dlmod(self, client, db):
|
||
"""
|
||
Called after the module is first time loaded with .dlmod or .loadmod
|
||
|
||
Possible use-cases:
|
||
- Send reaction to author's channel message
|
||
- Join author's channel
|
||
- Create asset folder
|
||
- ...
|
||
|
||
⚠️ Note, that any error there will not interrupt module load, and will just
|
||
send a message to logs with verbosity INFO and exception traceback
|
||
"""
|
||
|
||
def __getattr__(self, name: str):
|
||
if name in {"hikka_commands", "commands"}:
|
||
return get_commands(self)
|
||
|
||
if name in {"hikka_inline_handlers", "inline_handlers"}:
|
||
return get_inline_handlers(self)
|
||
|
||
if name in {"hikka_callback_handlers", "callback_handlers"}:
|
||
return get_callback_handlers(self)
|
||
|
||
if name in {"hikka_watchers", "watchers"}:
|
||
return get_watchers(self)
|
||
|
||
raise AttributeError(
|
||
f"Module {self.__class__.__name__} has no attribute {name}"
|
||
)
|
||
|
||
|
||
class DragonModule:
|
||
"""Module is running in compatibility mode with Dragon, so it might be unstable"""
|
||
|
||
# fmt: off
|
||
strings_ru = {"_cls_doc": "Модуль запущен в режиме совместимости с Dragon, поэтому он может быть нестабильным"}
|
||
strings_de = {"_cls_doc": "Das Modul wird im Dragon-Kompatibilitätsmodus ausgeführt, daher kann es instabil sein"}
|
||
strings_tr = {"_cls_doc": "Modül Dragon uyumluluğu modunda çalıştığı için istikrarsız olabilir"}
|
||
strings_uz = {"_cls_doc": "Modul Dragon muvofiqligi rejimida ishlamoqda, shuning uchun u beqaror bo'lishi mumkin"}
|
||
strings_kr = {"_cls_doc": "모듈이 드래곤 호환 모드로 실행되므로 불안정할 수 있습니다"}
|
||
strings_hi = {"_cls_doc": "ड्रैगन संगतता मोड में चल रहा मॉड्यूल, इसलिए यह अस्थिर हो सकता है"}
|
||
strings_ja = {"_cls_doc": "モジュールがドラゴン互換モードで実行されているため、不安定になる可能性があります"}
|
||
strings_ar = {"_cls_doc": "يعمل الوحدة في وضع التوافق مع Dragon ، لذلك قد يكون غير مستقرًا"}
|
||
strings_es = {"_cls_doc": "El módulo se está ejecutando en modo de compatibilidad con Dragon, por lo que puede ser inestable"}
|
||
strings_tt = {"_clc_doc": "Модуль Dragon белән ярашучанлык режимда эшли башлады, шуңа күрә ул тотрыксыз була ала"}
|
||
# fmt: on
|
||
|
||
def __init__(self):
|
||
self.name = "Unknown"
|
||
self.url = None
|
||
self.commands = {}
|
||
self.watchers = {}
|
||
self.hikka_watchers = {}
|
||
self.inline_handlers = {}
|
||
self.hikka_inline_handlers = {}
|
||
self.callback_handlers = {}
|
||
self.hikka_callback_handlers = {}
|
||
|
||
@property
|
||
def hikka_commands(self):
|
||
return self.commands
|
||
|
||
@property
|
||
def __origin__(self):
|
||
return f"<dragon {self.name}>"
|
||
|
||
def config_complete(self):
|
||
pass
|
||
|
||
async def client_ready(self):
|
||
pass
|
||
|
||
async def on_unload(self):
|
||
pass
|
||
|
||
async def on_dlmod(self):
|
||
pass
|
||
|
||
|
||
class Library:
|
||
"""All external libraries must have a class-inheritant from this class"""
|
||
|
||
|
||
class LoadError(Exception):
|
||
"""Tells user, why your module can't be loaded, if raised in `client_ready`"""
|
||
|
||
def __init__(self, error_message: str): # skipcq: PYL-W0231
|
||
self._error = error_message
|
||
|
||
def __str__(self) -> str:
|
||
return self._error
|
||
|
||
|
||
class CoreOverwriteError(LoadError):
|
||
"""Is being raised when core module or command is overwritten"""
|
||
|
||
def __init__(
|
||
self,
|
||
module: typing.Optional[str] = None,
|
||
command: typing.Optional[str] = None,
|
||
):
|
||
self.type = "module" if module else "command"
|
||
self.target = module or command
|
||
super().__init__(str(self))
|
||
|
||
def __str__(self) -> str:
|
||
return (
|
||
f"Module {self.target} will not be overwritten, because it's core"
|
||
if self.type == "module"
|
||
else f"Command {self.target} will not be overwritten, because it's core"
|
||
)
|
||
|
||
|
||
class CoreUnloadError(Exception):
|
||
"""Is being raised when user tries to unload core module"""
|
||
|
||
def __init__(self, module: str):
|
||
self.module = module
|
||
super().__init__()
|
||
|
||
def __str__(self) -> str:
|
||
return f"Module {self.module} will not be unloaded, because it's core"
|
||
|
||
|
||
class SelfUnload(Exception):
|
||
"""Silently unloads module, if raised in `client_ready`"""
|
||
|
||
def __init__(self, error_message: str = ""):
|
||
super().__init__()
|
||
self._error = error_message
|
||
|
||
def __str__(self) -> str:
|
||
return self._error
|
||
|
||
|
||
class SelfSuspend(Exception):
|
||
"""
|
||
Silently suspends module, if raised in `client_ready`
|
||
Commands and watcher will not be registered if raised
|
||
Module won't be unloaded from db and will be unfreezed after restart, unless
|
||
the exception is raised again
|
||
"""
|
||
|
||
def __init__(self, error_message: str = ""):
|
||
super().__init__()
|
||
self._error = error_message
|
||
|
||
def __str__(self) -> str:
|
||
return self._error
|
||
|
||
|
||
class StopLoop(Exception):
|
||
"""Stops the loop, in which is raised"""
|
||
|
||
|
||
class ModuleConfig(dict):
|
||
"""Stores config for modules and apparently libraries"""
|
||
|
||
def __init__(self, *entries):
|
||
if all(isinstance(entry, ConfigValue) for entry in entries):
|
||
# New config format processing
|
||
self._config = {config.option: config for config in entries}
|
||
else:
|
||
# Legacy config processing
|
||
keys = []
|
||
values = []
|
||
defaults = []
|
||
docstrings = []
|
||
for i, entry in enumerate(entries):
|
||
if i % 3 == 0:
|
||
keys += [entry]
|
||
elif i % 3 == 1:
|
||
values += [entry]
|
||
defaults += [entry]
|
||
else:
|
||
docstrings += [entry]
|
||
|
||
self._config = {
|
||
key: ConfigValue(option=key, default=default, doc=doc)
|
||
for key, default, doc in zip(keys, defaults, docstrings)
|
||
}
|
||
|
||
super().__init__(
|
||
{option: config.value for option, config in self._config.items()}
|
||
)
|
||
|
||
def getdoc(self, key: str, message: Message = None) -> str:
|
||
"""Get the documentation by key"""
|
||
ret = self._config[key].doc
|
||
|
||
if callable(ret):
|
||
try:
|
||
# Compatibility tweak
|
||
# does nothing in Hikka
|
||
ret = ret(message)
|
||
except Exception:
|
||
ret = ret()
|
||
|
||
return ret
|
||
|
||
def getdef(self, key: str) -> str:
|
||
"""Get the default value by key"""
|
||
return self._config[key].default
|
||
|
||
def __setitem__(self, key: str, value: typing.Any):
|
||
self._config[key].value = value
|
||
super().__setitem__(key, value)
|
||
|
||
def set_no_raise(self, key: str, value: typing.Any):
|
||
self._config[key].set_no_raise(value)
|
||
super().__setitem__(key, value)
|
||
|
||
def __getitem__(self, key: str) -> typing.Any:
|
||
try:
|
||
return self._config[key].value
|
||
except KeyError:
|
||
return None
|
||
|
||
def reload(self):
|
||
for key in self._config:
|
||
super().__setitem__(key, self._config[key].value)
|
||
|
||
|
||
LibraryConfig = ModuleConfig
|
||
|
||
|
||
class _Placeholder:
|
||
"""Placeholder to determine if the default value is going to be set"""
|
||
|
||
|
||
async def wrap(func: typing.Awaitable):
|
||
with contextlib.suppress(Exception):
|
||
return await func()
|
||
|
||
|
||
def syncwrap(func: typing.Callable):
|
||
with contextlib.suppress(Exception):
|
||
return func()
|
||
|
||
|
||
@dataclass(repr=True)
|
||
class ConfigValue:
|
||
option: str
|
||
default: typing.Any = None
|
||
doc: typing.Union[callable, str] = "No description"
|
||
value: typing.Any = field(default_factory=_Placeholder)
|
||
validator: typing.Optional[callable] = None
|
||
on_change: typing.Optional[typing.Union[typing.Awaitable, typing.Callable]] = None
|
||
|
||
def __post_init__(self):
|
||
if isinstance(self.value, _Placeholder):
|
||
self.value = self.default
|
||
|
||
def set_no_raise(self, value: typing.Any) -> bool:
|
||
"""
|
||
Sets the config value w/o ValidationError being raised
|
||
Should not be used uninternally
|
||
"""
|
||
return self.__setattr__("value", value, ignore_validation=True)
|
||
|
||
def __setattr__(
|
||
self,
|
||
key: str,
|
||
value: typing.Any,
|
||
*,
|
||
ignore_validation: bool = False,
|
||
) -> bool:
|
||
if key == "value":
|
||
try:
|
||
value = ast.literal_eval(value)
|
||
except Exception:
|
||
pass
|
||
|
||
# Convert value to list if it's tuple just not to mess up
|
||
# with json convertations
|
||
if isinstance(value, (set, tuple)):
|
||
value = list(value)
|
||
|
||
if isinstance(value, list):
|
||
value = [
|
||
item.strip() if isinstance(item, str) else item for item in value
|
||
]
|
||
|
||
if self.validator is not None:
|
||
if value is not None:
|
||
try:
|
||
value = self.validator.validate(value)
|
||
except validators.ValidationError as e:
|
||
if not ignore_validation:
|
||
raise e
|
||
|
||
logger.debug(
|
||
"Config value was broken (%s), so it was reset to %s",
|
||
value,
|
||
self.default,
|
||
)
|
||
|
||
value = self.default
|
||
else:
|
||
defaults = {
|
||
"String": "",
|
||
"Integer": 0,
|
||
"Boolean": False,
|
||
"Series": [],
|
||
"Float": 0.0,
|
||
}
|
||
|
||
if self.validator.internal_id in defaults:
|
||
logger.debug(
|
||
"Config value was None, so it was reset to %s",
|
||
defaults[self.validator.internal_id],
|
||
)
|
||
value = defaults[self.validator.internal_id]
|
||
|
||
# This attribute will tell the `Loader` to save this value in db
|
||
self._save_marker = True
|
||
|
||
object.__setattr__(self, key, value)
|
||
|
||
if key == "value" and not ignore_validation and callable(self.on_change):
|
||
if inspect.iscoroutinefunction(self.on_change):
|
||
asyncio.ensure_future(wrap(self.on_change))
|
||
else:
|
||
syncwrap(self.on_change)
|
||
|
||
|
||
def _get_members(
|
||
mod: Module,
|
||
ending: str,
|
||
attribute: typing.Optional[str] = None,
|
||
strict: bool = False,
|
||
) -> dict:
|
||
"""Get method of module, which end with ending"""
|
||
return {
|
||
(
|
||
method_name.rsplit(ending, maxsplit=1)[0]
|
||
if (method_name == ending if strict else method_name.endswith(ending))
|
||
else method_name
|
||
).lower(): getattr(mod, method_name)
|
||
for method_name in dir(mod)
|
||
if callable(getattr(mod, method_name))
|
||
and (
|
||
(method_name == ending if strict else method_name.endswith(ending))
|
||
or attribute
|
||
and getattr(getattr(mod, method_name), attribute, False)
|
||
)
|
||
}
|
||
|
||
|
||
class CacheRecord:
|
||
def __init__(
|
||
self,
|
||
hashable_entity: "Hashable", # type: ignore
|
||
resolved_entity: EntityLike,
|
||
exp: int,
|
||
):
|
||
self.entity = copy.deepcopy(resolved_entity)
|
||
self._hashable_entity = copy.deepcopy(hashable_entity)
|
||
self._exp = round(time.time() + exp)
|
||
self.ts = time.time()
|
||
|
||
def expired(self):
|
||
return self._exp < time.time()
|
||
|
||
def __eq__(self, record: "CacheRecord"):
|
||
return hash(record) == hash(self)
|
||
|
||
def __hash__(self):
|
||
return hash(self._hashable_entity)
|
||
|
||
def __str__(self):
|
||
return f"CacheRecord of {self.entity}"
|
||
|
||
def __repr__(self):
|
||
return f"CacheRecord(entity={type(self.entity).__name__}(...), exp={self._exp})"
|
||
|
||
|
||
class CacheRecordPerms:
|
||
def __init__(
|
||
self,
|
||
hashable_entity: "Hashable", # type: ignore
|
||
hashable_user: "Hashable", # type: ignore
|
||
resolved_perms: EntityLike,
|
||
exp: int,
|
||
):
|
||
self.perms = copy.deepcopy(resolved_perms)
|
||
self._hashable_entity = copy.deepcopy(hashable_entity)
|
||
self._hashable_user = copy.deepcopy(hashable_user)
|
||
self._exp = round(time.time() + exp)
|
||
self.ts = time.time()
|
||
|
||
def expired(self):
|
||
return self._exp < time.time()
|
||
|
||
def __eq__(self, record: "CacheRecordPerms"):
|
||
return hash(record) == hash(self)
|
||
|
||
def __hash__(self):
|
||
return hash((self._hashable_entity, self._hashable_user))
|
||
|
||
def __str__(self):
|
||
return f"CacheRecordPerms of {self.perms}"
|
||
|
||
def __repr__(self):
|
||
return (
|
||
f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})"
|
||
)
|
||
|
||
|
||
class CacheRecordFullChannel:
|
||
def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int):
|
||
self.channel_id = channel_id
|
||
self.full_channel = full_channel
|
||
self._exp = round(time.time() + exp)
|
||
self.ts = time.time()
|
||
|
||
def expired(self):
|
||
return self._exp < time.time()
|
||
|
||
def __eq__(self, record: "CacheRecordFullChannel"):
|
||
return hash(record) == hash(self)
|
||
|
||
def __hash__(self):
|
||
return hash((self._hashable_entity, self._hashable_user))
|
||
|
||
def __str__(self):
|
||
return f"CacheRecordFullChannel of {self.channel_id}"
|
||
|
||
def __repr__(self):
|
||
return (
|
||
f"CacheRecordFullChannel(channel_id={self.channel_id}(...),"
|
||
f" exp={self._exp})"
|
||
)
|
||
|
||
|
||
class CacheRecordFullUser:
|
||
def __init__(self, user_id: int, full_user: UserFull, exp: int):
|
||
self.user_id = user_id
|
||
self.full_user = full_user
|
||
self._exp = round(time.time() + exp)
|
||
self.ts = time.time()
|
||
|
||
def expired(self):
|
||
return self._exp < time.time()
|
||
|
||
def __eq__(self, record: "CacheRecordFullUser"):
|
||
return hash(record) == hash(self)
|
||
|
||
def __hash__(self):
|
||
return hash((self._hashable_entity, self._hashable_user))
|
||
|
||
def __str__(self):
|
||
return f"CacheRecordFullUser of {self.user_id}"
|
||
|
||
def __repr__(self):
|
||
return f"CacheRecordFullUser(channel_id={self.user_id}(...), exp={self._exp})"
|
||
|
||
|
||
def get_commands(mod: Module) -> dict:
|
||
"""Introspect the module to get its commands"""
|
||
return _get_members(mod, "cmd", "is_command")
|
||
|
||
|
||
def get_inline_handlers(mod: Module) -> dict:
|
||
"""Introspect the module to get its inline handlers"""
|
||
return _get_members(mod, "_inline_handler", "is_inline_handler")
|
||
|
||
|
||
def get_callback_handlers(mod: Module) -> dict:
|
||
"""Introspect the module to get its callback handlers"""
|
||
return _get_members(mod, "_callback_handler", "is_callback_handler")
|
||
|
||
|
||
def get_watchers(mod: Module) -> dict:
|
||
"""Introspect the module to get its watchers"""
|
||
return _get_members(
|
||
mod,
|
||
"watcher",
|
||
"is_watcher",
|
||
strict=True,
|
||
)
|