# █ █ ▀ █▄▀ ▄▀█ █▀█ ▀ # █▀█ █ █ █ █▀█ █▀▄ █ # © Copyright 2022 # https://t.me/hikariatama # # 🔒 Licensed under the GNU AGPLv3 # 🌐 https://www.gnu.org/licenses/agpl-3.0.html import ast import asyncio import contextlib import copy import inspect import logging import time import typing from dataclasses import dataclass, field from importlib.abc import SourceLoader from telethon.hints import EntityLike from telethon.tl.types import ChannelFull, Message, UserFull from . import validators # skipcq: PY-W2000 from .inline.types import BotInlineMessage # skipcq: PY-W2000 from .inline.types import ( BotInlineCall, BotMessage, InlineCall, InlineMessage, InlineQuery, InlineUnit, ) from .pointers import PointerDict, PointerList # skipcq: PY-W2000 logger = logging.getLogger(__name__) JSONSerializable = typing.Union[str, int, float, bool, list, dict, None] HikkaReplyMarkup = typing.Union[typing.List[typing.List[dict]], typing.List[dict], dict] ListLike = typing.Union[list, set, tuple] class StringLoader(SourceLoader): """Load a python module/file from a string""" def __init__(self, data: str, origin: str): self.data = data.encode("utf-8") if isinstance(data, str) else data self.origin = origin def get_source(self, _=None) -> str: return self.data.decode("utf-8") def get_code(self, fullname: str) -> bytes: return ( compile(source, self.origin, "exec", dont_inherit=True) if (source := self.get_data(fullname)) else None ) def get_filename(self, *args, **kwargs) -> str: return self.origin def get_data(self, *args, **kwargs) -> bytes: return self.data class Module: strings = {"name": "Unknown"} """There is no help for this module""" def config_complete(self): """Called when module.config is populated""" async def client_ready(self, client, db): """Called after client is ready (after config_loaded)""" async def on_unload(self): """Called after unloading / reloading module""" async def on_dlmod(self, client, db): """ Called after the module is first time loaded with .dlmod or .loadmod Possible use-cases: - Send reaction to author's channel message - Join author's channel - Create asset folder - ... ⚠️ Note, that any error there will not interrupt module load, and will just send a message to logs with verbosity INFO and exception traceback """ def __getattr__(self, name: str): if name in {"hikka_commands", "commands"}: return get_commands(self) if name in {"hikka_inline_handlers", "inline_handlers"}: return get_inline_handlers(self) if name in {"hikka_callback_handlers", "callback_handlers"}: return get_callback_handlers(self) if name in {"hikka_watchers", "watchers"}: return get_watchers(self) raise AttributeError( f"Module {self.__class__.__name__} has no attribute {name}" ) class DragonModule: """Module is running in compatibility mode with Dragon, so it might be unstable""" # fmt: off strings_ru = {"_cls_doc": "Модуль запущен в режиме совместимости с Dragon, поэтому он может быть нестабильным"} strings_de = {"_cls_doc": "Das Modul wird im Dragon-Kompatibilitätsmodus ausgeführt, daher kann es instabil sein"} strings_tr = {"_cls_doc": "Modül Dragon uyumluluğu modunda çalıştığı için istikrarsız olabilir"} strings_uz = {"_cls_doc": "Modul Dragon muvofiqligi rejimida ishlamoqda, shuning uchun u beqaror bo'lishi mumkin"} strings_kr = {"_cls_doc": "모듈이 드래곤 호환 모드로 실행되므로 불안정할 수 있습니다"} strings_hi = {"_cls_doc": "ड्रैगन संगतता मोड में चल रहा मॉड्यूल, इसलिए यह अस्थिर हो सकता है"} strings_ja = {"_cls_doc": "モジュールがドラゴン互換モードで実行されているため、不安定になる可能性があります"} strings_ar = {"_cls_doc": "يعمل الوحدة في وضع التوافق مع Dragon ، لذلك قد يكون غير مستقرًا"} strings_es = {"_cls_doc": "El módulo se está ejecutando en modo de compatibilidad con Dragon, por lo que puede ser inestable"} strings_tt = {"_clc_doc": "Модуль Dragon белән ярашучанлык режимда эшли башлады, шуңа күрә ул тотрыксыз була ала"} # fmt: on def __init__(self): self.name = "Unknown" self.url = None self.commands = {} self.watchers = {} self.hikka_watchers = {} self.inline_handlers = {} self.hikka_inline_handlers = {} self.callback_handlers = {} self.hikka_callback_handlers = {} @property def hikka_commands(self): return self.commands @property def __origin__(self): return f"" def config_complete(self): pass async def client_ready(self): pass async def on_unload(self): pass async def on_dlmod(self): pass class Library: """All external libraries must have a class-inheritant from this class""" class LoadError(Exception): """Tells user, why your module can't be loaded, if raised in `client_ready`""" def __init__(self, error_message: str): # skipcq: PYL-W0231 self._error = error_message def __str__(self) -> str: return self._error class CoreOverwriteError(LoadError): """Is being raised when core module or command is overwritten""" def __init__( self, module: typing.Optional[str] = None, command: typing.Optional[str] = None, ): self.type = "module" if module else "command" self.target = module or command super().__init__(str(self)) def __str__(self) -> str: return ( f"Module {self.target} will not be overwritten, because it's core" if self.type == "module" else f"Command {self.target} will not be overwritten, because it's core" ) class CoreUnloadError(Exception): """Is being raised when user tries to unload core module""" def __init__(self, module: str): self.module = module super().__init__() def __str__(self) -> str: return f"Module {self.module} will not be unloaded, because it's core" class SelfUnload(Exception): """Silently unloads module, if raised in `client_ready`""" def __init__(self, error_message: str = ""): super().__init__() self._error = error_message def __str__(self) -> str: return self._error class SelfSuspend(Exception): """ Silently suspends module, if raised in `client_ready` Commands and watcher will not be registered if raised Module won't be unloaded from db and will be unfreezed after restart, unless the exception is raised again """ def __init__(self, error_message: str = ""): super().__init__() self._error = error_message def __str__(self) -> str: return self._error class StopLoop(Exception): """Stops the loop, in which is raised""" class ModuleConfig(dict): """Stores config for modules and apparently libraries""" def __init__(self, *entries): if all(isinstance(entry, ConfigValue) for entry in entries): # New config format processing self._config = {config.option: config for config in entries} else: # Legacy config processing keys = [] values = [] defaults = [] docstrings = [] for i, entry in enumerate(entries): if i % 3 == 0: keys += [entry] elif i % 3 == 1: values += [entry] defaults += [entry] else: docstrings += [entry] self._config = { key: ConfigValue(option=key, default=default, doc=doc) for key, default, doc in zip(keys, defaults, docstrings) } super().__init__( {option: config.value for option, config in self._config.items()} ) def getdoc(self, key: str, message: Message = None) -> str: """Get the documentation by key""" ret = self._config[key].doc if callable(ret): try: # Compatibility tweak # does nothing in Hikka ret = ret(message) except Exception: ret = ret() return ret def getdef(self, key: str) -> str: """Get the default value by key""" return self._config[key].default def __setitem__(self, key: str, value: typing.Any): self._config[key].value = value super().__setitem__(key, value) def set_no_raise(self, key: str, value: typing.Any): self._config[key].set_no_raise(value) super().__setitem__(key, value) def __getitem__(self, key: str) -> typing.Any: try: return self._config[key].value except KeyError: return None def reload(self): for key in self._config: super().__setitem__(key, self._config[key].value) LibraryConfig = ModuleConfig class _Placeholder: """Placeholder to determine if the default value is going to be set""" async def wrap(func: typing.Awaitable): with contextlib.suppress(Exception): return await func() def syncwrap(func: typing.Callable): with contextlib.suppress(Exception): return func() @dataclass(repr=True) class ConfigValue: option: str default: typing.Any = None doc: typing.Union[callable, str] = "No description" value: typing.Any = field(default_factory=_Placeholder) validator: typing.Optional[callable] = None on_change: typing.Optional[typing.Union[typing.Awaitable, typing.Callable]] = None def __post_init__(self): if isinstance(self.value, _Placeholder): self.value = self.default def set_no_raise(self, value: typing.Any) -> bool: """ Sets the config value w/o ValidationError being raised Should not be used uninternally """ return self.__setattr__("value", value, ignore_validation=True) def __setattr__( self, key: str, value: typing.Any, *, ignore_validation: bool = False, ) -> bool: if key == "value": try: value = ast.literal_eval(value) except Exception: pass # Convert value to list if it's tuple just not to mess up # with json convertations if isinstance(value, (set, tuple)): value = list(value) if isinstance(value, list): value = [ item.strip() if isinstance(item, str) else item for item in value ] if self.validator is not None: if value is not None: try: value = self.validator.validate(value) except validators.ValidationError as e: if not ignore_validation: raise e logger.debug( "Config value was broken (%s), so it was reset to %s", value, self.default, ) value = self.default else: defaults = { "String": "", "Integer": 0, "Boolean": False, "Series": [], "Float": 0.0, } if self.validator.internal_id in defaults: logger.debug( "Config value was None, so it was reset to %s", defaults[self.validator.internal_id], ) value = defaults[self.validator.internal_id] # This attribute will tell the `Loader` to save this value in db self._save_marker = True object.__setattr__(self, key, value) if key == "value" and not ignore_validation and callable(self.on_change): if inspect.iscoroutinefunction(self.on_change): asyncio.ensure_future(wrap(self.on_change)) else: syncwrap(self.on_change) def _get_members( mod: Module, ending: str, attribute: typing.Optional[str] = None, strict: bool = False, ) -> dict: """Get method of module, which end with ending""" return { ( method_name.rsplit(ending, maxsplit=1)[0] if (method_name == ending if strict else method_name.endswith(ending)) else method_name ).lower(): getattr(mod, method_name) for method_name in dir(mod) if callable(getattr(mod, method_name)) and ( (method_name == ending if strict else method_name.endswith(ending)) or attribute and getattr(getattr(mod, method_name), attribute, False) ) } class CacheRecord: def __init__( self, hashable_entity: "Hashable", # type: ignore resolved_entity: EntityLike, exp: int, ): self.entity = copy.deepcopy(resolved_entity) self._hashable_entity = copy.deepcopy(hashable_entity) self._exp = round(time.time() + exp) self.ts = time.time() def expired(self): return self._exp < time.time() def __eq__(self, record: "CacheRecord"): return hash(record) == hash(self) def __hash__(self): return hash(self._hashable_entity) def __str__(self): return f"CacheRecord of {self.entity}" def __repr__(self): return f"CacheRecord(entity={type(self.entity).__name__}(...), exp={self._exp})" class CacheRecordPerms: def __init__( self, hashable_entity: "Hashable", # type: ignore hashable_user: "Hashable", # type: ignore resolved_perms: EntityLike, exp: int, ): self.perms = copy.deepcopy(resolved_perms) self._hashable_entity = copy.deepcopy(hashable_entity) self._hashable_user = copy.deepcopy(hashable_user) self._exp = round(time.time() + exp) self.ts = time.time() def expired(self): return self._exp < time.time() def __eq__(self, record: "CacheRecordPerms"): return hash(record) == hash(self) def __hash__(self): return hash((self._hashable_entity, self._hashable_user)) def __str__(self): return f"CacheRecordPerms of {self.perms}" def __repr__(self): return ( f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})" ) class CacheRecordFullChannel: def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int): self.channel_id = channel_id self.full_channel = full_channel self._exp = round(time.time() + exp) self.ts = time.time() def expired(self): return self._exp < time.time() def __eq__(self, record: "CacheRecordFullChannel"): return hash(record) == hash(self) def __hash__(self): return hash((self._hashable_entity, self._hashable_user)) def __str__(self): return f"CacheRecordFullChannel of {self.channel_id}" def __repr__(self): return ( f"CacheRecordFullChannel(channel_id={self.channel_id}(...)," f" exp={self._exp})" ) class CacheRecordFullUser: def __init__(self, user_id: int, full_user: UserFull, exp: int): self.user_id = user_id self.full_user = full_user self._exp = round(time.time() + exp) self.ts = time.time() def expired(self): return self._exp < time.time() def __eq__(self, record: "CacheRecordFullUser"): return hash(record) == hash(self) def __hash__(self): return hash((self._hashable_entity, self._hashable_user)) def __str__(self): return f"CacheRecordFullUser of {self.user_id}" def __repr__(self): return f"CacheRecordFullUser(channel_id={self.user_id}(...), exp={self._exp})" def get_commands(mod: Module) -> dict: """Introspect the module to get its commands""" return _get_members(mod, "cmd", "is_command") def get_inline_handlers(mod: Module) -> dict: """Introspect the module to get its inline handlers""" return _get_members(mod, "_inline_handler", "is_inline_handler") def get_callback_handlers(mod: Module) -> dict: """Introspect the module to get its callback handlers""" return _get_members(mod, "_callback_handler", "is_callback_handler") def get_watchers(mod: Module) -> dict: """Introspect the module to get its watchers""" return _get_members( mod, "watcher", "is_watcher", strict=True, )