Heroku/hikka/modules/loader.py

1344 lines
45 KiB
Python
Raw Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""Loads and registers modules"""
# ©️ Dan Gazizullin, 2021-2023
# This file is a part of Hikka Userbot
# 🌐 https://github.com/hikariatama/Hikka
# You can redistribute it and/or modify it under the terms of the GNU AGPLv3
# 🔑 https://www.gnu.org/licenses/agpl-3.0.html
import ast
import asyncio
import contextlib
import functools
import importlib
import difflib
import inspect
import io
import logging
import os
import re
import shutil
import sys
import time
import typing
import uuid
from collections import ChainMap
from importlib.machinery import ModuleSpec
from urllib.parse import urlparse
import requests
from hikkatl.errors.rpcerrorlist import MediaCaptionTooLongError
from hikkatl.tl.functions.channels import JoinChannelRequest
from hikkatl.tl.types import Channel, Message, PeerUser
from .. import loader, main, utils
from .._local_storage import RemoteStorage
from ..compat import geek
from ..inline.types import InlineCall
from ..types import CoreOverwriteError, CoreUnloadError
logger = logging.getLogger(__name__)
class FakeOne:
def __eq__(self, other):
return other == -1 or isinstance(other, FakeOne)
def __bool__(self):
return False
MODULE_LOADING_FORBIDDEN = FakeOne()
MODULE_LOADING_FAILED = 0
MODULE_LOADING_SUCCESS = 1
@loader.tds
class LoaderMod(loader.Module):
"""Loads modules"""
strings = {"name": "Loader"}
def __init__(self):
self.fully_loaded = False
self._links_cache = {}
self._storage: RemoteStorage = None
self.config = loader.ModuleConfig(
loader.ConfigValue(
"MODULES_REPO",
"https://raw.githubusercontent.com/coddrago/modules/main",
lambda: self.strings("repo_config_doc"),
validator=loader.validators.Link(),
),
loader.ConfigValue(
"ADDITIONAL_REPOS",
[],
lambda: self.strings("add_repo_config_doc"),
validator=loader.validators.Series(validator=loader.validators.Link()),
),
loader.ConfigValue(
"share_link",
doc=lambda: self.strings("share_link_doc"),
validator=loader.validators.Boolean(),
),
loader.ConfigValue(
"basic_auth",
None,
lambda: self.strings("basic_auth_doc"),
validator=loader.validators.Hidden(
loader.validators.RegExp(r"^.*:.*$")
),
),
)
async def _async_init(self):
modules = list(
filter(
lambda x: not x.startswith(
"https://raw.githubusercontent.com/coddrago/modules/main"
),
utils.array_sum(
map(
lambda x: list(x.values()),
(await self.get_repo_list()).values(),
)
),
)
)
logger.debug("Modules: %s", modules)
asyncio.ensure_future(self._storage.preload(modules))
asyncio.ensure_future(self._storage.preload_main_repo())
async def client_ready(self):
while not (settings := self.lookup("settings")):
await asyncio.sleep(0.5)
self._storage = RemoteStorage(self._client)
self.allmodules.add_aliases(settings.get("aliases", {}))
main.hikka.ready.set()
asyncio.ensure_future(self._update_modules())
asyncio.ensure_future(self._async_init())
@loader.loop(interval=3, wait_before=True, autostart=True)
async def _config_autosaver(self):
for mod in self.allmodules.modules:
if (
not hasattr(mod, "config")
or not mod.config
or not isinstance(mod.config, loader.ModuleConfig)
):
continue
for option, config in mod.config._config.items():
if not hasattr(config, "_save_marker"):
continue
delattr(mod.config._config[option], "_save_marker")
mod.pointer("__config__", {})[option] = config.value
for lib in self.allmodules.libraries:
if (
not hasattr(lib, "config")
or not lib.config
or not isinstance(lib.config, loader.ModuleConfig)
):
continue
for option, config in lib.config._config.items():
if not hasattr(config, "_save_marker"):
continue
delattr(lib.config._config[option], "_save_marker")
lib._lib_pointer("__config__", {})[option] = config.value
self._db.save()
def update_modules_in_db(self):
if self.allmodules.secure_boot:
return
self.set(
"loaded_modules",
{
**{
module.__class__.__name__: module.__origin__
for module in self.allmodules.modules
if module.__origin__.startswith("http")
},
},
)
@loader.command(alias="dlm")
async def dlmod(self, message: Message, force_pm: bool = False):
if args := utils.get_args(message):
args = args[0]
if (
await self.download_and_install(args, message, force_pm)
== MODULE_LOADING_FORBIDDEN
):
return
if self.fully_loaded:
self.update_modules_in_db()
else:
await self.inline.list(
message,
[
self.strings("avail_header")
+ f"\n☁️ {repo.strip('/')}\n\n"
+ "\n".join(
[
" | ".join(chunk)
for chunk in utils.chunks(
[
f"<code>{i}</code>"
for i in sorted(
[
utils.escape_html(
i.split("/")[-1].split(".")[0]
)
for i in mods.values()
]
)
],
5,
)
]
)
for repo, mods in (await self.get_repo_list()).items()
],
)
async def _get_modules_to_load(self):
todo = self.get("loaded_modules", {})
logger.debug("Loading modules: %s", todo)
return todo
async def _get_repo(self, repo: str) -> str:
repo = repo.strip("/")
if self._links_cache.get(repo, {}).get("exp", 0) >= time.time():
return self._links_cache[repo]["data"]
res = await utils.run_sync(
requests.get,
f"{repo}/full.txt",
auth=(
tuple(self.config["basic_auth"].split(":", 1))
if self.config["basic_auth"]
else None
),
)
if not str(res.status_code).startswith("2"):
logger.debug(
"Can't load repo %s contents because of %s status code",
repo,
res.status_code,
)
return []
self._links_cache[repo] = {
"exp": time.time() + 5 * 60,
"data": [link for link in res.text.strip().splitlines() if link],
}
return self._links_cache[repo]["data"]
async def get_repo_list(
self,
only_primary: bool = False,
) -> dict:
return {
repo: {
f"Mod/{repo_id}/{i}": f'{repo.strip("/")}/{link}.py'
for i, link in enumerate(set(await self._get_repo(repo)))
}
for repo_id, repo in enumerate(
[self.config["MODULES_REPO"]]
+ ([] if only_primary else self.config["ADDITIONAL_REPOS"])
)
if repo.startswith("http")
}
async def get_links_list(self) -> typing.List[str]:
links = await self.get_repo_list()
main_repo = list(links.pop(self.config["MODULES_REPO"]).values())
return main_repo + list(dict(ChainMap(*list(links.values()))).values())
async def _find_link(self, module_name: str) -> typing.Union[str, bool]:
return next(
filter(
lambda link: link.lower().endswith(f"/{module_name.lower()}.py"),
await self.get_links_list(),
),
False,
)
async def download_and_install(
self,
module_name: str,
message: typing.Optional[Message] = None,
force_pm: bool = False,
) -> int:
try:
blob_link = False
module_name = module_name.strip()
if urlparse(module_name).netloc:
url = module_name
if re.match(
r"^(https:\/\/github\.com\/.*?\/.*?\/blob\/.*\.py)|"
r"(https:\/\/gitlab\.com\/.*?\/.*?\/-\/blob\/.*\.py)$",
url,
):
url = url.replace("/blob/", "/raw/")
blob_link = True
else:
url = await self._find_link(module_name)
if not url:
if message is not None:
await utils.answer(message, self.strings("no_module"))
return MODULE_LOADING_FAILED
if message:
message = await utils.answer(
message,
self.strings("installing").format(module_name),
)
try:
r = await self._storage.fetch(url, auth=self.config["basic_auth"])
except requests.exceptions.HTTPError:
if message is not None:
await utils.answer(message, self.strings("no_module"))
return MODULE_LOADING_FAILED
await self.load_module(
r,
message,
module_name,
url,
blob_link=blob_link,
)
return MODULE_LOADING_SUCCESS
except Exception:
logger.exception("Failed to load %s", module_name)
return MODULE_LOADING_FAILED
async def _inline__load(
self,
call: InlineCall,
doc: str,
path_: str,
mode: str,
):
save = False
if mode == "all_yes":
self._db.set(main.__name__, "permanent_modules_fs", True)
self._db.set(main.__name__, "disable_modules_fs", False)
await call.answer(self.strings("will_save_fs"))
save = True
elif mode == "all_no":
self._db.set(main.__name__, "disable_modules_fs", True)
self._db.set(main.__name__, "permanent_modules_fs", False)
elif mode == "once":
save = True
await self.load_module(doc, call, origin=path_ or "<string>", save_fs=save)
@loader.command(alias="lm")
async def loadmod(self, message: Message, force_pm: bool = False):
args = utils.get_args_raw(message)
if "-fs" in args:
force_save = True
args = args.replace("-fs", "").strip()
else:
force_save = False
msg = message if message.file else (await message.get_reply_message())
if msg is None or msg.media is None:
await utils.answer(message, self.strings("provide_module"))
return
path_ = None
doc = await msg.download_media(bytes)
try:
doc = doc.decode()
except UnicodeDecodeError:
await utils.answer(message, self.strings("bad_unicode"))
return
if (
not self._db.get(
main.__name__,
"disable_modules_fs",
False,
)
and not self._db.get(main.__name__, "permanent_modules_fs", False)
and not force_save
):
if message.file:
await message.edit("")
message = await message.respond("🌘", reply_to=utils.get_topic(message))
if await self.inline.form(
self.strings("module_fs"),
message=message,
reply_markup=[
[
{
"text": self.strings("save"),
"callback": self._inline__load,
"args": (doc, path_, "once"),
},
{
"text": self.strings("no_save"),
"callback": self._inline__load,
"args": (doc, path_, "no"),
},
],
[
{
"text": self.strings("save_for_all"),
"callback": self._inline__load,
"args": (doc, path_, "all_yes"),
}
],
[
{
"text": self.strings("never_save"),
"callback": self._inline__load,
"args": (doc, path_, "all_no"),
}
],
],
):
return
if path_ is not None:
await self.load_module(
doc,
message,
origin=path_,
save_fs=(
force_save
or self._db.get(main.__name__, "permanent_modules_fs", False)
and not self._db.get(main.__name__, "disable_modules_fs", False)
),
)
else:
await self.load_module(
doc,
message,
save_fs=(
force_save
or self._db.get(main.__name__, "permanent_modules_fs", False)
and not self._db.get(main.__name__, "disable_modules_fs", False)
),
)
async def approve_internal(
self,
call: InlineCall,
channel: "hints.EntityLike", # type: ignore # noqa
event: asyncio.Event,
):
"""
Don't you dare call it externally
"""
await self._client(JoinChannelRequest(channel))
event.status = True
event.set()
await call.edit(
(
"💫 <b>Joined <a"
f' href="https://t.me/{channel.username}">{utils.escape_html(channel.title)}</a></b>'
),
gif="https://data.whicdn.com/images/324445359/original.gif",
)
async def load_module(
self,
doc: str,
message: Message,
name: typing.Optional[str] = None,
origin: str = "<string>",
did_requirements: bool = False,
save_fs: bool = False,
blob_link: bool = False,
):
if any(
line.replace(" ", "") == "#scope:ffmpeg" for line in doc.splitlines()
) and os.system("ffmpeg -version 1>/dev/null 2>/dev/null"):
if isinstance(message, Message):
await utils.answer(message, self.strings("ffmpeg_required"))
return
if (
any(line.replace(" ", "") == "#scope:inline" for line in doc.splitlines())
and not self.inline.init_complete
):
if isinstance(message, Message):
await utils.answer(message, self.strings("inline_init_failed"))
return
if re.search(r"# ?scope: ?hikka_min", doc):
ver = re.search(r"# ?scope: ?hikka_min ((?:\d+\.){2}\d+)", doc).group(1)
ver_ = tuple(map(int, ver.split(".")))
if main.__version__ < ver_:
if isinstance(message, Message):
if getattr(message, "file", None):
m = utils.get_chat_id(message)
await message.edit("")
else:
m = message
await self.inline.form(
self.strings("version_incompatible").format(ver),
m,
reply_markup=[
{
"text": self.lookup("updater").strings("btn_update"),
"callback": self.lookup("updater").inline_update,
},
{
"text": self.lookup("updater").strings("cancel"),
"action": "close",
},
],
)
return
developer = re.search(r"# ?meta developer: ?(.+)", doc)
developer = developer.group(1) if developer else False
blob_link = self.strings("blob_link") if blob_link else ""
if name is None:
try:
node = ast.parse(doc)
uid = next(
n.name
for n in node.body
if isinstance(n, ast.ClassDef)
and any(
isinstance(base, ast.Attribute)
and base.value.id == "Module"
or isinstance(base, ast.Name)
and base.id == "Module"
for base in n.bases
)
)
except Exception:
logger.debug(
"Can't parse classname from code, using legacy uid instead",
exc_info=True,
)
uid = "__extmod_" + str(uuid.uuid4())
else:
if name.startswith(self.config["MODULES_REPO"]):
name = name.split("/")[-1].split(".py")[0]
uid = name.replace("%", "%%").replace(".", "%d")
module_name = f"hikka.modules.{uid}"
doc = geek.compat(doc)
async def core_overwrite(e: CoreOverwriteError):
nonlocal message
with contextlib.suppress(Exception):
self.allmodules.modules.remove(instance)
if not message:
return
await utils.answer(
message,
self.strings(f"overwrite_{e.type}").format(
*(
(e.target,)
if e.type == "module"
else (utils.escape_html(self.get_prefix()), e.target)
)
),
)
try:
try:
spec = ModuleSpec(
module_name,
loader.StringLoader(doc, f"<external {module_name}>"),
origin=f"<external {module_name}>",
)
instance = await self.allmodules.register_module(
spec,
module_name,
origin,
save_fs=save_fs,
)
except ImportError as e:
logger.info(
"Module loading failed, attemping dependency installation (%s)",
e.name,
)
# Let's try to reinstall dependencies
try:
requirements = list(
filter(
lambda x: not x.startswith(("-", "_", ".")),
map(
str.strip,
loader.VALID_PIP_PACKAGES.search(doc)[1].split(),
),
)
)
except TypeError:
logger.warning(
"No valid pip packages specified in code, attemping"
" installation from error"
)
requirements = [
{
"sklearn": "scikit-learn",
"pil": "Pillow",
"hikkatl": "Hikka-TL-New",
}.get(e.name.lower(), e.name)
]
if not requirements:
raise Exception("Nothing to install") from e
logger.debug("Installing requirements: %s", requirements)
if did_requirements:
if message is not None:
await utils.answer(
message,
self.strings("requirements_restart").format(e.name),
)
return
if message is not None:
await utils.answer(
message,
self.strings("requirements_installing").format(
"\n".join(
"<emoji"
" document_id=4971987363145188045>▫️</emoji>"
f" {req}"
for req in requirements
)
),
)
pip = await asyncio.create_subprocess_exec(
sys.executable,
"-m",
"pip",
"install",
"--upgrade",
"-q",
"--disable-pip-version-check",
"--no-warn-script-location",
*["--user"] if loader.USER_INSTALL else [],
*requirements,
)
rc = await pip.wait()
if rc != 0:
if message is not None:
if "com.termux" in os.environ.get("PREFIX", ""):
await utils.answer(
message,
self.strings("requirements_failed_termux"),
)
else:
await utils.answer(
message,
self.strings("requirements_failed"),
)
return
importlib.invalidate_caches()
kwargs = utils.get_kwargs()
kwargs["did_requirements"] = True
return await self.load_module(**kwargs) # Try again
except CoreOverwriteError as e:
await core_overwrite(e)
return
except loader.LoadError as e:
with contextlib.suppress(Exception):
await self.allmodules.unload_module(instance.__class__.__name__)
with contextlib.suppress(Exception):
self.allmodules.modules.remove(instance)
if message:
await utils.answer(
message,
(
"<emoji document_id=5454225457916420314>😖</emoji>"
f" <b>{utils.escape_html(str(e))}</b>"
),
)
return
except Exception as e:
logger.exception("Loading external module failed due to %s", e)
if message is not None:
await utils.answer(message, self.strings("load_failed"))
return
if hasattr(instance, "__version__") and isinstance(instance.__version__, tuple):
version = (
"<b><i>"
f" (v{'.'.join(list(map(str, list(instance.__version__))))})</i></b>"
)
else:
version = ""
try:
try:
self.allmodules.send_config_one(instance)
async def inner_proxy():
nonlocal instance, message
while True:
if hasattr(instance, "hikka_wait_channel_approve"):
if message:
(
module,
channel,
reason,
) = instance.hikka_wait_channel_approve
message = await utils.answer(
message,
self.strings("wait_channel_approve").format(
module,
channel.username,
utils.escape_html(channel.title),
utils.escape_html(reason),
self.inline.bot_username,
),
)
return
await asyncio.sleep(0.1)
task = asyncio.ensure_future(inner_proxy())
await self.allmodules.send_ready_one(
instance,
no_self_unload=True,
from_dlmod=bool(message),
)
task.cancel()
except CoreOverwriteError as e:
await core_overwrite(e)
return
except loader.LoadError as e:
with contextlib.suppress(Exception):
await self.allmodules.unload_module(instance.__class__.__name__)
with contextlib.suppress(Exception):
self.allmodules.modules.remove(instance)
if message:
await utils.answer(
message,
(
"<emoji document_id=5454225457916420314>😖</emoji>"
f" <b>{utils.escape_html(str(e))}</b>"
),
)
return
except loader.SelfUnload as e:
logger.debug("Unloading %s, because it raised `SelfUnload`", instance)
with contextlib.suppress(Exception):
await self.allmodules.unload_module(instance.__class__.__name__)
with contextlib.suppress(Exception):
self.allmodules.modules.remove(instance)
if message:
await utils.answer(
message,
(
"<emoji document_id=5454225457916420314>😖</emoji>"
f" <b>{utils.escape_html(str(e))}</b>"
),
)
return
except loader.SelfSuspend as e:
logger.debug("Suspending %s, because it raised `SelfSuspend`", instance)
if message:
await utils.answer(
message,
(
"🥶 <b>Module suspended itself\nReason:"
f" {utils.escape_html(str(e))}</b>"
),
)
return
except Exception as e:
logger.exception("Module threw because of %s", e)
if message is not None:
await utils.answer(message, self.strings("load_failed"))
return
instance.hikka_meta_pic = next(
(
line.replace(" ", "").split("#metapic:", maxsplit=1)[1]
for line in doc.splitlines()
if line.replace(" ", "").startswith("#metapic:")
),
None,
)
pack_url = next(
(
line.replace(" ", "").split("#packurl:", maxsplit=1)[1]
for line in doc.splitlines()
if line.replace(" ", "").startswith("#packurl:")
),
None,
)
if pack_url and (
transations := await self.allmodules.translator.load_module_translations(
pack_url
)
):
instance.strings.external_strings = transations
for alias, cmd in self.lookup("settings").get("aliases", {}).items():
if cmd in instance.commands:
self.allmodules.add_alias(alias, cmd)
try:
modname = instance.strings("name")
except (KeyError, AttributeError):
modname = getattr(instance, "name", instance.__class__.__name__)
try:
developer_entity = await (
self._client.force_get_entity
if (
developer in self._client.hikka_entity_cache
and getattr(
await self._client.get_entity(developer),
"left",
True,
)
)
else self._client.get_entity
)(developer)
except Exception:
developer_entity = None
if not isinstance(developer_entity, Channel):
developer_entity = None
if message is None:
return
modhelp = ""
if instance.__doc__:
modhelp += (
"<i>\n<emoji document_id=5787544344906959608></emoji>"
f" {utils.escape_html(inspect.getdoc(instance))}</i>\n"
)
subscribe = ""
subscribe_markup = None
depends_from = []
for key in dir(instance):
value = getattr(instance, key)
if isinstance(value, loader.Library):
depends_from.append(
"<emoji document_id=4971987363145188045>▫️</emoji>"
" <code>{}</code> <b>{}</b> <code>{}</code>".format(
value.__class__.__name__,
self.strings("by"),
(
value.developer
if isinstance(getattr(value, "developer", None), str)
else "Unknown"
),
)
)
depends_from = (
self.strings("depends_from").format("\n".join(depends_from))
if depends_from
else ""
)
def loaded_msg(use_subscribe: bool = True):
nonlocal \
modname, \
version, \
modhelp, \
developer, \
origin, \
subscribe, \
blob_link, \
depends_from
return self.strings("loaded").format(
modname.strip(),
version,
utils.ascii_face(),
modhelp,
developer if not subscribe or not use_subscribe else "",
depends_from,
(
self.strings("modlink").format(origin)
if origin != "<string>" and self.config["share_link"]
else ""
),
blob_link,
subscribe if use_subscribe else "",
)
if developer:
if developer.startswith("@") and developer not in self.get(
"do_not_subscribe", []
):
if (
developer_entity
and getattr(developer_entity, "left", True)
and self._db.get(main.__name__, "suggest_subscribe", True)
):
subscribe = self.strings("suggest_subscribe").format(
f"@{utils.escape_html(developer_entity.username)}"
)
subscribe_markup = [
{
"text": self.strings("subscribe"),
"callback": self._inline__subscribe,
"args": (
developer_entity.id,
functools.partial(loaded_msg, use_subscribe=False),
True,
),
},
{
"text": self.strings("no_subscribe"),
"callback": self._inline__subscribe,
"args": (
developer,
functools.partial(loaded_msg, use_subscribe=False),
False,
),
},
]
developer = self.strings("developer").format(
utils.escape_html(developer)
if isinstance(developer_entity, Channel)
else f"<code>{utils.escape_html(developer)}</code>"
)
else:
developer = ""
if any(
line.replace(" ", "") == "#scope:disable_onload_docs"
for line in doc.splitlines()
):
await utils.answer(message, loaded_msg(), reply_markup=subscribe_markup)
return
for _name, fun in sorted(
instance.commands.items(),
key=lambda x: x[0],
):
modhelp += "\n{} <code>{}{}</code> {}".format(
"<emoji document_id=4971987363145188045>▫️</emoji>",
utils.escape_html(self.get_prefix()),
_name,
(
utils.escape_html(inspect.getdoc(fun))
if fun.__doc__
else self.strings("undoc")
),
)
if self.inline.init_complete:
for _name, fun in sorted(
instance.inline_handlers.items(),
key=lambda x: x[0],
):
modhelp += self.strings("ihandler").format(
f"@{self.inline.bot_username} {_name}",
(
utils.escape_html(inspect.getdoc(fun))
if fun.__doc__
else self.strings("undoc")
),
)
try:
await utils.answer(message, loaded_msg(), reply_markup=subscribe_markup)
except MediaCaptionTooLongError:
await message.reply(loaded_msg(False))
async def _inline__subscribe(
self,
call: InlineCall,
entity: int,
msg: typing.Callable[[], str],
subscribe: bool,
):
if not subscribe:
self.set("do_not_subscribe", self.get("do_not_subscribe", []) + [entity])
await utils.answer(call, msg())
await call.answer(self.strings("not_subscribed"))
return
await self._client(JoinChannelRequest(entity))
await utils.answer(call, msg())
await call.answer(self.strings("subscribed"))
@loader.command(alias="ulm")
async def unloadmod(self, message: Message):
if not (args := utils.get_args_raw(message)):
await utils.answer(message, self.strings("no_class"))
return
instance = self.lookup(args)
if issubclass(instance.__class__, loader.Library):
await utils.answer(message, self.strings("cannot_unload_lib"))
return
try:
worked = await self.allmodules.unload_module(args)
except CoreUnloadError as e:
await utils.answer(
message,
self.strings("unload_core").format(e.module),
)
return
if not self.allmodules.secure_boot:
self.set(
"loaded_modules",
{
mod: link
for mod, link in self.get("loaded_modules", {}).items()
if mod not in worked
},
)
msg = (
self.strings("unloaded").format(
"<emoji document_id=5784993237412351403>✅</emoji>",
", ".join(
[(mod[:-3] if mod.endswith("Mod") else mod) for mod in worked]
),
)
if worked
else self.strings("not_unloaded")
)
await utils.answer(message, msg)
@loader.command()
async def clearmodules(self, message: Message):
await self.inline.form(
self.strings("confirm_clearmodules"),
message,
reply_markup=[
{
"text": self.strings("clearmodules"),
"callback": self._inline__clearmodules,
},
{
"text": self.strings("cancel"),
"action": "close",
},
],
)
@loader.command()
async def addrepo(self, message: Message):
if not (args := utils.get_args_raw(message)) or (
not utils.check_url(args) and not utils.check_url(f"https://{args}")
):
await utils.answer(message, self.strings("no_repo"))
return
if args.endswith("/"):
args = args[:-1]
if not args.startswith("https://") and not args.startswith("http://"):
args = f"https://{args}"
try:
r = await utils.run_sync(
requests.get,
f"{args}/full.txt",
auth=(
tuple(self.config["basic_auth"].split(":", 1))
if self.config["basic_auth"]
else None
),
)
r.raise_for_status()
if not r.text.strip():
raise ValueError
except Exception:
await utils.answer(message, self.strings("no_repo"))
return
if args in self.config["ADDITIONAL_REPOS"]:
await utils.answer(message, self.strings("repo_exists").format(args))
return
self.config["ADDITIONAL_REPOS"] += [args]
await utils.answer(message, self.strings("repo_added").format(args))
@loader.command()
async def delrepo(self, message: Message):
if not (args := utils.get_args_raw(message)) or not utils.check_url(args):
await utils.answer(message, self.strings("no_repo"))
return
if args.endswith("/"):
args = args[:-1]
if args not in self.config["ADDITIONAL_REPOS"]:
await utils.answer(message, self.strings("repo_not_exists"))
return
self.config["ADDITIONAL_REPOS"].remove(args)
await utils.answer(message, self.strings("repo_deleted").format(args))
async def _inline__clearmodules(self, call: InlineCall):
self.set("loaded_modules", {})
for file in os.scandir(loader.LOADED_MODULES_DIR):
try:
shutil.rmtree(file.path)
except Exception:
logger.debug("Failed to remove %s", file.path, exc_info=True)
await utils.answer(call, self.strings("all_modules_deleted"))
await self.lookup("Updater").restart_common(call)
async def _update_modules(self):
todo = await self._get_modules_to_load()
self._secure_boot = False
if self._db.get(loader.__name__, "secure_boot", False):
self._db.set(loader.__name__, "secure_boot", False)
self._secure_boot = True
else:
for mod in todo.values():
await self.download_and_install(mod)
self.update_modules_in_db()
aliases = {
alias: cmd
for alias, cmd in self.lookup("settings").get("aliases", {}).items()
if self.allmodules.add_alias(alias, cmd)
}
self.lookup("settings").set("aliases", aliases)
self.fully_loaded = True
with contextlib.suppress(AttributeError):
await self.lookup("Updater").full_restart_complete(self._secure_boot)
def flush_cache(self) -> int:
"""Flush the cache of links to modules"""
count = sum(map(len, self._links_cache.values()))
self._links_cache = {}
return count
def inspect_cache(self) -> int:
"""Inspect the cache of links to modules"""
return sum(map(len, self._links_cache.values()))
async def reload_core(self) -> int:
"""Forcefully reload all core modules"""
self.fully_loaded = False
if self._secure_boot:
self._db.set(loader.__name__, "secure_boot", True)
if not self._db.get(main.__name__, "remove_core_protection", False):
for module in self.allmodules.modules:
if module.__origin__.startswith("<core"):
module.__origin__ = "<reload-core>"
loaded = await self.allmodules.register_all(no_external=True)
for instance in loaded:
self.allmodules.send_config_one(instance)
await self.allmodules.send_ready_one(
instance,
no_self_unload=False,
from_dlmod=False,
)
self.fully_loaded = True
return len(loaded)
@loader.command()
async def mlcmd(self, message: Message):
"""| send module via file"""
if not (args := utils.get_args_raw(message)):
await utils.answer(message, self.strings("args"))
return
exact = True
if not (
class_name := next(
(
module.strings("name")
for module in self.allmodules.modules
if args.lower()
in {
module.strings("name").lower(),
module.__class__.__name__.lower(),
}
),
None,
)
):
if not (
class_name := next(
reversed(
sorted(
[
module.strings["name"].lower()
for module in self.allmodules.modules
]
+ [
module.__class__.__name__.lower()
for module in self.allmodules.modules
],
key=lambda x: difflib.SequenceMatcher(
None,
args.lower(),
x,
).ratio(),
)
),
None,
)
):
await utils.answer(message, self.strings("404"))
return
exact = False
try:
module = self.lookup(class_name)
sys_module = inspect.getmodule(module)
except Exception:
await utils.answer(message, self.strings("404"))
return
link = module.__origin__
text = (
f"<b>🧳 {utils.escape_html(class_name)}</b>"
if not utils.check_url(link)
else (
f'📼 <b><a href="{link}">Link</a> for'
f" {utils.escape_html(class_name)}:</b>"
f' <code>{link}</code>\n\n{self.strings("not_exact") if not exact else ""}'
)
)
text = (
self.strings("link").format(
class_name=utils.escape_html(class_name),
url=link,
not_exact=self.strings("not_exact") if not exact else "",
prefix=utils.escape_html(self.get_prefix()),
)
if utils.check_url(link)
else self.strings("file").format(
class_name=utils.escape_html(class_name),
not_exact=self.strings("not_exact") if not exact else "",
prefix=utils.escape_html(self.get_prefix()),
)
)
file = io.BytesIO(sys_module.__loader__.data)
file.name = f"{class_name}.py"
file.seek(0)
await utils.answer_file(
message,
file,
caption=text,
reply_to=getattr(message, "reply_to_msg_id", None),
)
def _format_result(
self,
result: dict,
query: str,
no_translate: bool = False,
) -> str:
commands = "\n".join(
[
f"▫️ <code>{utils.escape_html(self.get_prefix())}{utils.escape_html(cmd)}</code>:"
f" <b>{utils.escape_html(cmd_doc)}</b>"
for cmd, cmd_doc in result["module"]["commands"].items()
]
)
kwargs = {
"name": utils.escape_html(result["module"]["name"]),
"dev": utils.escape_html(result["module"]["dev"]),
"commands": commands,
"cls_doc": utils.escape_html(result["module"]["cls_doc"]),
"mhash": result["module"]["hash"],
"query": utils.escape_html(query),
"prefix": utils.escape_html(self.get_prefix()),
}
strings = (
self.strings.get("result", "en")
if self.config["translate"] and not no_translate
else self.strings("result")
)
text = strings.format(**kwargs)
if len(text) > 1980:
kwargs["commands"] = "..."
text = strings.format(**kwargs)
return text