pull/1/head
hikariatama 2022-09-19 06:12:41 +00:00
parent 066f7a46a2
commit b32a3cb497
2 changed files with 59 additions and 52 deletions

View File

@ -417,7 +417,7 @@ class Events(InlineUnit):
"text": "🏌️ Run command", "text": "🏌️ Run command",
"switch_inline_query_current_chat": f"{name} ", "switch_inline_query_current_chat": f"{name} ",
} }
) ),
), ),
f"🎹 <code>@{self.bot_username} {utils.escape_html(name)}</code> -" f"🎹 <code>@{self.bot_username} {utils.escape_html(name)}</code> -"
f" {utils.escape_html(doc)}\n", f" {utils.escape_html(doc)}\n",

View File

@ -44,7 +44,7 @@ from telethon.tl.functions.account import UpdateNotifySettingsRequest
from telethon.hints import EntityLike from telethon.hints import EntityLike
from types import FunctionType from types import FunctionType
from typing import Any, Optional, Union, List import typing
from . import security, utils, validators, version from . import security, utils, validators, version
from .types import ( from .types import (
@ -102,7 +102,7 @@ def proxy0(data):
_CELLTYPE = type(proxy0(None).__closure__[0]) _CELLTYPE = type(proxy0(None).__closure__[0])
def replace_all_refs(replace_from: Any, replace_to: Any) -> Any: def replace_all_refs(replace_from: typing.Any, replace_to: typing.Any) -> typing.Any:
""" """
:summary: Uses the :mod:`gc` module to replace all references to obj :summary: Uses the :mod:`gc` module to replace all references to obj
:attr:`replace_from` with :attr:`replace_to` (it tries it's best, :attr:`replace_from` with :attr:`replace_to` (it tries it's best,
@ -196,7 +196,7 @@ def replace_all_refs(replace_from: Any, replace_to: Any) -> Any:
replace_all_refs(referrer, newfn) replace_all_refs(referrer, newfn)
else: else:
logging.debug(f"{referrer} is not supported.") logger.debug("%s is not supported.", referrer)
if hit is False: if hit is False:
raise AttributeError(f"Object '{replace_from}' not found") raise AttributeError(f"Object '{replace_from}' not found")
@ -233,7 +233,7 @@ class InfiniteLoop:
interval: int, interval: int,
autostart: bool, autostart: bool,
wait_before: bool, wait_before: bool,
stop_clause: Union[str, None], stop_clause: typing.Union[str, None],
): ):
self.func = func self.func = func
self.interval = interval self.interval = interval
@ -314,9 +314,9 @@ class InfiniteLoop:
def loop( def loop(
interval: int = 5, interval: int = 5,
autostart: Optional[bool] = False, autostart: typing.Optional[bool] = False,
wait_before: Optional[bool] = False, wait_before: typing.Optional[bool] = False,
stop_clause: Optional[str] = None, stop_clause: typing.Optional[str] = None,
) -> FunctionType: ) -> FunctionType:
""" """
Create new infinite loop from class method Create new infinite loop from class method
@ -572,9 +572,9 @@ class Modules:
async def register_all( async def register_all(
self, self,
mods: Optional[List[str]] = None, mods: typing.Optional[typing.List[str]] = None,
no_external: bool = False, no_external: bool = False,
) -> List[Module]: ) -> typing.List[Module]:
"""Load all modules in the module directory""" """Load all modules in the module directory"""
external_mods = [] external_mods = []
@ -613,7 +613,7 @@ class Modules:
async def _register_modules( async def _register_modules(
self, modules: list, origin: str = "<core>" self, modules: list, origin: str = "<core>"
) -> List[Module]: ) -> typing.List[Module]:
with contextlib.suppress(AttributeError): with contextlib.suppress(AttributeError):
_hikka_client_id_logging_tag = copy.copy(self.client.tg_id) _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
@ -837,7 +837,7 @@ class Modules:
self, self,
peer: EntityLike, peer: EntityLike,
reason: str, reason: str,
assure_joined: Optional[bool] = False, assure_joined: typing.Optional[bool] = False,
_module: Module = None, _module: Module = None,
) -> bool: ) -> bool:
""" """
@ -985,7 +985,7 @@ class Modules:
def _get( def _get(
self, self,
key: str, key: str,
default: Optional[JSONSerializable] = None, default: typing.Optional[JSONSerializable] = None,
_owner: str = None, _owner: str = None,
) -> JSONSerializable: ) -> JSONSerializable:
return self._db.get(_owner, key, default) return self._db.get(_owner, key, default)
@ -996,7 +996,7 @@ class Modules:
def _pointer( def _pointer(
self, self,
key: str, key: str,
default: Optional[JSONSerializable] = None, default: typing.Optional[JSONSerializable] = None,
_owner: str = None, _owner: str = None,
) -> JSONSerializable: ) -> JSONSerializable:
return self._db.pointer(_owner, key, default) return self._db.pointer(_owner, key, default)
@ -1005,7 +1005,7 @@ class Modules:
self, self,
url: str, url: str,
*, *,
suspend_on_error: Optional[bool] = False, suspend_on_error: typing.Optional[bool] = False,
_did_requirements: bool = False, _did_requirements: bool = False,
) -> object: ) -> object:
""" """
@ -1055,7 +1055,9 @@ class Modules:
origin = f"<library {url}>" origin = f"<library {url}>"
spec = importlib.machinery.ModuleSpec( spec = importlib.machinery.ModuleSpec(
module, StringLoader(code, origin), origin=origin module,
StringLoader(code, origin),
origin=origin,
) )
try: try:
instance = importlib.util.module_from_spec(spec) instance = importlib.util.module_from_spec(spec)
@ -1127,7 +1129,13 @@ class Modules:
_raise(ImportError("Invalid library. No class found")) _raise(ImportError("Invalid library. No class found"))
if not lib_obj.__class__.__name__.endswith("Lib"): if not lib_obj.__class__.__name__.endswith("Lib"):
_raise(ImportError("Invalid library. Class name must end with 'Lib'")) _raise(
ImportError(
"Invalid library. Classname {} does not end with 'Lib'".format(
lib_obj.__class__.__name__
)
)
)
if ( if (
all( all(
@ -1152,15 +1160,18 @@ class Modules:
lib_obj.tg_id = self.client.tg_id lib_obj.tg_id = self.client.tg_id
lib_obj.allmodules = self lib_obj.allmodules = self
lib_obj._lib_get = partial( lib_obj._lib_get = partial(
self._get, _owner=lib_obj.__class__.__name__ self._get,
) # skipcq _owner=lib_obj.__class__.__name__,
)
lib_obj._lib_set = partial( lib_obj._lib_set = partial(
self._set, _owner=lib_obj.__class__.__name__ self._set,
) # skipcq _owner=lib_obj.__class__.__name__,
)
lib_obj._lib_pointer = partial( lib_obj._lib_pointer = partial(
self._pointer, _owner=lib_obj.__class__.__name__ self._pointer,
) # skipcq _owner=lib_obj.__class__.__name__,
lib_obj.get_prefix = partial(self._db.get, "hikka.main", "command_prefix", ".") )
lib_obj.get_prefix = self.get_prefix
for old_lib in self.libraries: for old_lib in self.libraries:
if old_lib.name == lib_obj.name and ( if old_lib.name == lib_obj.name and (
@ -1168,26 +1179,15 @@ class Modules:
and not isinstance(getattr(lib_obj, "version", None), tuple) and not isinstance(getattr(lib_obj, "version", None), tuple)
or old_lib.version >= lib_obj.version or old_lib.version >= lib_obj.version
): ):
logging.debug(f"Using existing instance of library {old_lib.name}") logger.debug("Using existing instance of library %s", old_lib.name)
return old_lib return old_lib
new = True new = True
for old_lib in self.libraries: if hasattr(lib_obj, "init"):
if old_lib.name == lib_obj.name: if not callable(lib_obj.init):
if hasattr(old_lib, "on_lib_update") and callable( _raise(ValueError("Library init() must be callable"))
old_lib.on_lib_update
):
await old_lib.on_lib_update(lib_obj)
replace_all_refs(old_lib, lib_obj)
new = False
logging.debug(
"Replacing existing instance of library"
f" {lib_obj.name} with updated object"
)
if hasattr(lib_obj, "init") and callable(lib_obj.init):
try: try:
await lib_obj.init() await lib_obj.init()
except Exception: except Exception:
@ -1222,9 +1222,21 @@ class Modules:
lib_obj.translator = self._translator lib_obj.translator = self._translator
if new: for old_lib in self.libraries:
self.libraries += [lib_obj] if old_lib.name == lib_obj.name:
if hasattr(old_lib, "on_lib_update") and callable(
old_lib.on_lib_update
):
await old_lib.on_lib_update(lib_obj)
replace_all_refs(old_lib, lib_obj)
logger.debug(
"Replacing existing instance of library %s with updated object",
lib_obj.name,
)
return lib_obj
self.libraries += [lib_obj]
return lib_obj return lib_obj
def dispatch(self, _command: str) -> tuple: def dispatch(self, _command: str) -> tuple:
@ -1244,7 +1256,7 @@ class Modules:
for mod in self.modules: for mod in self.modules:
self.send_config_one(mod, skip_hook) self.send_config_one(mod, skip_hook)
def send_config_one(self, mod: "Module", skip_hook: bool = False): def send_config_one(self, mod: Module, skip_hook: bool = False):
"""Send config to single instance""" """Send config to single instance"""
with contextlib.suppress(AttributeError): with contextlib.suppress(AttributeError):
_hikka_client_id_logging_tag = copy.copy(self.client.tg_id) _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
@ -1311,9 +1323,9 @@ class Modules:
async def _animate( async def _animate(
self, self,
message: Union[Message, InlineMessage], message: typing.Union[Message, InlineMessage],
frames: List[str], frames: typing.List[str],
interval: Union[float, int], interval: typing.Union[float, int],
*, *,
inline: bool = False, inline: bool = False,
) -> None: ) -> None:
@ -1425,7 +1437,7 @@ class Modules:
name, name,
) )
async def unload_module(self, classname: str) -> bool: async def unload_module(self, classname: str) -> typing.List[str]:
"""Remove module and all stuff from it""" """Remove module and all stuff from it"""
worked = [] worked = []
@ -1494,12 +1506,7 @@ class Modules:
def remove_alias(self, alias: str) -> bool: def remove_alias(self, alias: str) -> bool:
"""Remove an alias""" """Remove an alias"""
try: return bool(self.aliases.pop(alias.lower().strip(), None))
del self.aliases[alias.lower().strip()]
except KeyError:
return False
return True
async def log(self, *args, **kwargs): async def log(self, *args, **kwargs):
"""Unnecessary placeholder for logging""" """Unnecessary placeholder for logging"""