mirror of https://github.com/coddrago/Heroku
466 lines
16 KiB
Python
466 lines
16 KiB
Python
# █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
|
|
# █▀█ █ █ █ █▀█ █▀▄ █
|
|
# © Copyright 2022
|
|
# https://t.me/hikariatama
|
|
#
|
|
# 🔒 Licensed under the GNU AGPLv3
|
|
# 🌐 https://www.gnu.org/licenses/agpl-3.0.html
|
|
|
|
import copy
|
|
import inspect
|
|
import time
|
|
import logging
|
|
import typing
|
|
|
|
from telethon import TelegramClient
|
|
from telethon.hints import EntityLike
|
|
from telethon.utils import is_list_like
|
|
from telethon.network import MTProtoSender
|
|
from telethon.tl.tlobject import TLRequest
|
|
from telethon.tl.functions.channels import GetFullChannelRequest
|
|
from telethon.tl.functions.users import GetFullUserRequest
|
|
from telethon.tl.types import (
|
|
ChannelFull,
|
|
UserFull,
|
|
Updates,
|
|
UpdatesCombined,
|
|
UpdateShort,
|
|
)
|
|
|
|
from .types import (
|
|
CacheRecord,
|
|
CacheRecordPerms,
|
|
CacheRecordFullChannel,
|
|
CacheRecordFullUser,
|
|
Module,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def hashable(value: typing.Any) -> bool:
|
|
"""
|
|
Determine whether `value` can be hashed.
|
|
|
|
This is a copy of `collections.abc.Hashable` from Python 3.8.
|
|
"""
|
|
|
|
try:
|
|
hash(value)
|
|
except TypeError:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class CustomTelegramClient(TelegramClient):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._hikka_entity_cache = {}
|
|
self._hikka_perms_cache = {}
|
|
self._hikka_fullchannel_cache = {}
|
|
self._hikka_fulluser_cache = {}
|
|
self.__forbidden_constructors = []
|
|
self.raw_updates_processor = None # Will be monkeypatched by pyro proxy
|
|
|
|
async def force_get_entity(self, *args, **kwargs):
|
|
"""Forcefully makes a request to Telegram to get the entity."""
|
|
|
|
return await self.get_entity(*args, force=True, **kwargs)
|
|
|
|
async def get_entity(
|
|
self,
|
|
entity: EntityLike,
|
|
exp: int = 5 * 60,
|
|
force: bool = False,
|
|
):
|
|
"""
|
|
Gets the entity and cache it
|
|
|
|
:param entity: Entity to fetch
|
|
:param exp: Expiration time of the cache record and maximum time of already cached record
|
|
:param force: Whether to force refresh the cache (make API request)
|
|
:return: :obj:`Entity`
|
|
"""
|
|
|
|
# Will be used to determine, which client caused logging messages
|
|
# parsed via inspect.stack()
|
|
_hikka_client_id_logging_tag = copy.copy(self.tg_id) # skipcq
|
|
|
|
if not hashable(entity):
|
|
try:
|
|
hashable_entity = next(
|
|
getattr(entity, attr)
|
|
for attr in {"user_id", "channel_id", "chat_id", "id"}
|
|
if getattr(entity, attr, None)
|
|
)
|
|
except StopIteration:
|
|
logger.debug(
|
|
"Can't parse hashable from entity %s, using legacy resolve",
|
|
entity,
|
|
)
|
|
return await TelegramClient.get_entity(self, entity)
|
|
else:
|
|
hashable_entity = entity
|
|
|
|
if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
|
|
hashable_entity = int(str(hashable_entity)[4:])
|
|
|
|
if (
|
|
not force
|
|
and hashable_entity
|
|
and hashable_entity in self._hikka_entity_cache
|
|
and (
|
|
not exp
|
|
or self._hikka_entity_cache[hashable_entity].ts + exp > time.time()
|
|
)
|
|
):
|
|
logger.debug(
|
|
"Using cached entity %s (%s)",
|
|
entity,
|
|
type(self._hikka_entity_cache[hashable_entity].entity).__name__,
|
|
)
|
|
return copy.deepcopy(self._hikka_entity_cache[hashable_entity].entity)
|
|
|
|
resolved_entity = await TelegramClient.get_entity(self, entity)
|
|
|
|
if resolved_entity:
|
|
cache_record = CacheRecord(hashable_entity, resolved_entity, exp)
|
|
self._hikka_entity_cache[hashable_entity] = cache_record
|
|
logger.debug("Saved hashable_entity %s to cache", hashable_entity)
|
|
|
|
if getattr(resolved_entity, "id", None):
|
|
logger.debug("Saved resolved_entity id %s to cache", resolved_entity.id)
|
|
self._hikka_entity_cache[resolved_entity.id] = cache_record
|
|
|
|
if getattr(resolved_entity, "username", None):
|
|
logger.debug(
|
|
"Saved resolved_entity username @%s to cache",
|
|
resolved_entity.username,
|
|
)
|
|
self._hikka_entity_cache[f"@{resolved_entity.username}"] = cache_record
|
|
self._hikka_entity_cache[resolved_entity.username] = cache_record
|
|
|
|
return copy.deepcopy(resolved_entity)
|
|
|
|
async def get_perms_cached(
|
|
self,
|
|
entity: EntityLike,
|
|
user: typing.Optional[EntityLike] = None,
|
|
exp: int = 5 * 60,
|
|
force: bool = False,
|
|
):
|
|
"""
|
|
Gets the permissions of the user in the entity and cache it
|
|
|
|
:param entity: Entity to fetch
|
|
:param user: User to fetch
|
|
:param exp: Expiration time of the cache record and maximum time of already cached record
|
|
:param force: Whether to force refresh the cache (make API request)
|
|
:return: :obj:`ChatPermissions`
|
|
"""
|
|
|
|
# Will be used to determine, which client caused logging messages
|
|
# parsed via inspect.stack()
|
|
_hikka_client_id_logging_tag = copy.copy(self.tg_id) # skipcq
|
|
|
|
entity = await self.get_entity(entity)
|
|
user = await self.get_entity(user) if user else None
|
|
|
|
if not hashable(entity) or not hashable(user):
|
|
try:
|
|
hashable_entity = next(
|
|
getattr(entity, attr)
|
|
for attr in {"user_id", "channel_id", "chat_id", "id"}
|
|
if getattr(entity, attr, None)
|
|
)
|
|
except StopIteration:
|
|
logger.debug(
|
|
"Can't parse hashable from entity %s, using legacy method",
|
|
entity,
|
|
)
|
|
return await self.get_permissions(entity, user)
|
|
|
|
try:
|
|
hashable_user = next(
|
|
getattr(user, attr)
|
|
for attr in {"user_id", "channel_id", "chat_id", "id"}
|
|
if getattr(user, attr, None)
|
|
)
|
|
except StopIteration:
|
|
logger.debug(
|
|
"Can't parse hashable from user %s, using legacy method",
|
|
user,
|
|
)
|
|
return await self.get_permissions(entity, user)
|
|
else:
|
|
hashable_entity = entity
|
|
hashable_user = user
|
|
|
|
if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
|
|
hashable_entity = int(str(hashable_entity)[4:])
|
|
|
|
if str(hashable_user).isdigit() and int(hashable_user) < 0:
|
|
hashable_user = int(str(hashable_user)[4:])
|
|
|
|
if (
|
|
not force
|
|
and hashable_entity
|
|
and hashable_user
|
|
and hashable_user in self._hikka_perms_cache.get(hashable_entity, {})
|
|
and (
|
|
not exp
|
|
or self._hikka_perms_cache[hashable_entity][hashable_user].ts + exp
|
|
> time.time()
|
|
)
|
|
):
|
|
logger.debug("Using cached perms %s (%s)", hashable_entity, hashable_user)
|
|
return copy.deepcopy(
|
|
self._hikka_perms_cache[hashable_entity][hashable_user].perms
|
|
)
|
|
|
|
resolved_perms = await self.get_permissions(entity, user)
|
|
|
|
if resolved_perms:
|
|
cache_record = CacheRecordPerms(
|
|
hashable_entity,
|
|
hashable_user,
|
|
resolved_perms,
|
|
exp,
|
|
)
|
|
self._hikka_perms_cache.setdefault(hashable_entity, {})[
|
|
hashable_user
|
|
] = cache_record
|
|
logger.debug("Saved hashable_entity %s perms to cache", hashable_entity)
|
|
|
|
def save_user(key: typing.Union[str, int]):
|
|
nonlocal self, cache_record, user, hashable_user
|
|
if getattr(user, "id", None):
|
|
self._hikka_perms_cache.setdefault(key, {})[user.id] = cache_record
|
|
|
|
if getattr(user, "username", None):
|
|
self._hikka_perms_cache.setdefault(key, {})[
|
|
f"@{user.username}"
|
|
] = cache_record
|
|
self._hikka_perms_cache.setdefault(key, {})[
|
|
user.username
|
|
] = cache_record
|
|
|
|
if getattr(entity, "id", None):
|
|
logger.debug("Saved resolved_entity id %s perms to cache", entity.id)
|
|
save_user(entity.id)
|
|
|
|
if getattr(entity, "username", None):
|
|
logger.debug(
|
|
"Saved resolved_entity username @%s perms to cache",
|
|
entity.username,
|
|
)
|
|
save_user(f"@{entity.username}")
|
|
save_user(entity.username)
|
|
|
|
return copy.deepcopy(resolved_perms)
|
|
|
|
async def get_fullchannel(
|
|
self,
|
|
entity: EntityLike,
|
|
exp: int = 300,
|
|
force: bool = False,
|
|
) -> ChannelFull:
|
|
"""
|
|
Gets the FullChannelRequest and cache it
|
|
|
|
:param entity: Channel to fetch ChannelFull of
|
|
:param exp: Expiration time of the cache record and maximum time of already cached record
|
|
:param force: Whether to force refresh the cache (make API request)
|
|
:return: :obj:`ChannelFull`
|
|
"""
|
|
if not hashable(entity):
|
|
try:
|
|
hashable_entity = next(
|
|
getattr(entity, attr)
|
|
for attr in {"channel_id", "chat_id", "id"}
|
|
if getattr(entity, attr, None)
|
|
)
|
|
except StopIteration:
|
|
logger.debug(
|
|
"Can't parse hashable from entity %s, using legacy fullchannel"
|
|
" request",
|
|
entity,
|
|
)
|
|
return await self(GetFullChannelRequest(channel=entity))
|
|
else:
|
|
hashable_entity = entity
|
|
|
|
if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
|
|
hashable_entity = int(str(hashable_entity)[4:])
|
|
|
|
if (
|
|
not force
|
|
and self._hikka_fullchannel_cache.get(hashable_entity)
|
|
and not self._hikka_fullchannel_cache[hashable_entity].expired()
|
|
and self._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
|
|
):
|
|
return self._hikka_fullchannel_cache[hashable_entity].full_channel
|
|
|
|
result = await self(GetFullChannelRequest(channel=entity))
|
|
self._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
|
|
hashable_entity,
|
|
result,
|
|
exp,
|
|
)
|
|
return result
|
|
|
|
async def get_fulluser(
|
|
self,
|
|
entity: EntityLike,
|
|
exp: int = 300,
|
|
force: bool = False,
|
|
) -> UserFull:
|
|
"""
|
|
Gets the FullUserRequest and cache it
|
|
|
|
:param entity: User to fetch UserFull of
|
|
:param exp: Expiration time of the cache record and maximum time of already cached record
|
|
:param force: Whether to force refresh the cache (make API request)
|
|
:return: :obj:`UserFull`
|
|
"""
|
|
if not hashable(entity):
|
|
try:
|
|
hashable_entity = next(
|
|
getattr(entity, attr)
|
|
for attr in {"user_id", "chat_id", "id"}
|
|
if getattr(entity, attr, None)
|
|
)
|
|
except StopIteration:
|
|
logger.debug(
|
|
"Can't parse hashable from entity %s, using legacy fulluser"
|
|
" request",
|
|
entity,
|
|
)
|
|
return await self(GetFullUserRequest(entity))
|
|
else:
|
|
hashable_entity = entity
|
|
|
|
if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
|
|
hashable_entity = int(str(hashable_entity)[4:])
|
|
|
|
if (
|
|
not force
|
|
and self._hikka_fulluser_cache.get(hashable_entity)
|
|
and not self._hikka_fulluser_cache[hashable_entity].expired()
|
|
and self._hikka_fulluser_cache[hashable_entity].ts + exp > time.time()
|
|
):
|
|
return self._hikka_fulluser_cache[hashable_entity].full_user
|
|
|
|
result = await self(GetFullUserRequest(entity))
|
|
self._hikka_fulluser_cache[hashable_entity] = CacheRecordFullUser(
|
|
hashable_entity,
|
|
result,
|
|
exp,
|
|
)
|
|
return result
|
|
|
|
async def _call(
|
|
self,
|
|
sender: MTProtoSender,
|
|
request: TLRequest,
|
|
ordered: bool = False,
|
|
flood_sleep_threshold: typing.Optional[int] = None,
|
|
):
|
|
"""
|
|
Calls the given request and handles user-side forbidden constructors
|
|
|
|
:param sender: Sender to use
|
|
:param request: Request to send
|
|
:param ordered: Whether to send the request ordered
|
|
:param flood_sleep_threshold: Flood sleep threshold
|
|
:return: The result of the request
|
|
"""
|
|
|
|
# ⚠️⚠️ WARNING! ⚠️⚠️
|
|
# If you are a module developer, and you'll try to bypass this protection to
|
|
# force user join your channel, you will be added to SCAM modules
|
|
# list and you will be banned from Hikka federation.
|
|
# Let USER decide, which channel he will follow. Do not be so petty
|
|
# I hope, you understood me.
|
|
# Thank you
|
|
|
|
if not self.__forbidden_constructors:
|
|
return await TelegramClient._call(
|
|
self,
|
|
sender,
|
|
request,
|
|
ordered,
|
|
flood_sleep_threshold,
|
|
)
|
|
|
|
not_tuple = False
|
|
if not is_list_like(request):
|
|
not_tuple = True
|
|
request = (request,)
|
|
|
|
new_request = []
|
|
|
|
for item in request:
|
|
if item.CONSTRUCTOR_ID in self.__forbidden_constructors and next(
|
|
(
|
|
frame_info.frame.f_locals["self"]
|
|
for frame_info in inspect.stack()
|
|
if hasattr(frame_info, "frame")
|
|
and hasattr(frame_info.frame, "f_locals")
|
|
and isinstance(frame_info.frame.f_locals, dict)
|
|
and "self" in frame_info.frame.f_locals
|
|
and isinstance(frame_info.frame.f_locals["self"], Module)
|
|
and not getattr(
|
|
frame_info.frame.f_locals["self"], "__origin__", ""
|
|
).startswith("<core")
|
|
),
|
|
None,
|
|
):
|
|
logger.debug(
|
|
"🎉 I protected you from unintented %s (%s)!",
|
|
item.__class__.__name__,
|
|
item,
|
|
)
|
|
continue
|
|
|
|
new_request += [item]
|
|
|
|
if not new_request:
|
|
return
|
|
|
|
return await TelegramClient._call(
|
|
self,
|
|
sender,
|
|
new_request[0] if not_tuple else tuple(new_request),
|
|
ordered,
|
|
flood_sleep_threshold,
|
|
)
|
|
|
|
def forbid_constructor(self, constructor: int):
|
|
"""
|
|
Forbids the given constructor to be called
|
|
|
|
:param constructor: Constructor id to forbid
|
|
"""
|
|
self.__forbidden_constructors.extend([constructor])
|
|
self.__forbidden_constructors = list(set(self.__forbidden_constructors))
|
|
|
|
def forbid_constructors(self, constructors: list):
|
|
"""
|
|
Forbids the given constructors to be called.
|
|
All existing forbidden constructors will be removed
|
|
|
|
:param constructors: Constructor ids to forbid
|
|
"""
|
|
self.__forbidden_constructors = list(set(constructors))
|
|
|
|
def _handle_update(
|
|
self: "CustomTelegramClient",
|
|
update: typing.Union[Updates, UpdatesCombined, UpdateShort],
|
|
):
|
|
if self.raw_updates_processor is not None:
|
|
self.raw_updates_processor(update)
|
|
|
|
super()._handle_update(update)
|