From b32a3cb4979937b8a4d47f3eb1d13802e2bada16 Mon Sep 17 00:00:00 2001 From: hikariatama Date: Mon, 19 Sep 2022 06:12:41 +0000 Subject: [PATCH] Attempt to fix library config dynamic reload regarding https://github.com/hikariatama/Hikka/issues/38#issuecomment-1210016533 --- hikka/inline/events.py | 2 +- hikka/loader.py | 109 ++++++++++++++++++++++------------------- 2 files changed, 59 insertions(+), 52 deletions(-) diff --git a/hikka/inline/events.py b/hikka/inline/events.py index 0cf0d48..f1d4dc6 100644 --- a/hikka/inline/events.py +++ b/hikka/inline/events.py @@ -417,7 +417,7 @@ class Events(InlineUnit): "text": "🏌️ Run command", "switch_inline_query_current_chat": f"{name} ", } - ) + ), ), f"🎹 @{self.bot_username} {utils.escape_html(name)} -" f" {utils.escape_html(doc)}\n", diff --git a/hikka/loader.py b/hikka/loader.py index 6d8dee2..ce7fc04 100644 --- a/hikka/loader.py +++ b/hikka/loader.py @@ -44,7 +44,7 @@ from telethon.tl.functions.account import UpdateNotifySettingsRequest from telethon.hints import EntityLike from types import FunctionType -from typing import Any, Optional, Union, List +import typing from . import security, utils, validators, version from .types import ( @@ -102,7 +102,7 @@ def proxy0(data): _CELLTYPE = type(proxy0(None).__closure__[0]) -def replace_all_refs(replace_from: Any, replace_to: Any) -> Any: +def replace_all_refs(replace_from: typing.Any, replace_to: typing.Any) -> typing.Any: """ :summary: Uses the :mod:`gc` module to replace all references to obj :attr:`replace_from` with :attr:`replace_to` (it tries it's best, @@ -196,7 +196,7 @@ def replace_all_refs(replace_from: Any, replace_to: Any) -> Any: replace_all_refs(referrer, newfn) else: - logging.debug(f"{referrer} is not supported.") + logger.debug("%s is not supported.", referrer) if hit is False: raise AttributeError(f"Object '{replace_from}' not found") @@ -233,7 +233,7 @@ class InfiniteLoop: interval: int, autostart: bool, wait_before: bool, - stop_clause: Union[str, None], + stop_clause: typing.Union[str, None], ): self.func = func self.interval = interval @@ -314,9 +314,9 @@ class InfiniteLoop: def loop( interval: int = 5, - autostart: Optional[bool] = False, - wait_before: Optional[bool] = False, - stop_clause: Optional[str] = None, + autostart: typing.Optional[bool] = False, + wait_before: typing.Optional[bool] = False, + stop_clause: typing.Optional[str] = None, ) -> FunctionType: """ Create new infinite loop from class method @@ -572,9 +572,9 @@ class Modules: async def register_all( self, - mods: Optional[List[str]] = None, + mods: typing.Optional[typing.List[str]] = None, no_external: bool = False, - ) -> List[Module]: + ) -> typing.List[Module]: """Load all modules in the module directory""" external_mods = [] @@ -613,7 +613,7 @@ class Modules: async def _register_modules( self, modules: list, origin: str = "" - ) -> List[Module]: + ) -> typing.List[Module]: with contextlib.suppress(AttributeError): _hikka_client_id_logging_tag = copy.copy(self.client.tg_id) @@ -837,7 +837,7 @@ class Modules: self, peer: EntityLike, reason: str, - assure_joined: Optional[bool] = False, + assure_joined: typing.Optional[bool] = False, _module: Module = None, ) -> bool: """ @@ -985,7 +985,7 @@ class Modules: def _get( self, key: str, - default: Optional[JSONSerializable] = None, + default: typing.Optional[JSONSerializable] = None, _owner: str = None, ) -> JSONSerializable: return self._db.get(_owner, key, default) @@ -996,7 +996,7 @@ class Modules: def _pointer( self, key: str, - default: Optional[JSONSerializable] = None, + default: typing.Optional[JSONSerializable] = None, _owner: str = None, ) -> JSONSerializable: return self._db.pointer(_owner, key, default) @@ -1005,7 +1005,7 @@ class Modules: self, url: str, *, - suspend_on_error: Optional[bool] = False, + suspend_on_error: typing.Optional[bool] = False, _did_requirements: bool = False, ) -> object: """ @@ -1055,7 +1055,9 @@ class Modules: origin = f"" spec = importlib.machinery.ModuleSpec( - module, StringLoader(code, origin), origin=origin + module, + StringLoader(code, origin), + origin=origin, ) try: instance = importlib.util.module_from_spec(spec) @@ -1127,7 +1129,13 @@ class Modules: _raise(ImportError("Invalid library. No class found")) if not lib_obj.__class__.__name__.endswith("Lib"): - _raise(ImportError("Invalid library. Class name must end with 'Lib'")) + _raise( + ImportError( + "Invalid library. Classname {} does not end with 'Lib'".format( + lib_obj.__class__.__name__ + ) + ) + ) if ( all( @@ -1152,15 +1160,18 @@ class Modules: lib_obj.tg_id = self.client.tg_id lib_obj.allmodules = self lib_obj._lib_get = partial( - self._get, _owner=lib_obj.__class__.__name__ - ) # skipcq + self._get, + _owner=lib_obj.__class__.__name__, + ) lib_obj._lib_set = partial( - self._set, _owner=lib_obj.__class__.__name__ - ) # skipcq + self._set, + _owner=lib_obj.__class__.__name__, + ) lib_obj._lib_pointer = partial( - self._pointer, _owner=lib_obj.__class__.__name__ - ) # skipcq - lib_obj.get_prefix = partial(self._db.get, "hikka.main", "command_prefix", ".") + self._pointer, + _owner=lib_obj.__class__.__name__, + ) + lib_obj.get_prefix = self.get_prefix for old_lib in self.libraries: if old_lib.name == lib_obj.name and ( @@ -1168,26 +1179,15 @@ class Modules: and not isinstance(getattr(lib_obj, "version", None), tuple) or old_lib.version >= lib_obj.version ): - logging.debug(f"Using existing instance of library {old_lib.name}") + logger.debug("Using existing instance of library %s", old_lib.name) return old_lib new = True - for old_lib in self.libraries: - if old_lib.name == lib_obj.name: - if hasattr(old_lib, "on_lib_update") and callable( - old_lib.on_lib_update - ): - await old_lib.on_lib_update(lib_obj) + if hasattr(lib_obj, "init"): + if not callable(lib_obj.init): + _raise(ValueError("Library init() must be callable")) - replace_all_refs(old_lib, lib_obj) - new = False - logging.debug( - "Replacing existing instance of library" - f" {lib_obj.name} with updated object" - ) - - if hasattr(lib_obj, "init") and callable(lib_obj.init): try: await lib_obj.init() except Exception: @@ -1222,9 +1222,21 @@ class Modules: lib_obj.translator = self._translator - if new: - self.libraries += [lib_obj] + for old_lib in self.libraries: + if old_lib.name == lib_obj.name: + if hasattr(old_lib, "on_lib_update") and callable( + old_lib.on_lib_update + ): + await old_lib.on_lib_update(lib_obj) + replace_all_refs(old_lib, lib_obj) + logger.debug( + "Replacing existing instance of library %s with updated object", + lib_obj.name, + ) + return lib_obj + + self.libraries += [lib_obj] return lib_obj def dispatch(self, _command: str) -> tuple: @@ -1244,7 +1256,7 @@ class Modules: for mod in self.modules: self.send_config_one(mod, skip_hook) - def send_config_one(self, mod: "Module", skip_hook: bool = False): + def send_config_one(self, mod: Module, skip_hook: bool = False): """Send config to single instance""" with contextlib.suppress(AttributeError): _hikka_client_id_logging_tag = copy.copy(self.client.tg_id) @@ -1311,9 +1323,9 @@ class Modules: async def _animate( self, - message: Union[Message, InlineMessage], - frames: List[str], - interval: Union[float, int], + message: typing.Union[Message, InlineMessage], + frames: typing.List[str], + interval: typing.Union[float, int], *, inline: bool = False, ) -> None: @@ -1425,7 +1437,7 @@ class Modules: name, ) - async def unload_module(self, classname: str) -> bool: + async def unload_module(self, classname: str) -> typing.List[str]: """Remove module and all stuff from it""" worked = [] @@ -1494,12 +1506,7 @@ class Modules: def remove_alias(self, alias: str) -> bool: """Remove an alias""" - try: - del self.aliases[alias.lower().strip()] - except KeyError: - return False - - return True + return bool(self.aliases.pop(alias.lower().strip(), None)) async def log(self, *args, **kwargs): """Unnecessary placeholder for logging"""