- Fix grep for messages bigger than 4096 UTF-8 characters
- Rename `func` tag to `filter` due to internal python conflict with dynamically generated methods
- Partially rework security unit
- Internal refactoring and typehints
pull/1/head
hikariatama 2022-08-21 21:48:19 +00:00
parent ab8130ed60
commit 89040b6e2f
10 changed files with 218 additions and 130 deletions

View File

@ -10,12 +10,16 @@
- Fix `utils.find_caller` for :method:`hikka.inline.utils.Utils._find_caller_sec_map`
- Fix `.eval`
- Fix: use old lib if its version is higher than new one
- Fix grep for messages bigger than 4096 UTF-8 characters
- Add more animated emojis to modules
- Add targeted security for users and chats (`.tsec`)
- Add support for `tg_level` in `.config Tester`
- Add `-f` param to `.restart` and `.update`
- Add platform-specific Hikka emojis to premium users
- Add codepaces to `utils.get_named_platform`
- Rename `func` tag to `filter` due to internal python conflict with dynamically generated methods
- Partially rework security unit
- Internal refactoring and typehints
## 🌑 Hikka 1.3.3

View File

@ -27,8 +27,9 @@ except ImportError as e:
raise e
from typing import Any, Union
from typing import Optional, Union
from telethon import TelegramClient
from telethon.tl.types import Message
from telethon.errors.rpcerrorlist import ChannelsTooMuchError
@ -37,6 +38,7 @@ from .pointers import (
PointerList,
PointerDict,
)
from .types import JSONSerializable
DATA_DIR = (
os.path.normpath(os.path.join(utils.get_base_dir(), ".."))
@ -60,7 +62,7 @@ class Database(dict):
_redis = None
_saving_task = None
def __init__(self, client):
def __init__(self, client: TelegramClient):
super().__init__()
self._client = client
@ -347,14 +349,19 @@ class Database(dict):
return asset[0] if asset else None
def get(self, owner: str, key: str, default: Any = None) -> Any:
def get(
self,
owner: str,
key: str,
default: Optional[JSONSerializable] = None,
) -> JSONSerializable:
"""Get database key"""
try:
return self[owner][key]
except KeyError:
return default
def set(self, owner: str, key: str, value: Any) -> bool:
def set(self, owner: str, key: str, value: JSONSerializable) -> bool:
"""Set database key"""
if not utils.is_serializable(owner):
raise RuntimeError(
@ -380,7 +387,12 @@ class Database(dict):
super().setdefault(owner, {})[key] = value
return self.save()
def pointer(self, owner: str, key: str, default: Any = None) -> Any:
def pointer(
self,
owner: str,
key: str,
default: Optional[JSONSerializable] = None,
) -> JSONSerializable:
"""Get a pointer to database key"""
value = self.get(owner, key, default)
mapping = {

View File

@ -1,4 +1,4 @@
"""Obviously, dispatches stuff"""
"""Processes incoming events and dispatches them to appropriate handlers"""
# Friendly Telegram (telegram userbot)
# Copyright (C) 2018-2022 The Authors
@ -33,7 +33,7 @@ import re
import traceback
from typing import Tuple, Union
from telethon import types
from telethon import TelegramClient
from telethon.tl.types import Message
from . import main, security, utils
@ -64,7 +64,7 @@ ALL_TAGS = [
"startswith",
"endswith",
"contains",
"func",
"filter",
"from_id",
"chat_id",
"regex",
@ -79,25 +79,33 @@ def _decrement_ratelimit(delay, data, key, severity):
class CommandDispatcher:
def __init__(self, modules: Modules, db: Database, no_nickname: bool = False):
def __init__(
self,
modules: Modules,
client: TelegramClient,
db: Database,
no_nickname: bool = False,
):
self._modules = modules
self._client = client
self.client = client
self._db = db
self.security = security.SecurityManager(db)
self.no_nickname = no_nickname
self._ratelimit_storage_user = collections.defaultdict(int)
self._ratelimit_storage_chat = collections.defaultdict(int)
self._ratelimit_max_user = db.get(__name__, "ratelimit_max_user", 30)
self._ratelimit_max_chat = db.get(__name__, "ratelimit_max_chat", 100)
self.security = security.SecurityManager(client, db)
self.check_security = self.security.check
async def init(self, client: "TelegramClient"): # type: ignore
await self.security.init(client)
me = await client.get_me()
self.client = client # Intended to be used to track user in logging
self._me = me.id
self._cached_username = me.username.lower() if me.username else str(me.id)
self._me = self._client.hikka_me.id
self._cached_username = (
self._client.hikka_me.username.lower()
if self._client.hikka_me.username
else str(self._client.hikka_me.id)
)
async def _handle_ratelimit(self, message: Message, func: callable) -> bool:
if await self.security.check(
@ -227,6 +235,7 @@ class CommandDispatcher:
message.edit = my_edit
message.reply = my_reply
message.respond = my_respond
message.hikka_grepped = True
return message
@ -422,7 +431,7 @@ class CommandDispatcher:
or (getattr(func, "in", False) and getattr(message, "out", True))
or (
getattr(func, "only_messages", False)
and not isinstance(message, types.Message)
and not isinstance(message, Message)
)
or (
getattr(func, "editable", False)
@ -435,13 +444,13 @@ class CommandDispatcher:
)
or (
getattr(func, "no_media", False)
and isinstance(message, types.Message)
and isinstance(message, Message)
and getattr(message, "media", False)
)
or (
getattr(func, "only_media", False)
and (
not isinstance(message, types.Message)
not isinstance(message, Message)
or not getattr(message, "media", False)
)
)
@ -510,9 +519,9 @@ class CommandDispatcher:
)
)
or (
getattr(func, "func", False)
and callable(func.func)
and not func.func(message)
getattr(func, "filter", False)
and callable(func.filter)
and not func.filter(message)
)
or (
getattr(func, "from_id", False)
@ -556,7 +565,7 @@ class CommandDispatcher:
if (
modname in bl
and isinstance(message, types.Message)
and isinstance(message, Message)
and (
"*" in bl[modname]
or utils.get_chat_id(message) in bl[modname]
@ -584,7 +593,7 @@ class CommandDispatcher:
# Avoid weird AttributeErrors in weird dochub modules by settings placeholder
# of attributes
for placeholder in {"text", "raw_text"}:
for placeholder in {"text", "raw_text", "out"}:
try:
if not hasattr(message, placeholder):
setattr(message, placeholder, "")

View File

@ -26,32 +26,35 @@
import asyncio
import contextlib
import copy
from functools import partial, wraps
import importlib
import importlib.util
import inspect
import logging
import os
import re
import sys
from importlib.machinery import ModuleSpec
from types import FunctionType
from typing import Any, Awaitable, Hashable, Optional, Union, List
import requests
import copy
import importlib
import importlib.util
import importlib.machinery
from functools import partial, wraps
from telethon import TelegramClient
from telethon.tl.types import Message, InputPeerNotifySettings, Channel
from telethon.tl.functions.account import UpdateNotifySettingsRequest
from telethon.hints import EntityLike
from types import FunctionType
from typing import Any, Optional, Union, List
from . import security, utils, validators, version
from .types import (
ConfigValue, # skipcq
LoadError, # skipcq
LoadError,
Module,
Library, # skipcq
Library,
ModuleConfig, # skipcq
LibraryConfig, # skipcq
LibraryConfig,
SelfUnload,
SelfSuspend,
StopLoop,
@ -61,6 +64,7 @@ from .types import (
StringLoader,
get_commands,
get_inline_handlers,
JSONSerializable,
)
from .inline.core import InlineManager
from .inline.types import InlineCall
@ -431,7 +435,7 @@ def tag(*tags, **kwarg_tags):
`endswith` - Capture only messages that end with given text
`contains` - Capture only messages that contain given text
`regex` - Capture only messages that match given regex
`func` - Capture only messages that pass given function
`filter` - Capture only messages that pass given function
`from_id` - Capture only messages from given user
`chat_id` - Capture only messages from given chat
@ -612,7 +616,7 @@ class Modules:
logger.debug(f"Loading {module_name} from filesystem")
with open(mod, "r") as file:
spec = ModuleSpec(
spec = importlib.machinery.ModuleSpec(
module_name,
StringLoader(file.read(), user_friendly_origin),
origin=user_friendly_origin,
@ -624,7 +628,7 @@ class Modules:
async def register_module(
self,
spec: ModuleSpec,
spec: importlib.machinery.ModuleSpec,
module_name: str,
origin: str = "<core>",
save_fs: bool = False,
@ -917,10 +921,10 @@ class Modules:
instance.allclients = self.allclients
instance.allmodules = self
instance.hikka = True
instance.get = partial(self._mod_get, _modname=instance.__class__.__name__)
instance.set = partial(self._mod_set, _modname=instance.__class__.__name__)
instance.get = partial(self._get, _owner=instance.__class__.__name__)
instance.set = partial(self._set, _owner=instance.__class__.__name__)
instance.pointer = partial(
self._mod_pointer, _modname=instance.__class__.__name__
self._pointer, _owner=instance.__class__.__name__
)
instance.get_prefix = partial(self._db.get, "hikka.main", "command_prefix", ".")
instance.client = self.client
@ -955,43 +959,24 @@ class Modules:
self.modules += [instance]
def _mod_get(
def _get(
self,
key: str,
default: Optional[Hashable] = None,
_modname: str = None,
) -> Hashable:
return self._db.get(_modname, key, default)
default: Optional[JSONSerializable] = None,
_owner: str = None,
) -> JSONSerializable:
return self._db.get(_owner, key, default)
def _mod_set(self, key: str, value: Hashable, _modname: str = None) -> bool:
return self._db.set(_modname, key, value)
def _set(self, key: str, value: JSONSerializable, _owner: str = None) -> bool:
return self._db.set(_owner, key, value)
def _mod_pointer(
def _pointer(
self,
key: str,
default: Optional[Hashable] = None,
_modname: str = None,
) -> Any:
return self._db.pointer(_modname, key, default)
def _lib_get(
self,
key: str,
default: Optional[Hashable] = None,
_lib: Library = None,
) -> Hashable:
return self._db.get(_lib.__class__.__name__, key, default)
def _lib_set(self, key: str, value: Hashable, _lib: Library = None) -> bool:
return self._db.set(_lib.__class__.__name__, key, value)
def _lib_pointer(
self,
key: str,
default: Optional[Hashable] = None,
_lib: Library = None,
) -> Any:
return self._db.pointer(_lib.__class__.__name__, key, default)
default: Optional[JSONSerializable] = None,
_owner: str = None,
) -> JSONSerializable:
return self._db.pointer(_owner, key, default)
async def _mod_import_lib(
self,
@ -1046,7 +1031,7 @@ class Modules:
module = f"hikka.libraries.{url.replace('%', '%%').replace('.', '%d')}"
origin = f"<library {url}>"
spec = ModuleSpec(module, StringLoader(code, origin), origin=origin)
spec = importlib.machinery.ModuleSpec(module, StringLoader(code, origin), origin=origin)
try:
instance = importlib.util.module_from_spec(spec)
sys.modules[module] = instance
@ -1140,9 +1125,9 @@ class Modules:
lib_obj.inline = self.inline
lib_obj.tg_id = self.client.tg_id
lib_obj.allmodules = self
lib_obj._lib_get = partial(self._lib_get, _lib=lib_obj) # skipcq
lib_obj._lib_set = partial(self._lib_set, _lib=lib_obj) # skipcq
lib_obj._lib_pointer = partial(self._lib_pointer, _lib=lib_obj) # skipcq
lib_obj._lib_get = partial(self._get, _owner=lib_obj) # skipcq
lib_obj._lib_set = partial(self._set, _owner=lib_obj) # skipcq
lib_obj._lib_pointer = partial(self._pointer, _owner=lib_obj) # skipcq
lib_obj.get_prefix = partial(self._db.get, "hikka.main", "command_prefix", ".")
for old_lib in self.libraries:

View File

@ -625,9 +625,8 @@ class Hikka:
async def _add_dispatcher(self, client, modules, db):
"""Inits and adds dispatcher instance to client"""
dispatcher = CommandDispatcher(modules, db, self.arguments.no_nickname)
dispatcher = CommandDispatcher(modules, client, db, self.arguments.no_nickname)
client.dispatcher = dispatcher
await dispatcher.init(client)
modules.check_security = dispatcher.check_security
client.add_event_handler(

View File

@ -543,7 +543,7 @@ class HikkaSecurityMod(loader.Module):
) -> dict:
config = self._db.get(security.__name__, "masks", {}).get(
f"{command.__module__}.{command.__name__}",
getattr(command, "security", self._client.dispatcher.security._default),
getattr(command, "security", self._client.dispatcher.security.default),
)
return self._perms_map(config, is_inline)
@ -938,7 +938,7 @@ class HikkaSecurityMod(loader.Module):
try:
if not args[1].isdigit() and not args[1].startswith("@"):
raise ValueError
target = await self._client.get_entity(
int(args[1]) if args[1].isdigit() else args[1]
)
@ -1011,7 +1011,7 @@ class HikkaSecurityMod(loader.Module):
await utils.answer(message, self.strings("no_target"))
return
if target.id in self._client.dispatcher.security._owner:
if target.id in self._client.dispatcher.security.owner:
await utils.answer(message, self.strings("owner_target"))
return
@ -1059,8 +1059,8 @@ class HikkaSecurityMod(loader.Module):
async def tsecrm(self, message: Message):
"""<"user"/"chat"> - Remove targeted security rule"""
if (
not self._client.dispatcher.security._tsec_chat
and not self._client.dispatcher.security._tsec_user
not self._client.dispatcher.security.tsec_chat
and not self._client.dispatcher.security.tsec_user
):
await utils.answer(message, self.strings("no_rules"))
return
@ -1084,18 +1084,10 @@ class HikkaSecurityMod(loader.Module):
await utils.answer(message, self.strings("no_target"))
return
if not any(
rule["target"] == target.id
for rule in self._client.dispatcher.security._tsec_user
):
if not self._client.dispatcher.security.remove_rules("user", target.id):
await utils.answer(message, self.strings("no_rules"))
return
self._client.dispatcher.security._tsec_user = [
rule
for rule in self._client.dispatcher.security._tsec_user
if rule["target"] != target.id
]
await utils.answer(
message,
self.strings("rules_removed").format(
@ -1111,18 +1103,10 @@ class HikkaSecurityMod(loader.Module):
target = await self._client.get_entity(message.peer_id)
if not any(
rule["target"] == target.id
for rule in self._client.dispatcher.security._tsec_chat
):
if not self._client.dispatcher.security.remove_rules("chat", target.id):
await utils.answer(message, self.strings("no_rules"))
return
self._client.dispatcher.security._tsec_chat = [
rule
for rule in self._client.dispatcher.security._tsec_chat
if rule["target"] != target.id
]
await utils.answer(
message,
self.strings("rules_removed").format(
@ -1143,8 +1127,8 @@ class HikkaSecurityMod(loader.Module):
args = utils.get_args(message)
if not args:
if (
not self._client.dispatcher.security._tsec_chat
and not self._client.dispatcher.security._tsec_user
not self._client.dispatcher.security.tsec_chat
and not self._client.dispatcher.security.tsec_user
):
await utils.answer(message, self.strings("no_rules"))
return
@ -1158,14 +1142,14 @@ class HikkaSecurityMod(loader.Module):
f" href='{rule['entity_url']}'>{utils.escape_html(rule['entity_name'])}</a>"
f" {self._convert_time(int(rule['expires'] - time.time()))} {self.strings('for')} {self.strings(rule['rule_type'])}</b>"
f" <code>{rule['rule']}</code>"
for rule in self._client.dispatcher.security._tsec_chat
for rule in self._client.dispatcher.security.tsec_chat
]
+ [
"<emoji document_id='6037122016849432064'>👤</emoji> <b><a"
f" href='{rule['entity_url']}'>{utils.escape_html(rule['entity_name'])}</a>"
f" {self._convert_time(int(rule['expires'] - time.time()))} {self.strings('for')} {self.strings(rule['rule_type'])}</b>"
f" <code>{rule['rule']}</code>"
for rule in self._client.dispatcher.security._tsec_user
for rule in self._client.dispatcher.security.tsec_user
]
)
),

View File

@ -17,6 +17,38 @@ class PointerList(list):
self._default = default
super().__init__(db.get(module, key, default))
def sync(self):
super().__init__(self._db.get(self._module, self._key, self._default))
def __getitem__(self, index: int) -> Any:
self.sync()
return super().__getitem__(index)
def __iter__(self) -> Iterable:
self.sync()
return super().__iter__()
def __reversed__(self) -> Iterable:
self.sync()
return super().__reversed__()
def __contains__(self, item: Any) -> bool:
self.sync()
return super().__contains__(item)
def __len__(self) -> int:
self.sync()
return super().__len__()
def __bool__(self) -> bool:
return bool(self._db.get(self._module, self._key, self._default))
def __repr__(self):
return f"PointerList({list(self)})"
def __str__(self):
return f"PointerList({list(self)})"
def __delitem__(self, __i: Union[SupportsIndex, slice]) -> None:
a = super().__delitem__(__i)
self._save()
@ -37,9 +69,6 @@ class PointerList(list):
self._save()
return a
def __str__(self):
return f"PointerList({list(self)})"
def append(self, value: Any):
super().append(value)
self._save()
@ -85,6 +114,43 @@ class PointerDict(dict):
self._default = default
super().__init__(db.get(module, key, default))
def sync(self):
super().__init__(self._db.get(self._module, self._key, self._default))
def __repr__(self):
return f"PointerDict({dict(self)})"
def __bool__(self) -> bool:
return bool(self._db.get(self._module, self._key, self._default))
def __reversed__(self) -> Iterable:
self.sync()
return super().__reversed__()
def __contains__(self, item: Any) -> bool:
self.sync()
return super().__contains__(item)
def __getitem__(self, key: str) -> Any:
self.sync()
return super().__getitem__(key)
def __iter__(self) -> Iterable:
self.sync()
return super().__iter__()
def items(self) -> Iterable:
self.sync()
return super().items()
def keys(self) -> Iterable:
self.sync()
return super().keys()
def values(self) -> Iterable:
self.sync()
return super().values()
def __setitem__(self, key: str, value: Any):
super().__setitem__(key, value)
self._save()

View File

@ -28,12 +28,14 @@ import logging
import time
from typing import Optional
from telethon import TelegramClient
from telethon.hints import EntityLike
from telethon.utils import get_display_name
from telethon.tl.functions.messages import GetFullChatRequest
from telethon.tl.types import ChatParticipantAdmin, ChatParticipantCreator, Message
from . import main, utils
from .database import Database
logger = logging.getLogger(__name__)
@ -153,24 +155,33 @@ def _sec(func: callable, flags: int) -> callable:
class SecurityManager:
def __init__(self, db):
self._any_admin = db.get(__name__, "any_admin", False)
self._default = db.get(__name__, "default", DEFAULT_PERMISSIONS)
def __init__(self, client: TelegramClient, db: Database):
self._client = client
self._db = db
self._cache = {}
self._tsec_chat = self._db.pointer(__name__, "tsec_chat", [])
self._tsec_user = self._db.pointer(__name__, "tsec_user", [])
self._any_admin = db.get(__name__, "any_admin", False)
self._default = db.get(__name__, "default", DEFAULT_PERMISSIONS)
self._tsec_chat = db.pointer(__name__, "tsec_chat", [])
self._tsec_user = db.pointer(__name__, "tsec_user", [])
self._owner = db.pointer(__name__, "owner", [])
self._sudo = db.pointer(__name__, "sudo", [])
self._support = db.pointer(__name__, "support", [])
self._reload_rights()
self.any_admin = self._any_admin
self.default = self._default
self.tsec_chat = self._tsec_chat
self.tsec_user = self._tsec_user
self.owner = self._owner
self.sudo = self._sudo
self.support = self._support
def _reload_rights(self):
self._owner = list(
set(
self._db.get(__name__, "owner", []).copy()
+ ([self._client.tg_id] if hasattr(self, "_client") else [])
)
)
self._sudo = list(set(self._db.get(__name__, "sudo", []).copy()))
self._support = list(set(self._db.get(__name__, "support", []).copy()))
if self._client.tg_id not in self._owner:
self._owner.append(self._client.tg_id)
for info in self._tsec_user.copy():
if info["expires"] < time.time():
self._tsec_user.remove(info)
@ -179,9 +190,6 @@ class SecurityManager:
if info["expires"] < time.time():
self._tsec_chat.remove(info)
async def init(self, client):
self._client = client
def add_rule(
self,
target_type: str,
@ -209,6 +217,22 @@ class SecurityManager:
}
)
def remove_rules(self, target_type: str, target_id: int) -> bool:
any_ = False
if target_type == "user":
for rule in self.tsec_user.copy():
if rule["target"] == target_id:
self.tsec_user.remove(rule)
any_ = True
elif target_type == "chat":
for rule in self.tsec_chat.copy():
if rule["target"] == target_id:
self.tsec_chat.remove(rule)
any_ = True
return any_
def get_flags(self, func: callable) -> int:
if isinstance(func, int):
config = func
@ -297,7 +321,7 @@ class SecurityManager:
cmd = message.raw_text[1:].split()[0].strip()
except Exception:
cmd = None
if callable(func):
for info in self._tsec_user.copy():
if info["target"] == user:

View File

@ -35,6 +35,9 @@ from .pointers import ( # skipcq: PY-W2000
logger = logging.getLogger(__name__)
JSONSerializable = Union[str, int, float, bool, list, dict, None]
class StringLoader(SourceLoader):
"""Load a python module/file from a string"""

View File

@ -335,7 +335,7 @@ async def answer(
if isinstance(response, str) and not kwargs.pop("asfile", False):
text, entities = parse_mode.parse(response)
if len(text) >= 4096:
if len(text) >= 4096 and not hasattr(message, "hikka_grepped"):
try:
if not message.client.loader.inline.init_complete:
raise
@ -404,7 +404,9 @@ async def answer(
"reply_to",
getattr(message, "reply_to_msg_id", None),
)
result = await message.client.send_file(message.chat_id, response, **kwargs)
result = await message.client.send_file(message.peer_id, response, **kwargs)
if message.out:
await message.delete()
return result
@ -670,7 +672,7 @@ def get_named_platform() -> str:
if is_okteto:
return "☁️ Okteto"
if is_codespaces:
return "🐈‍⬛ Codespaces"