mirror of https://github.com/coddrago/Heroku
1.4.0
- Fix grep for messages bigger than 4096 UTF-8 characters - Rename `func` tag to `filter` due to internal python conflict with dynamically generated methods - Partially rework security unit - Internal refactoring and typehintspull/1/head
parent
ab8130ed60
commit
89040b6e2f
|
@ -10,12 +10,16 @@
|
|||
- Fix `utils.find_caller` for :method:`hikka.inline.utils.Utils._find_caller_sec_map`
|
||||
- Fix `.eval`
|
||||
- Fix: use old lib if its version is higher than new one
|
||||
- Fix grep for messages bigger than 4096 UTF-8 characters
|
||||
- Add more animated emojis to modules
|
||||
- Add targeted security for users and chats (`.tsec`)
|
||||
- Add support for `tg_level` in `.config Tester`
|
||||
- Add `-f` param to `.restart` and `.update`
|
||||
- Add platform-specific Hikka emojis to premium users
|
||||
- Add codepaces to `utils.get_named_platform`
|
||||
- Rename `func` tag to `filter` due to internal python conflict with dynamically generated methods
|
||||
- Partially rework security unit
|
||||
- Internal refactoring and typehints
|
||||
|
||||
## 🌑 Hikka 1.3.3
|
||||
|
||||
|
|
|
@ -27,8 +27,9 @@ except ImportError as e:
|
|||
raise e
|
||||
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from telethon import TelegramClient
|
||||
from telethon.tl.types import Message
|
||||
from telethon.errors.rpcerrorlist import ChannelsTooMuchError
|
||||
|
||||
|
@ -37,6 +38,7 @@ from .pointers import (
|
|||
PointerList,
|
||||
PointerDict,
|
||||
)
|
||||
from .types import JSONSerializable
|
||||
|
||||
DATA_DIR = (
|
||||
os.path.normpath(os.path.join(utils.get_base_dir(), ".."))
|
||||
|
@ -60,7 +62,7 @@ class Database(dict):
|
|||
_redis = None
|
||||
_saving_task = None
|
||||
|
||||
def __init__(self, client):
|
||||
def __init__(self, client: TelegramClient):
|
||||
super().__init__()
|
||||
self._client = client
|
||||
|
||||
|
@ -347,14 +349,19 @@ class Database(dict):
|
|||
|
||||
return asset[0] if asset else None
|
||||
|
||||
def get(self, owner: str, key: str, default: Any = None) -> Any:
|
||||
def get(
|
||||
self,
|
||||
owner: str,
|
||||
key: str,
|
||||
default: Optional[JSONSerializable] = None,
|
||||
) -> JSONSerializable:
|
||||
"""Get database key"""
|
||||
try:
|
||||
return self[owner][key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def set(self, owner: str, key: str, value: Any) -> bool:
|
||||
def set(self, owner: str, key: str, value: JSONSerializable) -> bool:
|
||||
"""Set database key"""
|
||||
if not utils.is_serializable(owner):
|
||||
raise RuntimeError(
|
||||
|
@ -380,7 +387,12 @@ class Database(dict):
|
|||
super().setdefault(owner, {})[key] = value
|
||||
return self.save()
|
||||
|
||||
def pointer(self, owner: str, key: str, default: Any = None) -> Any:
|
||||
def pointer(
|
||||
self,
|
||||
owner: str,
|
||||
key: str,
|
||||
default: Optional[JSONSerializable] = None,
|
||||
) -> JSONSerializable:
|
||||
"""Get a pointer to database key"""
|
||||
value = self.get(owner, key, default)
|
||||
mapping = {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
"""Obviously, dispatches stuff"""
|
||||
"""Processes incoming events and dispatches them to appropriate handlers"""
|
||||
|
||||
# Friendly Telegram (telegram userbot)
|
||||
# Copyright (C) 2018-2022 The Authors
|
||||
|
@ -33,7 +33,7 @@ import re
|
|||
import traceback
|
||||
from typing import Tuple, Union
|
||||
|
||||
from telethon import types
|
||||
from telethon import TelegramClient
|
||||
from telethon.tl.types import Message
|
||||
|
||||
from . import main, security, utils
|
||||
|
@ -64,7 +64,7 @@ ALL_TAGS = [
|
|||
"startswith",
|
||||
"endswith",
|
||||
"contains",
|
||||
"func",
|
||||
"filter",
|
||||
"from_id",
|
||||
"chat_id",
|
||||
"regex",
|
||||
|
@ -79,25 +79,33 @@ def _decrement_ratelimit(delay, data, key, severity):
|
|||
|
||||
|
||||
class CommandDispatcher:
|
||||
def __init__(self, modules: Modules, db: Database, no_nickname: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
modules: Modules,
|
||||
client: TelegramClient,
|
||||
db: Database,
|
||||
no_nickname: bool = False,
|
||||
):
|
||||
self._modules = modules
|
||||
self._client = client
|
||||
self.client = client
|
||||
self._db = db
|
||||
self.security = security.SecurityManager(db)
|
||||
self.no_nickname = no_nickname
|
||||
|
||||
self._ratelimit_storage_user = collections.defaultdict(int)
|
||||
self._ratelimit_storage_chat = collections.defaultdict(int)
|
||||
self._ratelimit_max_user = db.get(__name__, "ratelimit_max_user", 30)
|
||||
self._ratelimit_max_chat = db.get(__name__, "ratelimit_max_chat", 100)
|
||||
|
||||
self.security = security.SecurityManager(client, db)
|
||||
|
||||
self.check_security = self.security.check
|
||||
|
||||
async def init(self, client: "TelegramClient"): # type: ignore
|
||||
await self.security.init(client)
|
||||
me = await client.get_me()
|
||||
|
||||
self.client = client # Intended to be used to track user in logging
|
||||
|
||||
self._me = me.id
|
||||
self._cached_username = me.username.lower() if me.username else str(me.id)
|
||||
self._me = self._client.hikka_me.id
|
||||
self._cached_username = (
|
||||
self._client.hikka_me.username.lower()
|
||||
if self._client.hikka_me.username
|
||||
else str(self._client.hikka_me.id)
|
||||
)
|
||||
|
||||
async def _handle_ratelimit(self, message: Message, func: callable) -> bool:
|
||||
if await self.security.check(
|
||||
|
@ -227,6 +235,7 @@ class CommandDispatcher:
|
|||
message.edit = my_edit
|
||||
message.reply = my_reply
|
||||
message.respond = my_respond
|
||||
message.hikka_grepped = True
|
||||
|
||||
return message
|
||||
|
||||
|
@ -422,7 +431,7 @@ class CommandDispatcher:
|
|||
or (getattr(func, "in", False) and getattr(message, "out", True))
|
||||
or (
|
||||
getattr(func, "only_messages", False)
|
||||
and not isinstance(message, types.Message)
|
||||
and not isinstance(message, Message)
|
||||
)
|
||||
or (
|
||||
getattr(func, "editable", False)
|
||||
|
@ -435,13 +444,13 @@ class CommandDispatcher:
|
|||
)
|
||||
or (
|
||||
getattr(func, "no_media", False)
|
||||
and isinstance(message, types.Message)
|
||||
and isinstance(message, Message)
|
||||
and getattr(message, "media", False)
|
||||
)
|
||||
or (
|
||||
getattr(func, "only_media", False)
|
||||
and (
|
||||
not isinstance(message, types.Message)
|
||||
not isinstance(message, Message)
|
||||
or not getattr(message, "media", False)
|
||||
)
|
||||
)
|
||||
|
@ -510,9 +519,9 @@ class CommandDispatcher:
|
|||
)
|
||||
)
|
||||
or (
|
||||
getattr(func, "func", False)
|
||||
and callable(func.func)
|
||||
and not func.func(message)
|
||||
getattr(func, "filter", False)
|
||||
and callable(func.filter)
|
||||
and not func.filter(message)
|
||||
)
|
||||
or (
|
||||
getattr(func, "from_id", False)
|
||||
|
@ -556,7 +565,7 @@ class CommandDispatcher:
|
|||
|
||||
if (
|
||||
modname in bl
|
||||
and isinstance(message, types.Message)
|
||||
and isinstance(message, Message)
|
||||
and (
|
||||
"*" in bl[modname]
|
||||
or utils.get_chat_id(message) in bl[modname]
|
||||
|
@ -584,7 +593,7 @@ class CommandDispatcher:
|
|||
|
||||
# Avoid weird AttributeErrors in weird dochub modules by settings placeholder
|
||||
# of attributes
|
||||
for placeholder in {"text", "raw_text"}:
|
||||
for placeholder in {"text", "raw_text", "out"}:
|
||||
try:
|
||||
if not hasattr(message, placeholder):
|
||||
setattr(message, placeholder, "")
|
||||
|
|
|
@ -26,32 +26,35 @@
|
|||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
from functools import partial, wraps
|
||||
import importlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from importlib.machinery import ModuleSpec
|
||||
from types import FunctionType
|
||||
from typing import Any, Awaitable, Hashable, Optional, Union, List
|
||||
import requests
|
||||
import copy
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import importlib.machinery
|
||||
from functools import partial, wraps
|
||||
|
||||
from telethon import TelegramClient
|
||||
from telethon.tl.types import Message, InputPeerNotifySettings, Channel
|
||||
from telethon.tl.functions.account import UpdateNotifySettingsRequest
|
||||
from telethon.hints import EntityLike
|
||||
|
||||
from types import FunctionType
|
||||
from typing import Any, Optional, Union, List
|
||||
|
||||
from . import security, utils, validators, version
|
||||
from .types import (
|
||||
ConfigValue, # skipcq
|
||||
LoadError, # skipcq
|
||||
LoadError,
|
||||
Module,
|
||||
Library, # skipcq
|
||||
Library,
|
||||
ModuleConfig, # skipcq
|
||||
LibraryConfig, # skipcq
|
||||
LibraryConfig,
|
||||
SelfUnload,
|
||||
SelfSuspend,
|
||||
StopLoop,
|
||||
|
@ -61,6 +64,7 @@ from .types import (
|
|||
StringLoader,
|
||||
get_commands,
|
||||
get_inline_handlers,
|
||||
JSONSerializable,
|
||||
)
|
||||
from .inline.core import InlineManager
|
||||
from .inline.types import InlineCall
|
||||
|
@ -431,7 +435,7 @@ def tag(*tags, **kwarg_tags):
|
|||
• `endswith` - Capture only messages that end with given text
|
||||
• `contains` - Capture only messages that contain given text
|
||||
• `regex` - Capture only messages that match given regex
|
||||
• `func` - Capture only messages that pass given function
|
||||
• `filter` - Capture only messages that pass given function
|
||||
• `from_id` - Capture only messages from given user
|
||||
• `chat_id` - Capture only messages from given chat
|
||||
|
||||
|
@ -612,7 +616,7 @@ class Modules:
|
|||
logger.debug(f"Loading {module_name} from filesystem")
|
||||
|
||||
with open(mod, "r") as file:
|
||||
spec = ModuleSpec(
|
||||
spec = importlib.machinery.ModuleSpec(
|
||||
module_name,
|
||||
StringLoader(file.read(), user_friendly_origin),
|
||||
origin=user_friendly_origin,
|
||||
|
@ -624,7 +628,7 @@ class Modules:
|
|||
|
||||
async def register_module(
|
||||
self,
|
||||
spec: ModuleSpec,
|
||||
spec: importlib.machinery.ModuleSpec,
|
||||
module_name: str,
|
||||
origin: str = "<core>",
|
||||
save_fs: bool = False,
|
||||
|
@ -917,10 +921,10 @@ class Modules:
|
|||
instance.allclients = self.allclients
|
||||
instance.allmodules = self
|
||||
instance.hikka = True
|
||||
instance.get = partial(self._mod_get, _modname=instance.__class__.__name__)
|
||||
instance.set = partial(self._mod_set, _modname=instance.__class__.__name__)
|
||||
instance.get = partial(self._get, _owner=instance.__class__.__name__)
|
||||
instance.set = partial(self._set, _owner=instance.__class__.__name__)
|
||||
instance.pointer = partial(
|
||||
self._mod_pointer, _modname=instance.__class__.__name__
|
||||
self._pointer, _owner=instance.__class__.__name__
|
||||
)
|
||||
instance.get_prefix = partial(self._db.get, "hikka.main", "command_prefix", ".")
|
||||
instance.client = self.client
|
||||
|
@ -955,43 +959,24 @@ class Modules:
|
|||
|
||||
self.modules += [instance]
|
||||
|
||||
def _mod_get(
|
||||
def _get(
|
||||
self,
|
||||
key: str,
|
||||
default: Optional[Hashable] = None,
|
||||
_modname: str = None,
|
||||
) -> Hashable:
|
||||
return self._db.get(_modname, key, default)
|
||||
default: Optional[JSONSerializable] = None,
|
||||
_owner: str = None,
|
||||
) -> JSONSerializable:
|
||||
return self._db.get(_owner, key, default)
|
||||
|
||||
def _mod_set(self, key: str, value: Hashable, _modname: str = None) -> bool:
|
||||
return self._db.set(_modname, key, value)
|
||||
def _set(self, key: str, value: JSONSerializable, _owner: str = None) -> bool:
|
||||
return self._db.set(_owner, key, value)
|
||||
|
||||
def _mod_pointer(
|
||||
def _pointer(
|
||||
self,
|
||||
key: str,
|
||||
default: Optional[Hashable] = None,
|
||||
_modname: str = None,
|
||||
) -> Any:
|
||||
return self._db.pointer(_modname, key, default)
|
||||
|
||||
def _lib_get(
|
||||
self,
|
||||
key: str,
|
||||
default: Optional[Hashable] = None,
|
||||
_lib: Library = None,
|
||||
) -> Hashable:
|
||||
return self._db.get(_lib.__class__.__name__, key, default)
|
||||
|
||||
def _lib_set(self, key: str, value: Hashable, _lib: Library = None) -> bool:
|
||||
return self._db.set(_lib.__class__.__name__, key, value)
|
||||
|
||||
def _lib_pointer(
|
||||
self,
|
||||
key: str,
|
||||
default: Optional[Hashable] = None,
|
||||
_lib: Library = None,
|
||||
) -> Any:
|
||||
return self._db.pointer(_lib.__class__.__name__, key, default)
|
||||
default: Optional[JSONSerializable] = None,
|
||||
_owner: str = None,
|
||||
) -> JSONSerializable:
|
||||
return self._db.pointer(_owner, key, default)
|
||||
|
||||
async def _mod_import_lib(
|
||||
self,
|
||||
|
@ -1046,7 +1031,7 @@ class Modules:
|
|||
module = f"hikka.libraries.{url.replace('%', '%%').replace('.', '%d')}"
|
||||
origin = f"<library {url}>"
|
||||
|
||||
spec = ModuleSpec(module, StringLoader(code, origin), origin=origin)
|
||||
spec = importlib.machinery.ModuleSpec(module, StringLoader(code, origin), origin=origin)
|
||||
try:
|
||||
instance = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module] = instance
|
||||
|
@ -1140,9 +1125,9 @@ class Modules:
|
|||
lib_obj.inline = self.inline
|
||||
lib_obj.tg_id = self.client.tg_id
|
||||
lib_obj.allmodules = self
|
||||
lib_obj._lib_get = partial(self._lib_get, _lib=lib_obj) # skipcq
|
||||
lib_obj._lib_set = partial(self._lib_set, _lib=lib_obj) # skipcq
|
||||
lib_obj._lib_pointer = partial(self._lib_pointer, _lib=lib_obj) # skipcq
|
||||
lib_obj._lib_get = partial(self._get, _owner=lib_obj) # skipcq
|
||||
lib_obj._lib_set = partial(self._set, _owner=lib_obj) # skipcq
|
||||
lib_obj._lib_pointer = partial(self._pointer, _owner=lib_obj) # skipcq
|
||||
lib_obj.get_prefix = partial(self._db.get, "hikka.main", "command_prefix", ".")
|
||||
|
||||
for old_lib in self.libraries:
|
||||
|
|
|
@ -625,9 +625,8 @@ class Hikka:
|
|||
|
||||
async def _add_dispatcher(self, client, modules, db):
|
||||
"""Inits and adds dispatcher instance to client"""
|
||||
dispatcher = CommandDispatcher(modules, db, self.arguments.no_nickname)
|
||||
dispatcher = CommandDispatcher(modules, client, db, self.arguments.no_nickname)
|
||||
client.dispatcher = dispatcher
|
||||
await dispatcher.init(client)
|
||||
modules.check_security = dispatcher.check_security
|
||||
|
||||
client.add_event_handler(
|
||||
|
|
|
@ -543,7 +543,7 @@ class HikkaSecurityMod(loader.Module):
|
|||
) -> dict:
|
||||
config = self._db.get(security.__name__, "masks", {}).get(
|
||||
f"{command.__module__}.{command.__name__}",
|
||||
getattr(command, "security", self._client.dispatcher.security._default),
|
||||
getattr(command, "security", self._client.dispatcher.security.default),
|
||||
)
|
||||
|
||||
return self._perms_map(config, is_inline)
|
||||
|
@ -938,7 +938,7 @@ class HikkaSecurityMod(loader.Module):
|
|||
try:
|
||||
if not args[1].isdigit() and not args[1].startswith("@"):
|
||||
raise ValueError
|
||||
|
||||
|
||||
target = await self._client.get_entity(
|
||||
int(args[1]) if args[1].isdigit() else args[1]
|
||||
)
|
||||
|
@ -1011,7 +1011,7 @@ class HikkaSecurityMod(loader.Module):
|
|||
await utils.answer(message, self.strings("no_target"))
|
||||
return
|
||||
|
||||
if target.id in self._client.dispatcher.security._owner:
|
||||
if target.id in self._client.dispatcher.security.owner:
|
||||
await utils.answer(message, self.strings("owner_target"))
|
||||
return
|
||||
|
||||
|
@ -1059,8 +1059,8 @@ class HikkaSecurityMod(loader.Module):
|
|||
async def tsecrm(self, message: Message):
|
||||
"""<"user"/"chat"> - Remove targeted security rule"""
|
||||
if (
|
||||
not self._client.dispatcher.security._tsec_chat
|
||||
and not self._client.dispatcher.security._tsec_user
|
||||
not self._client.dispatcher.security.tsec_chat
|
||||
and not self._client.dispatcher.security.tsec_user
|
||||
):
|
||||
await utils.answer(message, self.strings("no_rules"))
|
||||
return
|
||||
|
@ -1084,18 +1084,10 @@ class HikkaSecurityMod(loader.Module):
|
|||
await utils.answer(message, self.strings("no_target"))
|
||||
return
|
||||
|
||||
if not any(
|
||||
rule["target"] == target.id
|
||||
for rule in self._client.dispatcher.security._tsec_user
|
||||
):
|
||||
if not self._client.dispatcher.security.remove_rules("user", target.id):
|
||||
await utils.answer(message, self.strings("no_rules"))
|
||||
return
|
||||
|
||||
self._client.dispatcher.security._tsec_user = [
|
||||
rule
|
||||
for rule in self._client.dispatcher.security._tsec_user
|
||||
if rule["target"] != target.id
|
||||
]
|
||||
await utils.answer(
|
||||
message,
|
||||
self.strings("rules_removed").format(
|
||||
|
@ -1111,18 +1103,10 @@ class HikkaSecurityMod(loader.Module):
|
|||
|
||||
target = await self._client.get_entity(message.peer_id)
|
||||
|
||||
if not any(
|
||||
rule["target"] == target.id
|
||||
for rule in self._client.dispatcher.security._tsec_chat
|
||||
):
|
||||
if not self._client.dispatcher.security.remove_rules("chat", target.id):
|
||||
await utils.answer(message, self.strings("no_rules"))
|
||||
return
|
||||
|
||||
self._client.dispatcher.security._tsec_chat = [
|
||||
rule
|
||||
for rule in self._client.dispatcher.security._tsec_chat
|
||||
if rule["target"] != target.id
|
||||
]
|
||||
await utils.answer(
|
||||
message,
|
||||
self.strings("rules_removed").format(
|
||||
|
@ -1143,8 +1127,8 @@ class HikkaSecurityMod(loader.Module):
|
|||
args = utils.get_args(message)
|
||||
if not args:
|
||||
if (
|
||||
not self._client.dispatcher.security._tsec_chat
|
||||
and not self._client.dispatcher.security._tsec_user
|
||||
not self._client.dispatcher.security.tsec_chat
|
||||
and not self._client.dispatcher.security.tsec_user
|
||||
):
|
||||
await utils.answer(message, self.strings("no_rules"))
|
||||
return
|
||||
|
@ -1158,14 +1142,14 @@ class HikkaSecurityMod(loader.Module):
|
|||
f" href='{rule['entity_url']}'>{utils.escape_html(rule['entity_name'])}</a>"
|
||||
f" {self._convert_time(int(rule['expires'] - time.time()))} {self.strings('for')} {self.strings(rule['rule_type'])}</b>"
|
||||
f" <code>{rule['rule']}</code>"
|
||||
for rule in self._client.dispatcher.security._tsec_chat
|
||||
for rule in self._client.dispatcher.security.tsec_chat
|
||||
]
|
||||
+ [
|
||||
"<emoji document_id='6037122016849432064'>👤</emoji> <b><a"
|
||||
f" href='{rule['entity_url']}'>{utils.escape_html(rule['entity_name'])}</a>"
|
||||
f" {self._convert_time(int(rule['expires'] - time.time()))} {self.strings('for')} {self.strings(rule['rule_type'])}</b>"
|
||||
f" <code>{rule['rule']}</code>"
|
||||
for rule in self._client.dispatcher.security._tsec_user
|
||||
for rule in self._client.dispatcher.security.tsec_user
|
||||
]
|
||||
)
|
||||
),
|
||||
|
|
|
@ -17,6 +17,38 @@ class PointerList(list):
|
|||
self._default = default
|
||||
super().__init__(db.get(module, key, default))
|
||||
|
||||
def sync(self):
|
||||
super().__init__(self._db.get(self._module, self._key, self._default))
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
self.sync()
|
||||
return super().__getitem__(index)
|
||||
|
||||
def __iter__(self) -> Iterable:
|
||||
self.sync()
|
||||
return super().__iter__()
|
||||
|
||||
def __reversed__(self) -> Iterable:
|
||||
self.sync()
|
||||
return super().__reversed__()
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
self.sync()
|
||||
return super().__contains__(item)
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.sync()
|
||||
return super().__len__()
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._db.get(self._module, self._key, self._default))
|
||||
|
||||
def __repr__(self):
|
||||
return f"PointerList({list(self)})"
|
||||
|
||||
def __str__(self):
|
||||
return f"PointerList({list(self)})"
|
||||
|
||||
def __delitem__(self, __i: Union[SupportsIndex, slice]) -> None:
|
||||
a = super().__delitem__(__i)
|
||||
self._save()
|
||||
|
@ -37,9 +69,6 @@ class PointerList(list):
|
|||
self._save()
|
||||
return a
|
||||
|
||||
def __str__(self):
|
||||
return f"PointerList({list(self)})"
|
||||
|
||||
def append(self, value: Any):
|
||||
super().append(value)
|
||||
self._save()
|
||||
|
@ -85,6 +114,43 @@ class PointerDict(dict):
|
|||
self._default = default
|
||||
super().__init__(db.get(module, key, default))
|
||||
|
||||
def sync(self):
|
||||
super().__init__(self._db.get(self._module, self._key, self._default))
|
||||
|
||||
def __repr__(self):
|
||||
return f"PointerDict({dict(self)})"
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._db.get(self._module, self._key, self._default))
|
||||
|
||||
def __reversed__(self) -> Iterable:
|
||||
self.sync()
|
||||
return super().__reversed__()
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
self.sync()
|
||||
return super().__contains__(item)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
self.sync()
|
||||
return super().__getitem__(key)
|
||||
|
||||
def __iter__(self) -> Iterable:
|
||||
self.sync()
|
||||
return super().__iter__()
|
||||
|
||||
def items(self) -> Iterable:
|
||||
self.sync()
|
||||
return super().items()
|
||||
|
||||
def keys(self) -> Iterable:
|
||||
self.sync()
|
||||
return super().keys()
|
||||
|
||||
def values(self) -> Iterable:
|
||||
self.sync()
|
||||
return super().values()
|
||||
|
||||
def __setitem__(self, key: str, value: Any):
|
||||
super().__setitem__(key, value)
|
||||
self._save()
|
||||
|
|
|
@ -28,12 +28,14 @@ import logging
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from telethon import TelegramClient
|
||||
from telethon.hints import EntityLike
|
||||
from telethon.utils import get_display_name
|
||||
from telethon.tl.functions.messages import GetFullChatRequest
|
||||
from telethon.tl.types import ChatParticipantAdmin, ChatParticipantCreator, Message
|
||||
|
||||
from . import main, utils
|
||||
from .database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -153,24 +155,33 @@ def _sec(func: callable, flags: int) -> callable:
|
|||
|
||||
|
||||
class SecurityManager:
|
||||
def __init__(self, db):
|
||||
self._any_admin = db.get(__name__, "any_admin", False)
|
||||
self._default = db.get(__name__, "default", DEFAULT_PERMISSIONS)
|
||||
def __init__(self, client: TelegramClient, db: Database):
|
||||
self._client = client
|
||||
self._db = db
|
||||
self._cache = {}
|
||||
self._tsec_chat = self._db.pointer(__name__, "tsec_chat", [])
|
||||
self._tsec_user = self._db.pointer(__name__, "tsec_user", [])
|
||||
|
||||
self._any_admin = db.get(__name__, "any_admin", False)
|
||||
self._default = db.get(__name__, "default", DEFAULT_PERMISSIONS)
|
||||
self._tsec_chat = db.pointer(__name__, "tsec_chat", [])
|
||||
self._tsec_user = db.pointer(__name__, "tsec_user", [])
|
||||
self._owner = db.pointer(__name__, "owner", [])
|
||||
self._sudo = db.pointer(__name__, "sudo", [])
|
||||
self._support = db.pointer(__name__, "support", [])
|
||||
|
||||
self._reload_rights()
|
||||
|
||||
self.any_admin = self._any_admin
|
||||
self.default = self._default
|
||||
self.tsec_chat = self._tsec_chat
|
||||
self.tsec_user = self._tsec_user
|
||||
self.owner = self._owner
|
||||
self.sudo = self._sudo
|
||||
self.support = self._support
|
||||
|
||||
def _reload_rights(self):
|
||||
self._owner = list(
|
||||
set(
|
||||
self._db.get(__name__, "owner", []).copy()
|
||||
+ ([self._client.tg_id] if hasattr(self, "_client") else [])
|
||||
)
|
||||
)
|
||||
self._sudo = list(set(self._db.get(__name__, "sudo", []).copy()))
|
||||
self._support = list(set(self._db.get(__name__, "support", []).copy()))
|
||||
if self._client.tg_id not in self._owner:
|
||||
self._owner.append(self._client.tg_id)
|
||||
|
||||
for info in self._tsec_user.copy():
|
||||
if info["expires"] < time.time():
|
||||
self._tsec_user.remove(info)
|
||||
|
@ -179,9 +190,6 @@ class SecurityManager:
|
|||
if info["expires"] < time.time():
|
||||
self._tsec_chat.remove(info)
|
||||
|
||||
async def init(self, client):
|
||||
self._client = client
|
||||
|
||||
def add_rule(
|
||||
self,
|
||||
target_type: str,
|
||||
|
@ -209,6 +217,22 @@ class SecurityManager:
|
|||
}
|
||||
)
|
||||
|
||||
def remove_rules(self, target_type: str, target_id: int) -> bool:
|
||||
any_ = False
|
||||
|
||||
if target_type == "user":
|
||||
for rule in self.tsec_user.copy():
|
||||
if rule["target"] == target_id:
|
||||
self.tsec_user.remove(rule)
|
||||
any_ = True
|
||||
elif target_type == "chat":
|
||||
for rule in self.tsec_chat.copy():
|
||||
if rule["target"] == target_id:
|
||||
self.tsec_chat.remove(rule)
|
||||
any_ = True
|
||||
|
||||
return any_
|
||||
|
||||
def get_flags(self, func: callable) -> int:
|
||||
if isinstance(func, int):
|
||||
config = func
|
||||
|
@ -297,7 +321,7 @@ class SecurityManager:
|
|||
cmd = message.raw_text[1:].split()[0].strip()
|
||||
except Exception:
|
||||
cmd = None
|
||||
|
||||
|
||||
if callable(func):
|
||||
for info in self._tsec_user.copy():
|
||||
if info["target"] == user:
|
||||
|
|
|
@ -35,6 +35,9 @@ from .pointers import ( # skipcq: PY-W2000
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
JSONSerializable = Union[str, int, float, bool, list, dict, None]
|
||||
|
||||
|
||||
class StringLoader(SourceLoader):
|
||||
"""Load a python module/file from a string"""
|
||||
|
||||
|
|
|
@ -335,7 +335,7 @@ async def answer(
|
|||
if isinstance(response, str) and not kwargs.pop("asfile", False):
|
||||
text, entities = parse_mode.parse(response)
|
||||
|
||||
if len(text) >= 4096:
|
||||
if len(text) >= 4096 and not hasattr(message, "hikka_grepped"):
|
||||
try:
|
||||
if not message.client.loader.inline.init_complete:
|
||||
raise
|
||||
|
@ -404,7 +404,9 @@ async def answer(
|
|||
"reply_to",
|
||||
getattr(message, "reply_to_msg_id", None),
|
||||
)
|
||||
result = await message.client.send_file(message.chat_id, response, **kwargs)
|
||||
result = await message.client.send_file(message.peer_id, response, **kwargs)
|
||||
if message.out:
|
||||
await message.delete()
|
||||
|
||||
return result
|
||||
|
||||
|
@ -670,7 +672,7 @@ def get_named_platform() -> str:
|
|||
|
||||
if is_okteto:
|
||||
return "☁️ Okteto"
|
||||
|
||||
|
||||
if is_codespaces:
|
||||
return "🐈⬛ Codespaces"
|
||||
|
||||
|
|
Loading…
Reference in New Issue