mirror of https://github.com/coddrago/Heroku
Attempt to fix library config dynamic reload regarding https://github.com/hikariatama/Hikka/issues/38#issuecomment-1210016533
parent
066f7a46a2
commit
b32a3cb497
|
@ -417,7 +417,7 @@ class Events(InlineUnit):
|
||||||
"text": "🏌️ Run command",
|
"text": "🏌️ Run command",
|
||||||
"switch_inline_query_current_chat": f"{name} ",
|
"switch_inline_query_current_chat": f"{name} ",
|
||||||
}
|
}
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
f"🎹 <code>@{self.bot_username} {utils.escape_html(name)}</code> -"
|
f"🎹 <code>@{self.bot_username} {utils.escape_html(name)}</code> -"
|
||||||
f" {utils.escape_html(doc)}\n",
|
f" {utils.escape_html(doc)}\n",
|
||||||
|
|
109
hikka/loader.py
109
hikka/loader.py
|
@ -44,7 +44,7 @@ from telethon.tl.functions.account import UpdateNotifySettingsRequest
|
||||||
from telethon.hints import EntityLike
|
from telethon.hints import EntityLike
|
||||||
|
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
from typing import Any, Optional, Union, List
|
import typing
|
||||||
|
|
||||||
from . import security, utils, validators, version
|
from . import security, utils, validators, version
|
||||||
from .types import (
|
from .types import (
|
||||||
|
@ -102,7 +102,7 @@ def proxy0(data):
|
||||||
_CELLTYPE = type(proxy0(None).__closure__[0])
|
_CELLTYPE = type(proxy0(None).__closure__[0])
|
||||||
|
|
||||||
|
|
||||||
def replace_all_refs(replace_from: Any, replace_to: Any) -> Any:
|
def replace_all_refs(replace_from: typing.Any, replace_to: typing.Any) -> typing.Any:
|
||||||
"""
|
"""
|
||||||
:summary: Uses the :mod:`gc` module to replace all references to obj
|
:summary: Uses the :mod:`gc` module to replace all references to obj
|
||||||
:attr:`replace_from` with :attr:`replace_to` (it tries it's best,
|
:attr:`replace_from` with :attr:`replace_to` (it tries it's best,
|
||||||
|
@ -196,7 +196,7 @@ def replace_all_refs(replace_from: Any, replace_to: Any) -> Any:
|
||||||
replace_all_refs(referrer, newfn)
|
replace_all_refs(referrer, newfn)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.debug(f"{referrer} is not supported.")
|
logger.debug("%s is not supported.", referrer)
|
||||||
|
|
||||||
if hit is False:
|
if hit is False:
|
||||||
raise AttributeError(f"Object '{replace_from}' not found")
|
raise AttributeError(f"Object '{replace_from}' not found")
|
||||||
|
@ -233,7 +233,7 @@ class InfiniteLoop:
|
||||||
interval: int,
|
interval: int,
|
||||||
autostart: bool,
|
autostart: bool,
|
||||||
wait_before: bool,
|
wait_before: bool,
|
||||||
stop_clause: Union[str, None],
|
stop_clause: typing.Union[str, None],
|
||||||
):
|
):
|
||||||
self.func = func
|
self.func = func
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
|
@ -314,9 +314,9 @@ class InfiniteLoop:
|
||||||
|
|
||||||
def loop(
|
def loop(
|
||||||
interval: int = 5,
|
interval: int = 5,
|
||||||
autostart: Optional[bool] = False,
|
autostart: typing.Optional[bool] = False,
|
||||||
wait_before: Optional[bool] = False,
|
wait_before: typing.Optional[bool] = False,
|
||||||
stop_clause: Optional[str] = None,
|
stop_clause: typing.Optional[str] = None,
|
||||||
) -> FunctionType:
|
) -> FunctionType:
|
||||||
"""
|
"""
|
||||||
Create new infinite loop from class method
|
Create new infinite loop from class method
|
||||||
|
@ -572,9 +572,9 @@ class Modules:
|
||||||
|
|
||||||
async def register_all(
|
async def register_all(
|
||||||
self,
|
self,
|
||||||
mods: Optional[List[str]] = None,
|
mods: typing.Optional[typing.List[str]] = None,
|
||||||
no_external: bool = False,
|
no_external: bool = False,
|
||||||
) -> List[Module]:
|
) -> typing.List[Module]:
|
||||||
"""Load all modules in the module directory"""
|
"""Load all modules in the module directory"""
|
||||||
external_mods = []
|
external_mods = []
|
||||||
|
|
||||||
|
@ -613,7 +613,7 @@ class Modules:
|
||||||
|
|
||||||
async def _register_modules(
|
async def _register_modules(
|
||||||
self, modules: list, origin: str = "<core>"
|
self, modules: list, origin: str = "<core>"
|
||||||
) -> List[Module]:
|
) -> typing.List[Module]:
|
||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
_hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
|
_hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
|
||||||
|
|
||||||
|
@ -837,7 +837,7 @@ class Modules:
|
||||||
self,
|
self,
|
||||||
peer: EntityLike,
|
peer: EntityLike,
|
||||||
reason: str,
|
reason: str,
|
||||||
assure_joined: Optional[bool] = False,
|
assure_joined: typing.Optional[bool] = False,
|
||||||
_module: Module = None,
|
_module: Module = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -985,7 +985,7 @@ class Modules:
|
||||||
def _get(
|
def _get(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
default: Optional[JSONSerializable] = None,
|
default: typing.Optional[JSONSerializable] = None,
|
||||||
_owner: str = None,
|
_owner: str = None,
|
||||||
) -> JSONSerializable:
|
) -> JSONSerializable:
|
||||||
return self._db.get(_owner, key, default)
|
return self._db.get(_owner, key, default)
|
||||||
|
@ -996,7 +996,7 @@ class Modules:
|
||||||
def _pointer(
|
def _pointer(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
default: Optional[JSONSerializable] = None,
|
default: typing.Optional[JSONSerializable] = None,
|
||||||
_owner: str = None,
|
_owner: str = None,
|
||||||
) -> JSONSerializable:
|
) -> JSONSerializable:
|
||||||
return self._db.pointer(_owner, key, default)
|
return self._db.pointer(_owner, key, default)
|
||||||
|
@ -1005,7 +1005,7 @@ class Modules:
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
*,
|
*,
|
||||||
suspend_on_error: Optional[bool] = False,
|
suspend_on_error: typing.Optional[bool] = False,
|
||||||
_did_requirements: bool = False,
|
_did_requirements: bool = False,
|
||||||
) -> object:
|
) -> object:
|
||||||
"""
|
"""
|
||||||
|
@ -1055,7 +1055,9 @@ class Modules:
|
||||||
origin = f"<library {url}>"
|
origin = f"<library {url}>"
|
||||||
|
|
||||||
spec = importlib.machinery.ModuleSpec(
|
spec = importlib.machinery.ModuleSpec(
|
||||||
module, StringLoader(code, origin), origin=origin
|
module,
|
||||||
|
StringLoader(code, origin),
|
||||||
|
origin=origin,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
instance = importlib.util.module_from_spec(spec)
|
instance = importlib.util.module_from_spec(spec)
|
||||||
|
@ -1127,7 +1129,13 @@ class Modules:
|
||||||
_raise(ImportError("Invalid library. No class found"))
|
_raise(ImportError("Invalid library. No class found"))
|
||||||
|
|
||||||
if not lib_obj.__class__.__name__.endswith("Lib"):
|
if not lib_obj.__class__.__name__.endswith("Lib"):
|
||||||
_raise(ImportError("Invalid library. Class name must end with 'Lib'"))
|
_raise(
|
||||||
|
ImportError(
|
||||||
|
"Invalid library. Classname {} does not end with 'Lib'".format(
|
||||||
|
lib_obj.__class__.__name__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
all(
|
all(
|
||||||
|
@ -1152,15 +1160,18 @@ class Modules:
|
||||||
lib_obj.tg_id = self.client.tg_id
|
lib_obj.tg_id = self.client.tg_id
|
||||||
lib_obj.allmodules = self
|
lib_obj.allmodules = self
|
||||||
lib_obj._lib_get = partial(
|
lib_obj._lib_get = partial(
|
||||||
self._get, _owner=lib_obj.__class__.__name__
|
self._get,
|
||||||
) # skipcq
|
_owner=lib_obj.__class__.__name__,
|
||||||
|
)
|
||||||
lib_obj._lib_set = partial(
|
lib_obj._lib_set = partial(
|
||||||
self._set, _owner=lib_obj.__class__.__name__
|
self._set,
|
||||||
) # skipcq
|
_owner=lib_obj.__class__.__name__,
|
||||||
|
)
|
||||||
lib_obj._lib_pointer = partial(
|
lib_obj._lib_pointer = partial(
|
||||||
self._pointer, _owner=lib_obj.__class__.__name__
|
self._pointer,
|
||||||
) # skipcq
|
_owner=lib_obj.__class__.__name__,
|
||||||
lib_obj.get_prefix = partial(self._db.get, "hikka.main", "command_prefix", ".")
|
)
|
||||||
|
lib_obj.get_prefix = self.get_prefix
|
||||||
|
|
||||||
for old_lib in self.libraries:
|
for old_lib in self.libraries:
|
||||||
if old_lib.name == lib_obj.name and (
|
if old_lib.name == lib_obj.name and (
|
||||||
|
@ -1168,26 +1179,15 @@ class Modules:
|
||||||
and not isinstance(getattr(lib_obj, "version", None), tuple)
|
and not isinstance(getattr(lib_obj, "version", None), tuple)
|
||||||
or old_lib.version >= lib_obj.version
|
or old_lib.version >= lib_obj.version
|
||||||
):
|
):
|
||||||
logging.debug(f"Using existing instance of library {old_lib.name}")
|
logger.debug("Using existing instance of library %s", old_lib.name)
|
||||||
return old_lib
|
return old_lib
|
||||||
|
|
||||||
new = True
|
new = True
|
||||||
|
|
||||||
for old_lib in self.libraries:
|
if hasattr(lib_obj, "init"):
|
||||||
if old_lib.name == lib_obj.name:
|
if not callable(lib_obj.init):
|
||||||
if hasattr(old_lib, "on_lib_update") and callable(
|
_raise(ValueError("Library init() must be callable"))
|
||||||
old_lib.on_lib_update
|
|
||||||
):
|
|
||||||
await old_lib.on_lib_update(lib_obj)
|
|
||||||
|
|
||||||
replace_all_refs(old_lib, lib_obj)
|
|
||||||
new = False
|
|
||||||
logging.debug(
|
|
||||||
"Replacing existing instance of library"
|
|
||||||
f" {lib_obj.name} with updated object"
|
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(lib_obj, "init") and callable(lib_obj.init):
|
|
||||||
try:
|
try:
|
||||||
await lib_obj.init()
|
await lib_obj.init()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -1222,9 +1222,21 @@ class Modules:
|
||||||
|
|
||||||
lib_obj.translator = self._translator
|
lib_obj.translator = self._translator
|
||||||
|
|
||||||
if new:
|
for old_lib in self.libraries:
|
||||||
self.libraries += [lib_obj]
|
if old_lib.name == lib_obj.name:
|
||||||
|
if hasattr(old_lib, "on_lib_update") and callable(
|
||||||
|
old_lib.on_lib_update
|
||||||
|
):
|
||||||
|
await old_lib.on_lib_update(lib_obj)
|
||||||
|
|
||||||
|
replace_all_refs(old_lib, lib_obj)
|
||||||
|
logger.debug(
|
||||||
|
"Replacing existing instance of library %s with updated object",
|
||||||
|
lib_obj.name,
|
||||||
|
)
|
||||||
|
return lib_obj
|
||||||
|
|
||||||
|
self.libraries += [lib_obj]
|
||||||
return lib_obj
|
return lib_obj
|
||||||
|
|
||||||
def dispatch(self, _command: str) -> tuple:
|
def dispatch(self, _command: str) -> tuple:
|
||||||
|
@ -1244,7 +1256,7 @@ class Modules:
|
||||||
for mod in self.modules:
|
for mod in self.modules:
|
||||||
self.send_config_one(mod, skip_hook)
|
self.send_config_one(mod, skip_hook)
|
||||||
|
|
||||||
def send_config_one(self, mod: "Module", skip_hook: bool = False):
|
def send_config_one(self, mod: Module, skip_hook: bool = False):
|
||||||
"""Send config to single instance"""
|
"""Send config to single instance"""
|
||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
_hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
|
_hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
|
||||||
|
@ -1311,9 +1323,9 @@ class Modules:
|
||||||
|
|
||||||
async def _animate(
|
async def _animate(
|
||||||
self,
|
self,
|
||||||
message: Union[Message, InlineMessage],
|
message: typing.Union[Message, InlineMessage],
|
||||||
frames: List[str],
|
frames: typing.List[str],
|
||||||
interval: Union[float, int],
|
interval: typing.Union[float, int],
|
||||||
*,
|
*,
|
||||||
inline: bool = False,
|
inline: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -1425,7 +1437,7 @@ class Modules:
|
||||||
name,
|
name,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unload_module(self, classname: str) -> bool:
|
async def unload_module(self, classname: str) -> typing.List[str]:
|
||||||
"""Remove module and all stuff from it"""
|
"""Remove module and all stuff from it"""
|
||||||
worked = []
|
worked = []
|
||||||
|
|
||||||
|
@ -1494,12 +1506,7 @@ class Modules:
|
||||||
|
|
||||||
def remove_alias(self, alias: str) -> bool:
|
def remove_alias(self, alias: str) -> bool:
|
||||||
"""Remove an alias"""
|
"""Remove an alias"""
|
||||||
try:
|
return bool(self.aliases.pop(alias.lower().strip(), None))
|
||||||
del self.aliases[alias.lower().strip()]
|
|
||||||
except KeyError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def log(self, *args, **kwargs):
|
async def log(self, *args, **kwargs):
|
||||||
"""Unnecessary placeholder for logging"""
|
"""Unnecessary placeholder for logging"""
|
||||||
|
|
Loading…
Reference in New Issue