mirror of https://github.com/coddrago/Heroku
476 lines
16 KiB
Python
476 lines
16 KiB
Python
# █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
|
|
# █▀█ █ █ █ █▀█ █▀▄ █
|
|
# © Copyright 2022
|
|
# https://t.me/hikariatama
|
|
#
|
|
# 🔒 Licensed under the GNU AGPLv3
|
|
# 🌐 https://www.gnu.org/licenses/agpl-3.0.html
|
|
|
|
import copy
|
|
import time
|
|
import asyncio
|
|
import logging
|
|
from typing import Optional, Union
|
|
from telethon.hints import EntityLike
|
|
from telethon import TelegramClient
|
|
from telethon.tl.functions.channels import GetFullChannelRequest
|
|
from telethon.tl.functions.users import GetFullUserRequest
|
|
from telethon.tl.types import ChannelFull, UserFull
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def hashable(value):
|
|
"""Determine whether `value` can be hashed."""
|
|
try:
|
|
hash(value)
|
|
except TypeError:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class CacheRecord:
|
|
def __init__(
|
|
self,
|
|
hashable_entity: "Hashable", # type: ignore
|
|
resolved_entity: EntityLike,
|
|
exp: int,
|
|
):
|
|
self.entity = copy.deepcopy(resolved_entity)
|
|
self._hashable_entity = copy.deepcopy(hashable_entity)
|
|
self._exp = round(time.time() + exp)
|
|
self.ts = time.time()
|
|
|
|
def expired(self):
|
|
return self._exp < time.time()
|
|
|
|
def __eq__(self, record: "CacheRecord"):
|
|
return hash(record) == hash(self)
|
|
|
|
def __hash__(self):
|
|
return hash(self._hashable_entity)
|
|
|
|
def __str__(self):
|
|
return f"CacheRecord of {self.entity}"
|
|
|
|
def __repr__(self):
|
|
return f"CacheRecord(entity={type(self.entity).__name__}(...), exp={self._exp})"
|
|
|
|
|
|
class CacheRecordPerms:
|
|
def __init__(
|
|
self,
|
|
hashable_entity: "Hashable", # type: ignore
|
|
hashable_user: "Hashable", # type: ignore
|
|
resolved_perms: EntityLike,
|
|
exp: int,
|
|
):
|
|
self.perms = copy.deepcopy(resolved_perms)
|
|
self._hashable_entity = copy.deepcopy(hashable_entity)
|
|
self._hashable_user = copy.deepcopy(hashable_user)
|
|
self._exp = round(time.time() + exp)
|
|
self.ts = time.time()
|
|
|
|
def expired(self):
|
|
return self._exp < time.time()
|
|
|
|
def __eq__(self, record: "CacheRecordPerms"):
|
|
return hash(record) == hash(self)
|
|
|
|
def __hash__(self):
|
|
return hash((self._hashable_entity, self._hashable_user))
|
|
|
|
def __str__(self):
|
|
return f"CacheRecordPerms of {self.perms}"
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})"
|
|
)
|
|
|
|
|
|
class CacheRecordFullChannel:
|
|
def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int):
|
|
self.channel_id = channel_id
|
|
self.full_channel = full_channel
|
|
self._exp = round(time.time() + exp)
|
|
self.ts = time.time()
|
|
|
|
def expired(self):
|
|
return self._exp < time.time()
|
|
|
|
def __eq__(self, record: "CacheRecordFullChannel"):
|
|
return hash(record) == hash(self)
|
|
|
|
def __hash__(self):
|
|
return hash((self._hashable_entity, self._hashable_user))
|
|
|
|
def __str__(self):
|
|
return f"CacheRecordFullChannel of {self.channel_id}"
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"CacheRecordFullChannel(channel_id={self.channel_id}(...),"
|
|
f" exp={self._exp})"
|
|
)
|
|
|
|
|
|
class CacheRecordFullUser:
|
|
def __init__(self, user_id: int, full_user: UserFull, exp: int):
|
|
self.user_id = user_id
|
|
self.full_user = full_user
|
|
self._exp = round(time.time() + exp)
|
|
self.ts = time.time()
|
|
|
|
def expired(self):
|
|
return self._exp < time.time()
|
|
|
|
def __eq__(self, record: "CacheRecordFullUser"):
|
|
return hash(record) == hash(self)
|
|
|
|
def __hash__(self):
|
|
return hash((self._hashable_entity, self._hashable_user))
|
|
|
|
def __str__(self):
|
|
return f"CacheRecordFullUser of {self.user_id}"
|
|
|
|
def __repr__(self):
|
|
return f"CacheRecordFullUser(channel_id={self.user_id}(...), exp={self._exp})"
|
|
|
|
|
|
def install_entity_caching(client: TelegramClient):
|
|
client._hikka_entity_cache = {}
|
|
|
|
old = client.get_entity
|
|
|
|
async def new(
|
|
entity: EntityLike,
|
|
exp: Optional[int] = 5 * 60,
|
|
force: Optional[bool] = False,
|
|
):
|
|
# Will be used to determine, which client caused logging messages
|
|
# parsed via inspect.stack()
|
|
_hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
|
|
|
|
if not hashable(entity):
|
|
try:
|
|
hashable_entity = next(
|
|
getattr(entity, attr)
|
|
for attr in {"user_id", "channel_id", "chat_id", "id"}
|
|
if getattr(entity, attr, None)
|
|
)
|
|
except StopIteration:
|
|
logger.debug(
|
|
f"Can't parse hashable from {entity=}, using legacy resolve"
|
|
)
|
|
return await old(entity)
|
|
else:
|
|
hashable_entity = entity
|
|
|
|
if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
|
|
hashable_entity = int(str(hashable_entity)[4:])
|
|
|
|
if (
|
|
not force
|
|
and hashable_entity
|
|
and hashable_entity in client._hikka_entity_cache
|
|
and (
|
|
not exp
|
|
or client._hikka_entity_cache[hashable_entity].ts + exp > time.time()
|
|
)
|
|
):
|
|
logger.debug(
|
|
"Using cached entity"
|
|
f" {entity} ({type(client._hikka_entity_cache[hashable_entity].entity).__name__})"
|
|
)
|
|
return copy.deepcopy(client._hikka_entity_cache[hashable_entity].entity)
|
|
|
|
resolved_entity = await old(entity)
|
|
|
|
if resolved_entity:
|
|
cache_record = CacheRecord(hashable_entity, resolved_entity, exp)
|
|
client._hikka_entity_cache[hashable_entity] = cache_record
|
|
logger.debug(f"Saved hashable_entity {hashable_entity} to cache")
|
|
|
|
if getattr(resolved_entity, "id", None):
|
|
logger.debug(f"Saved resolved_entity id {resolved_entity.id} to cache")
|
|
client._hikka_entity_cache[resolved_entity.id] = cache_record
|
|
|
|
if getattr(resolved_entity, "username", None):
|
|
logger.debug(
|
|
f"Saved resolved_entity username @{resolved_entity.username} to"
|
|
" cache"
|
|
)
|
|
client._hikka_entity_cache[
|
|
f"@{resolved_entity.username}"
|
|
] = cache_record
|
|
client._hikka_entity_cache[resolved_entity.username] = cache_record
|
|
|
|
return copy.deepcopy(resolved_entity)
|
|
|
|
async def cleaner(client: TelegramClient):
|
|
while True:
|
|
for record, record_data in client._hikka_entity_cache.copy().items():
|
|
if record_data.expired():
|
|
del client._hikka_entity_cache[record]
|
|
logger.debug(f"Cleaned outdated cache {record=}")
|
|
|
|
await asyncio.sleep(3)
|
|
|
|
client.get_entity = new
|
|
client.force_get_entity = old
|
|
asyncio.ensure_future(cleaner(client))
|
|
logger.debug("Monkeypatched client with entity cacher")
|
|
|
|
|
|
def install_perms_caching(client: TelegramClient):
|
|
client._hikka_perms_cache = {}
|
|
|
|
old = client.get_permissions
|
|
|
|
async def new(
|
|
entity: EntityLike,
|
|
user: Optional[EntityLike] = None,
|
|
exp: Optional[int] = 5 * 60,
|
|
force: Optional[bool] = False,
|
|
):
|
|
# Will be used to determine, which client caused logging messages
|
|
# parsed via inspect.stack()
|
|
_hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
|
|
|
|
entity = await client.get_entity(entity)
|
|
user = await client.get_entity(user) if user else None
|
|
|
|
if not hashable(entity) or not hashable(user):
|
|
try:
|
|
hashable_entity = next(
|
|
getattr(entity, attr)
|
|
for attr in {"user_id", "channel_id", "chat_id", "id"}
|
|
if getattr(entity, attr, None)
|
|
)
|
|
except StopIteration:
|
|
logger.debug(
|
|
f"Can't parse hashable from {entity=}, using legacy method"
|
|
)
|
|
return await old(entity, user)
|
|
|
|
try:
|
|
hashable_user = next(
|
|
getattr(user, attr)
|
|
for attr in {"user_id", "channel_id", "chat_id", "id"}
|
|
if getattr(user, attr, None)
|
|
)
|
|
except StopIteration:
|
|
logger.debug(f"Can't parse hashable from {user=}, using legacy method")
|
|
return await old(entity, user)
|
|
else:
|
|
hashable_entity = entity
|
|
hashable_user = user
|
|
|
|
if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
|
|
hashable_entity = int(str(hashable_entity)[4:])
|
|
|
|
if str(hashable_user).isdigit() and int(hashable_user) < 0:
|
|
hashable_user = int(str(hashable_user)[4:])
|
|
|
|
if (
|
|
not force
|
|
and hashable_entity
|
|
and hashable_user
|
|
and hashable_user in client._hikka_perms_cache.get(hashable_entity, {})
|
|
and (
|
|
not exp
|
|
or client._hikka_perms_cache[hashable_entity][hashable_user].ts + exp
|
|
> time.time()
|
|
)
|
|
):
|
|
logger.debug(f"Using cached perms {hashable_entity} ({hashable_user})")
|
|
return copy.deepcopy(
|
|
client._hikka_perms_cache[hashable_entity][hashable_user].perms
|
|
)
|
|
|
|
resolved_perms = await old(entity, user)
|
|
|
|
if resolved_perms:
|
|
cache_record = CacheRecordPerms(
|
|
hashable_entity,
|
|
hashable_user,
|
|
resolved_perms,
|
|
exp,
|
|
)
|
|
client._hikka_perms_cache.setdefault(hashable_entity, {})[
|
|
hashable_user
|
|
] = cache_record
|
|
logger.debug(f"Saved hashable_entity {hashable_entity} perms to cache")
|
|
|
|
def save_user(key: Union[str, int]):
|
|
nonlocal client, cache_record, user, hashable_user
|
|
if getattr(user, "id", None):
|
|
client._hikka_perms_cache.setdefault(key, {})[
|
|
user.id
|
|
] = cache_record
|
|
|
|
if getattr(user, "username", None):
|
|
client._hikka_perms_cache.setdefault(key, {})[
|
|
f"@{user.username}"
|
|
] = cache_record
|
|
client._hikka_perms_cache.setdefault(key, {})[
|
|
user.username
|
|
] = cache_record
|
|
|
|
if getattr(entity, "id", None):
|
|
logger.debug(f"Saved resolved_entity id {entity.id} perms to cache")
|
|
save_user(entity.id)
|
|
|
|
if getattr(entity, "username", None):
|
|
logger.debug(
|
|
f"Saved resolved_entity username @{entity.username} perms to cache"
|
|
)
|
|
save_user(f"@{entity.username}")
|
|
save_user(entity.username)
|
|
|
|
return copy.deepcopy(resolved_perms)
|
|
|
|
async def cleaner(client: TelegramClient):
|
|
while True:
|
|
for chat, chat_data in client._hikka_perms_cache.copy().items():
|
|
for user, user_data in chat_data.copy().items():
|
|
if user_data.expired():
|
|
del client._hikka_perms_cache[chat][user]
|
|
logger.debug(f"Cleaned outdated perms cache {chat=} {user=}")
|
|
|
|
await asyncio.sleep(3)
|
|
|
|
client.get_perms_cached = new
|
|
asyncio.ensure_future(cleaner(client))
|
|
logger.debug("Monkeypatched client with perms cacher")
|
|
|
|
|
|
def install_fullchannel_caching(client: TelegramClient):
|
|
client._hikka_fullchannel_cache = {}
|
|
|
|
async def get_fullchannel(
|
|
entity: EntityLike,
|
|
exp: Optional[int] = 300,
|
|
force: Optional[bool] = False,
|
|
) -> ChannelFull:
|
|
"""
|
|
Gets the FullChannelRequest and cache it
|
|
:param entity: Channel to fetch ChannelFull of
|
|
:param exp: Expiration time of the cache record and maximum time of already cached record
|
|
:param force: Whether to force refresh the cache (make API request)
|
|
:return: :obj:`ChannelFull`
|
|
"""
|
|
if not hashable(entity):
|
|
try:
|
|
hashable_entity = next(
|
|
getattr(entity, attr)
|
|
for attr in {"channel_id", "chat_id", "id"}
|
|
if getattr(entity, attr, None)
|
|
)
|
|
except StopIteration:
|
|
logger.debug(
|
|
f"Can't parse hashable from {entity=}, using legacy fullchannel"
|
|
" request"
|
|
)
|
|
return await client(GetFullChannelRequest(channel=entity))
|
|
else:
|
|
hashable_entity = entity
|
|
|
|
if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
|
|
hashable_entity = int(str(hashable_entity)[4:])
|
|
|
|
if (
|
|
not force
|
|
and client._hikka_fullchannel_cache.get(hashable_entity)
|
|
and not client._hikka_fullchannel_cache[hashable_entity].expired()
|
|
and client._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
|
|
):
|
|
return client._hikka_fullchannel_cache[hashable_entity].full_channel
|
|
|
|
result = await client(GetFullChannelRequest(channel=entity))
|
|
client._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
|
|
hashable_entity,
|
|
result,
|
|
exp,
|
|
)
|
|
return result
|
|
|
|
async def cleaner(client: TelegramClient):
|
|
while True:
|
|
for channel_id, record in client._hikka_fullchannel_cache.copy().items():
|
|
if record.expired():
|
|
del client._hikka_fullchannel_cache[channel_id]
|
|
logger.debug(f"Cleaned outdated fullchannel cache {channel_id=}")
|
|
|
|
await asyncio.sleep(3)
|
|
|
|
client.get_fullchannel = get_fullchannel
|
|
asyncio.ensure_future(cleaner(client))
|
|
logger.debug("Monkeypatched client with fullchannel cacher")
|
|
|
|
|
|
def install_fulluser_caching(client: TelegramClient):
|
|
client._hikka_fulluser_cache = {}
|
|
|
|
async def get_fulluser(
|
|
entity: EntityLike,
|
|
exp: Optional[int] = 300,
|
|
force: Optional[bool] = False,
|
|
) -> ChannelFull:
|
|
"""
|
|
Gets the FullUserRequest and cache it
|
|
:param entity: User to fetch UserFull of
|
|
:param exp: Expiration time of the cache record and maximum time of already cached record
|
|
:param force: Whether to force refresh the cache (make API request)
|
|
:return: :obj:`UserFull`
|
|
"""
|
|
if not hashable(entity):
|
|
try:
|
|
hashable_entity = next(
|
|
getattr(entity, attr)
|
|
for attr in {"user_id", "chat_id", "id"}
|
|
if getattr(entity, attr, None)
|
|
)
|
|
except StopIteration:
|
|
logger.debug(
|
|
f"Can't parse hashable from {entity=}, using legacy fulluser"
|
|
" request"
|
|
)
|
|
return await client(GetFullUserRequest(entity))
|
|
else:
|
|
hashable_entity = entity
|
|
|
|
if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
|
|
hashable_entity = int(str(hashable_entity)[4:])
|
|
|
|
if (
|
|
not force
|
|
and client._hikka_fulluser_cache.get(hashable_entity)
|
|
and not client._hikka_fulluser_cache[hashable_entity].expired()
|
|
and client._hikka_fulluser_cache[hashable_entity].ts + exp > time.time()
|
|
):
|
|
return client._hikka_fulluser_cache[hashable_entity].full_user
|
|
|
|
result = await client(GetFullUserRequest(entity))
|
|
client._hikka_fulluser_cache[hashable_entity] = CacheRecordFullUser(
|
|
hashable_entity,
|
|
result,
|
|
exp,
|
|
)
|
|
return result
|
|
|
|
async def cleaner(client: TelegramClient):
|
|
while True:
|
|
for user_id, record in client._hikka_fulluser_cache.copy().items():
|
|
if record.expired():
|
|
del client._hikka_fulluser_cache[user_id]
|
|
logger.debug(f"Cleaned outdated fulluser cache {user_id=}")
|
|
|
|
await asyncio.sleep(3)
|
|
|
|
client.get_fulluser = get_fulluser
|
|
asyncio.ensure_future(cleaner(client))
|
|
logger.debug("Monkeypatched client with fulluser cacher")
|