diff --git a/CHANGELOG.md b/CHANGELOG.md index fddaf9f..67ec89a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,12 +10,16 @@ - Fix `utils.find_caller` for :method:`hikka.inline.utils.Utils._find_caller_sec_map` - Fix `.eval` - Fix: use old lib if its version is higher than new one +- Fix grep for messages bigger than 4096 UTF-8 characters - Add more animated emojis to modules - Add targeted security for users and chats (`.tsec`) - Add support for `tg_level` in `.config Tester` - Add `-f` param to `.restart` and `.update` - Add platform-specific Hikka emojis to premium users - Add codepaces to `utils.get_named_platform` +- Rename `func` tag to `filter` due to internal python conflict with dynamically generated methods +- Partially rework security unit +- Internal refactoring and typehints ## 🌑 Hikka 1.3.3 diff --git a/hikka/database.py b/hikka/database.py index 3c7dda0..5c13d11 100755 --- a/hikka/database.py +++ b/hikka/database.py @@ -27,8 +27,9 @@ except ImportError as e: raise e -from typing import Any, Union +from typing import Optional, Union +from telethon import TelegramClient from telethon.tl.types import Message from telethon.errors.rpcerrorlist import ChannelsTooMuchError @@ -37,6 +38,7 @@ from .pointers import ( PointerList, PointerDict, ) +from .types import JSONSerializable DATA_DIR = ( os.path.normpath(os.path.join(utils.get_base_dir(), "..")) @@ -60,7 +62,7 @@ class Database(dict): _redis = None _saving_task = None - def __init__(self, client): + def __init__(self, client: TelegramClient): super().__init__() self._client = client @@ -347,14 +349,19 @@ class Database(dict): return asset[0] if asset else None - def get(self, owner: str, key: str, default: Any = None) -> Any: + def get( + self, + owner: str, + key: str, + default: Optional[JSONSerializable] = None, + ) -> JSONSerializable: """Get database key""" try: return self[owner][key] except KeyError: return default - def set(self, owner: str, key: str, value: Any) -> bool: + def set(self, owner: str, key: str, value: JSONSerializable) -> bool: """Set database key""" if not utils.is_serializable(owner): raise RuntimeError( @@ -380,7 +387,12 @@ class Database(dict): super().setdefault(owner, {})[key] = value return self.save() - def pointer(self, owner: str, key: str, default: Any = None) -> Any: + def pointer( + self, + owner: str, + key: str, + default: Optional[JSONSerializable] = None, + ) -> JSONSerializable: """Get a pointer to database key""" value = self.get(owner, key, default) mapping = { diff --git a/hikka/dispatcher.py b/hikka/dispatcher.py index 1d47071..8939c5d 100755 --- a/hikka/dispatcher.py +++ b/hikka/dispatcher.py @@ -1,4 +1,4 @@ -"""Obviously, dispatches stuff""" +"""Processes incoming events and dispatches them to appropriate handlers""" # Friendly Telegram (telegram userbot) # Copyright (C) 2018-2022 The Authors @@ -33,7 +33,7 @@ import re import traceback from typing import Tuple, Union -from telethon import types +from telethon import TelegramClient from telethon.tl.types import Message from . import main, security, utils @@ -64,7 +64,7 @@ ALL_TAGS = [ "startswith", "endswith", "contains", - "func", + "filter", "from_id", "chat_id", "regex", @@ -79,25 +79,33 @@ def _decrement_ratelimit(delay, data, key, severity): class CommandDispatcher: - def __init__(self, modules: Modules, db: Database, no_nickname: bool = False): + def __init__( + self, + modules: Modules, + client: TelegramClient, + db: Database, + no_nickname: bool = False, + ): self._modules = modules + self._client = client + self.client = client self._db = db - self.security = security.SecurityManager(db) self.no_nickname = no_nickname + self._ratelimit_storage_user = collections.defaultdict(int) self._ratelimit_storage_chat = collections.defaultdict(int) self._ratelimit_max_user = db.get(__name__, "ratelimit_max_user", 30) self._ratelimit_max_chat = db.get(__name__, "ratelimit_max_chat", 100) + + self.security = security.SecurityManager(client, db) + self.check_security = self.security.check - - async def init(self, client: "TelegramClient"): # type: ignore - await self.security.init(client) - me = await client.get_me() - - self.client = client # Intended to be used to track user in logging - - self._me = me.id - self._cached_username = me.username.lower() if me.username else str(me.id) + self._me = self._client.hikka_me.id + self._cached_username = ( + self._client.hikka_me.username.lower() + if self._client.hikka_me.username + else str(self._client.hikka_me.id) + ) async def _handle_ratelimit(self, message: Message, func: callable) -> bool: if await self.security.check( @@ -227,6 +235,7 @@ class CommandDispatcher: message.edit = my_edit message.reply = my_reply message.respond = my_respond + message.hikka_grepped = True return message @@ -422,7 +431,7 @@ class CommandDispatcher: or (getattr(func, "in", False) and getattr(message, "out", True)) or ( getattr(func, "only_messages", False) - and not isinstance(message, types.Message) + and not isinstance(message, Message) ) or ( getattr(func, "editable", False) @@ -435,13 +444,13 @@ class CommandDispatcher: ) or ( getattr(func, "no_media", False) - and isinstance(message, types.Message) + and isinstance(message, Message) and getattr(message, "media", False) ) or ( getattr(func, "only_media", False) and ( - not isinstance(message, types.Message) + not isinstance(message, Message) or not getattr(message, "media", False) ) ) @@ -510,9 +519,9 @@ class CommandDispatcher: ) ) or ( - getattr(func, "func", False) - and callable(func.func) - and not func.func(message) + getattr(func, "filter", False) + and callable(func.filter) + and not func.filter(message) ) or ( getattr(func, "from_id", False) @@ -556,7 +565,7 @@ class CommandDispatcher: if ( modname in bl - and isinstance(message, types.Message) + and isinstance(message, Message) and ( "*" in bl[modname] or utils.get_chat_id(message) in bl[modname] @@ -584,7 +593,7 @@ class CommandDispatcher: # Avoid weird AttributeErrors in weird dochub modules by settings placeholder # of attributes - for placeholder in {"text", "raw_text"}: + for placeholder in {"text", "raw_text", "out"}: try: if not hasattr(message, placeholder): setattr(message, placeholder, "") diff --git a/hikka/loader.py b/hikka/loader.py index 4a4399b..5f9be8d 100644 --- a/hikka/loader.py +++ b/hikka/loader.py @@ -26,32 +26,35 @@ import asyncio import contextlib -import copy -from functools import partial, wraps -import importlib -import importlib.util import inspect import logging import os import re import sys -from importlib.machinery import ModuleSpec -from types import FunctionType -from typing import Any, Awaitable, Hashable, Optional, Union, List import requests +import copy + +import importlib +import importlib.util +import importlib.machinery +from functools import partial, wraps + from telethon import TelegramClient from telethon.tl.types import Message, InputPeerNotifySettings, Channel from telethon.tl.functions.account import UpdateNotifySettingsRequest from telethon.hints import EntityLike +from types import FunctionType +from typing import Any, Optional, Union, List + from . import security, utils, validators, version from .types import ( ConfigValue, # skipcq - LoadError, # skipcq + LoadError, Module, - Library, # skipcq + Library, ModuleConfig, # skipcq - LibraryConfig, # skipcq + LibraryConfig, SelfUnload, SelfSuspend, StopLoop, @@ -61,6 +64,7 @@ from .types import ( StringLoader, get_commands, get_inline_handlers, + JSONSerializable, ) from .inline.core import InlineManager from .inline.types import InlineCall @@ -431,7 +435,7 @@ def tag(*tags, **kwarg_tags): • `endswith` - Capture only messages that end with given text • `contains` - Capture only messages that contain given text • `regex` - Capture only messages that match given regex - • `func` - Capture only messages that pass given function + • `filter` - Capture only messages that pass given function • `from_id` - Capture only messages from given user • `chat_id` - Capture only messages from given chat @@ -612,7 +616,7 @@ class Modules: logger.debug(f"Loading {module_name} from filesystem") with open(mod, "r") as file: - spec = ModuleSpec( + spec = importlib.machinery.ModuleSpec( module_name, StringLoader(file.read(), user_friendly_origin), origin=user_friendly_origin, @@ -624,7 +628,7 @@ class Modules: async def register_module( self, - spec: ModuleSpec, + spec: importlib.machinery.ModuleSpec, module_name: str, origin: str = "", save_fs: bool = False, @@ -917,10 +921,10 @@ class Modules: instance.allclients = self.allclients instance.allmodules = self instance.hikka = True - instance.get = partial(self._mod_get, _modname=instance.__class__.__name__) - instance.set = partial(self._mod_set, _modname=instance.__class__.__name__) + instance.get = partial(self._get, _owner=instance.__class__.__name__) + instance.set = partial(self._set, _owner=instance.__class__.__name__) instance.pointer = partial( - self._mod_pointer, _modname=instance.__class__.__name__ + self._pointer, _owner=instance.__class__.__name__ ) instance.get_prefix = partial(self._db.get, "hikka.main", "command_prefix", ".") instance.client = self.client @@ -955,43 +959,24 @@ class Modules: self.modules += [instance] - def _mod_get( + def _get( self, key: str, - default: Optional[Hashable] = None, - _modname: str = None, - ) -> Hashable: - return self._db.get(_modname, key, default) + default: Optional[JSONSerializable] = None, + _owner: str = None, + ) -> JSONSerializable: + return self._db.get(_owner, key, default) - def _mod_set(self, key: str, value: Hashable, _modname: str = None) -> bool: - return self._db.set(_modname, key, value) + def _set(self, key: str, value: JSONSerializable, _owner: str = None) -> bool: + return self._db.set(_owner, key, value) - def _mod_pointer( + def _pointer( self, key: str, - default: Optional[Hashable] = None, - _modname: str = None, - ) -> Any: - return self._db.pointer(_modname, key, default) - - def _lib_get( - self, - key: str, - default: Optional[Hashable] = None, - _lib: Library = None, - ) -> Hashable: - return self._db.get(_lib.__class__.__name__, key, default) - - def _lib_set(self, key: str, value: Hashable, _lib: Library = None) -> bool: - return self._db.set(_lib.__class__.__name__, key, value) - - def _lib_pointer( - self, - key: str, - default: Optional[Hashable] = None, - _lib: Library = None, - ) -> Any: - return self._db.pointer(_lib.__class__.__name__, key, default) + default: Optional[JSONSerializable] = None, + _owner: str = None, + ) -> JSONSerializable: + return self._db.pointer(_owner, key, default) async def _mod_import_lib( self, @@ -1046,7 +1031,7 @@ class Modules: module = f"hikka.libraries.{url.replace('%', '%%').replace('.', '%d')}" origin = f"" - spec = ModuleSpec(module, StringLoader(code, origin), origin=origin) + spec = importlib.machinery.ModuleSpec(module, StringLoader(code, origin), origin=origin) try: instance = importlib.util.module_from_spec(spec) sys.modules[module] = instance @@ -1140,9 +1125,9 @@ class Modules: lib_obj.inline = self.inline lib_obj.tg_id = self.client.tg_id lib_obj.allmodules = self - lib_obj._lib_get = partial(self._lib_get, _lib=lib_obj) # skipcq - lib_obj._lib_set = partial(self._lib_set, _lib=lib_obj) # skipcq - lib_obj._lib_pointer = partial(self._lib_pointer, _lib=lib_obj) # skipcq + lib_obj._lib_get = partial(self._get, _owner=lib_obj) # skipcq + lib_obj._lib_set = partial(self._set, _owner=lib_obj) # skipcq + lib_obj._lib_pointer = partial(self._pointer, _owner=lib_obj) # skipcq lib_obj.get_prefix = partial(self._db.get, "hikka.main", "command_prefix", ".") for old_lib in self.libraries: diff --git a/hikka/main.py b/hikka/main.py index e3aef47..cd5a7dc 100755 --- a/hikka/main.py +++ b/hikka/main.py @@ -625,9 +625,8 @@ class Hikka: async def _add_dispatcher(self, client, modules, db): """Inits and adds dispatcher instance to client""" - dispatcher = CommandDispatcher(modules, db, self.arguments.no_nickname) + dispatcher = CommandDispatcher(modules, client, db, self.arguments.no_nickname) client.dispatcher = dispatcher - await dispatcher.init(client) modules.check_security = dispatcher.check_security client.add_event_handler( diff --git a/hikka/modules/hikka_security.py b/hikka/modules/hikka_security.py index 34839d6..f6b4cf3 100755 --- a/hikka/modules/hikka_security.py +++ b/hikka/modules/hikka_security.py @@ -543,7 +543,7 @@ class HikkaSecurityMod(loader.Module): ) -> dict: config = self._db.get(security.__name__, "masks", {}).get( f"{command.__module__}.{command.__name__}", - getattr(command, "security", self._client.dispatcher.security._default), + getattr(command, "security", self._client.dispatcher.security.default), ) return self._perms_map(config, is_inline) @@ -938,7 +938,7 @@ class HikkaSecurityMod(loader.Module): try: if not args[1].isdigit() and not args[1].startswith("@"): raise ValueError - + target = await self._client.get_entity( int(args[1]) if args[1].isdigit() else args[1] ) @@ -1011,7 +1011,7 @@ class HikkaSecurityMod(loader.Module): await utils.answer(message, self.strings("no_target")) return - if target.id in self._client.dispatcher.security._owner: + if target.id in self._client.dispatcher.security.owner: await utils.answer(message, self.strings("owner_target")) return @@ -1059,8 +1059,8 @@ class HikkaSecurityMod(loader.Module): async def tsecrm(self, message: Message): """<"user"/"chat"> - Remove targeted security rule""" if ( - not self._client.dispatcher.security._tsec_chat - and not self._client.dispatcher.security._tsec_user + not self._client.dispatcher.security.tsec_chat + and not self._client.dispatcher.security.tsec_user ): await utils.answer(message, self.strings("no_rules")) return @@ -1084,18 +1084,10 @@ class HikkaSecurityMod(loader.Module): await utils.answer(message, self.strings("no_target")) return - if not any( - rule["target"] == target.id - for rule in self._client.dispatcher.security._tsec_user - ): + if not self._client.dispatcher.security.remove_rules("user", target.id): await utils.answer(message, self.strings("no_rules")) return - self._client.dispatcher.security._tsec_user = [ - rule - for rule in self._client.dispatcher.security._tsec_user - if rule["target"] != target.id - ] await utils.answer( message, self.strings("rules_removed").format( @@ -1111,18 +1103,10 @@ class HikkaSecurityMod(loader.Module): target = await self._client.get_entity(message.peer_id) - if not any( - rule["target"] == target.id - for rule in self._client.dispatcher.security._tsec_chat - ): + if not self._client.dispatcher.security.remove_rules("chat", target.id): await utils.answer(message, self.strings("no_rules")) return - self._client.dispatcher.security._tsec_chat = [ - rule - for rule in self._client.dispatcher.security._tsec_chat - if rule["target"] != target.id - ] await utils.answer( message, self.strings("rules_removed").format( @@ -1143,8 +1127,8 @@ class HikkaSecurityMod(loader.Module): args = utils.get_args(message) if not args: if ( - not self._client.dispatcher.security._tsec_chat - and not self._client.dispatcher.security._tsec_user + not self._client.dispatcher.security.tsec_chat + and not self._client.dispatcher.security.tsec_user ): await utils.answer(message, self.strings("no_rules")) return @@ -1158,14 +1142,14 @@ class HikkaSecurityMod(loader.Module): f" href='{rule['entity_url']}'>{utils.escape_html(rule['entity_name'])}" f" {self._convert_time(int(rule['expires'] - time.time()))} {self.strings('for')} {self.strings(rule['rule_type'])}" f" {rule['rule']}" - for rule in self._client.dispatcher.security._tsec_chat + for rule in self._client.dispatcher.security.tsec_chat ] + [ "👤 {utils.escape_html(rule['entity_name'])}" f" {self._convert_time(int(rule['expires'] - time.time()))} {self.strings('for')} {self.strings(rule['rule_type'])}" f" {rule['rule']}" - for rule in self._client.dispatcher.security._tsec_user + for rule in self._client.dispatcher.security.tsec_user ] ) ), diff --git a/hikka/pointers.py b/hikka/pointers.py index e4ced1f..c6bdceb 100644 --- a/hikka/pointers.py +++ b/hikka/pointers.py @@ -17,6 +17,38 @@ class PointerList(list): self._default = default super().__init__(db.get(module, key, default)) + def sync(self): + super().__init__(self._db.get(self._module, self._key, self._default)) + + def __getitem__(self, index: int) -> Any: + self.sync() + return super().__getitem__(index) + + def __iter__(self) -> Iterable: + self.sync() + return super().__iter__() + + def __reversed__(self) -> Iterable: + self.sync() + return super().__reversed__() + + def __contains__(self, item: Any) -> bool: + self.sync() + return super().__contains__(item) + + def __len__(self) -> int: + self.sync() + return super().__len__() + + def __bool__(self) -> bool: + return bool(self._db.get(self._module, self._key, self._default)) + + def __repr__(self): + return f"PointerList({list(self)})" + + def __str__(self): + return f"PointerList({list(self)})" + def __delitem__(self, __i: Union[SupportsIndex, slice]) -> None: a = super().__delitem__(__i) self._save() @@ -37,9 +69,6 @@ class PointerList(list): self._save() return a - def __str__(self): - return f"PointerList({list(self)})" - def append(self, value: Any): super().append(value) self._save() @@ -85,6 +114,43 @@ class PointerDict(dict): self._default = default super().__init__(db.get(module, key, default)) + def sync(self): + super().__init__(self._db.get(self._module, self._key, self._default)) + + def __repr__(self): + return f"PointerDict({dict(self)})" + + def __bool__(self) -> bool: + return bool(self._db.get(self._module, self._key, self._default)) + + def __reversed__(self) -> Iterable: + self.sync() + return super().__reversed__() + + def __contains__(self, item: Any) -> bool: + self.sync() + return super().__contains__(item) + + def __getitem__(self, key: str) -> Any: + self.sync() + return super().__getitem__(key) + + def __iter__(self) -> Iterable: + self.sync() + return super().__iter__() + + def items(self) -> Iterable: + self.sync() + return super().items() + + def keys(self) -> Iterable: + self.sync() + return super().keys() + + def values(self) -> Iterable: + self.sync() + return super().values() + def __setitem__(self, key: str, value: Any): super().__setitem__(key, value) self._save() diff --git a/hikka/security.py b/hikka/security.py index 616df6a..23a022d 100755 --- a/hikka/security.py +++ b/hikka/security.py @@ -28,12 +28,14 @@ import logging import time from typing import Optional +from telethon import TelegramClient from telethon.hints import EntityLike from telethon.utils import get_display_name from telethon.tl.functions.messages import GetFullChatRequest from telethon.tl.types import ChatParticipantAdmin, ChatParticipantCreator, Message from . import main, utils +from .database import Database logger = logging.getLogger(__name__) @@ -153,24 +155,33 @@ def _sec(func: callable, flags: int) -> callable: class SecurityManager: - def __init__(self, db): - self._any_admin = db.get(__name__, "any_admin", False) - self._default = db.get(__name__, "default", DEFAULT_PERMISSIONS) + def __init__(self, client: TelegramClient, db: Database): + self._client = client self._db = db self._cache = {} - self._tsec_chat = self._db.pointer(__name__, "tsec_chat", []) - self._tsec_user = self._db.pointer(__name__, "tsec_user", []) + + self._any_admin = db.get(__name__, "any_admin", False) + self._default = db.get(__name__, "default", DEFAULT_PERMISSIONS) + self._tsec_chat = db.pointer(__name__, "tsec_chat", []) + self._tsec_user = db.pointer(__name__, "tsec_user", []) + self._owner = db.pointer(__name__, "owner", []) + self._sudo = db.pointer(__name__, "sudo", []) + self._support = db.pointer(__name__, "support", []) + self._reload_rights() + self.any_admin = self._any_admin + self.default = self._default + self.tsec_chat = self._tsec_chat + self.tsec_user = self._tsec_user + self.owner = self._owner + self.sudo = self._sudo + self.support = self._support + def _reload_rights(self): - self._owner = list( - set( - self._db.get(__name__, "owner", []).copy() - + ([self._client.tg_id] if hasattr(self, "_client") else []) - ) - ) - self._sudo = list(set(self._db.get(__name__, "sudo", []).copy())) - self._support = list(set(self._db.get(__name__, "support", []).copy())) + if self._client.tg_id not in self._owner: + self._owner.append(self._client.tg_id) + for info in self._tsec_user.copy(): if info["expires"] < time.time(): self._tsec_user.remove(info) @@ -179,9 +190,6 @@ class SecurityManager: if info["expires"] < time.time(): self._tsec_chat.remove(info) - async def init(self, client): - self._client = client - def add_rule( self, target_type: str, @@ -209,6 +217,22 @@ class SecurityManager: } ) + def remove_rules(self, target_type: str, target_id: int) -> bool: + any_ = False + + if target_type == "user": + for rule in self.tsec_user.copy(): + if rule["target"] == target_id: + self.tsec_user.remove(rule) + any_ = True + elif target_type == "chat": + for rule in self.tsec_chat.copy(): + if rule["target"] == target_id: + self.tsec_chat.remove(rule) + any_ = True + + return any_ + def get_flags(self, func: callable) -> int: if isinstance(func, int): config = func @@ -297,7 +321,7 @@ class SecurityManager: cmd = message.raw_text[1:].split()[0].strip() except Exception: cmd = None - + if callable(func): for info in self._tsec_user.copy(): if info["target"] == user: diff --git a/hikka/types.py b/hikka/types.py index 5d0052c..de83333 100644 --- a/hikka/types.py +++ b/hikka/types.py @@ -35,6 +35,9 @@ from .pointers import ( # skipcq: PY-W2000 logger = logging.getLogger(__name__) +JSONSerializable = Union[str, int, float, bool, list, dict, None] + + class StringLoader(SourceLoader): """Load a python module/file from a string""" diff --git a/hikka/utils.py b/hikka/utils.py index 5b9119a..829c60a 100755 --- a/hikka/utils.py +++ b/hikka/utils.py @@ -335,7 +335,7 @@ async def answer( if isinstance(response, str) and not kwargs.pop("asfile", False): text, entities = parse_mode.parse(response) - if len(text) >= 4096: + if len(text) >= 4096 and not hasattr(message, "hikka_grepped"): try: if not message.client.loader.inline.init_complete: raise @@ -404,7 +404,9 @@ async def answer( "reply_to", getattr(message, "reply_to_msg_id", None), ) - result = await message.client.send_file(message.chat_id, response, **kwargs) + result = await message.client.send_file(message.peer_id, response, **kwargs) + if message.out: + await message.delete() return result @@ -670,7 +672,7 @@ def get_named_platform() -> str: if is_okteto: return "☁️ Okteto" - + if is_codespaces: return "🐈‍⬛ Codespaces"