diff --git a/CHANGELOG.md b/CHANGELOG.md index a4bd3fc..8ed7e09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ - Add `utils.find_caller` - Add possible cause of error in logs (module and method) - Add `client.get_perms_cached` to cache native `client.get_permissions` +- Add `client.get_fullchannel` with cache - Add `exp` and `force` params to `client.get_perms_cached` and `client.get_entity` - Add `exp` cached values check in `client.get_perms_cached` and `client.get_entity` - Change errors format in web to more human-readable diff --git a/hikka/main.py b/hikka/main.py index 6a9e62f..3ee19e7 100755 --- a/hikka/main.py +++ b/hikka/main.py @@ -56,7 +56,7 @@ from . import database, loader, utils, heroku from .dispatcher import CommandDispatcher from .translations import Translator from .version import __version__ -from .tl_cache import install_entity_caching, install_perms_caching +from .tl_cache import install_entity_caching, install_perms_caching, install_fullchannel_caching try: from .web import core @@ -525,6 +525,7 @@ class Hikka: install_entity_caching(client) install_perms_caching(client) + install_fullchannel_caching(client) self.clients += [client] except sqlite3.OperationalError: diff --git a/hikka/tl_cache.py b/hikka/tl_cache.py index f39f1e0..e5e3acb 100644 --- a/hikka/tl_cache.py +++ b/hikka/tl_cache.py @@ -11,8 +11,11 @@ import time import asyncio import logging from typing import Optional, Union +from xml.dom.minidom import Entity from telethon.hints import EntityLike from telethon import TelegramClient +from telethon.tl.functions.channels import GetFullChannelRequest +from telethon.tl.types import ChannelFull logger = logging.getLogger(__name__) @@ -43,7 +46,7 @@ class CacheRecord: return self._exp < time.time() def __eq__(self, record: "CacheRecord"): - return hash(record._hashable_entity) == hash(self._hashable_entity) + return hash(record) == hash(self) def __hash__(self): return hash(self._hashable_entity) @@ -72,19 +75,45 @@ class CacheRecordPerms: def expired(self): return self._exp < time.time() - def __eq__(self, record: "CacheRecord"): - return hash((record._hashable_entity, record._hashable_user)) == hash( - (self._hashable_entity, self._hashable_user) - ) + def __eq__(self, record: "CacheRecordPerms"): + return hash(record) == hash(self) def __hash__(self): return hash((self._hashable_entity, self._hashable_user)) def __str__(self): - return f"CacheRecord of {self.perms}" + return f"CacheRecordPerms of {self.perms}" def __repr__(self): - return f"CacheRecord(perms={type(self.perms).__name__}(...), exp={self._exp})" + return ( + f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})" + ) + + +class CacheRecordFullChannel: + def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int): + self.channel_id = channel_id + self.full_channel = full_channel + self._exp = round(time.time() + exp) + self.ts = time.time() + + def expired(self): + return self._exp < time.time() + + def __eq__(self, record: "CacheRecordFullChannel"): + return hash(record) == hash(self) + + def __hash__(self): + return hash((self._hashable_entity, self._hashable_user)) + + def __str__(self): + return f"CacheRecordFullChannel of {self.channel_id}" + + def __repr__(self): + return ( + f"CacheRecordFullChannel(channel_id={self.channel_id}(...)," + f" exp={self._exp})" + ) def install_entity_caching(client: TelegramClient): @@ -93,15 +122,14 @@ def install_entity_caching(client: TelegramClient): old = client.get_entity async def new( - entity: EntityLike, exp: Optional[int] = 5 * 60, force: Optional[bool] = False + entity: EntityLike, + exp: Optional[int] = 5 * 60, + force: Optional[bool] = False, ): # Will be used to determine, which client caused logging messages # parsed via inspect.stack() _hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq - if force: - return await old(entity) - if not hashable(entity): try: hashable_entity = next( @@ -121,7 +149,8 @@ def install_entity_caching(client: TelegramClient): hashable_entity = int(str(hashable_entity)[4:]) if ( - hashable_entity + not force + and hashable_entity and hashable_entity in client._hikka_entity_cache and ( not exp @@ -187,9 +216,6 @@ def install_perms_caching(client: TelegramClient): # parsed via inspect.stack() _hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq - if force: - return await old(entity, user) - entity = await client.get_entity(entity) user = await client.get_entity(user) if user else None @@ -226,7 +252,8 @@ def install_perms_caching(client: TelegramClient): hashable_user = int(str(hashable_user)[4:]) if ( - hashable_entity + not force + and hashable_entity and hashable_user and hashable_user in client._hikka_perms_cache.get(hashable_entity, {}) and ( @@ -295,3 +322,66 @@ def install_perms_caching(client: TelegramClient): client.get_perms_cached = new asyncio.ensure_future(cleaner(client)) logger.debug("Monkeypatched client with perms cacher") + + +def install_fullchannel_caching(client: TelegramClient): + client._hikka_fullchannel_cache = {} + + async def get_fullchannel( + entity: EntityLike, + exp: Optional[int] = 300, + force: Optional[bool] = False, + ) -> ChannelFull: + """ + Gets the FullChannelRequest and cache it + :param channel_id: Channel to fetch ChannelFull of + :param exp: Expiration time of the cache record and maximum time of already cached record + :param force: Whether to force refresh the cache (make API request) + :return: :obj:`FullChannel` + """ + if not hashable(entity): + try: + hashable_entity = next( + getattr(entity, attr) + for attr in {"channel_id", "chat_id", "id"} + if getattr(entity, attr, None) + ) + except StopIteration: + logger.debug( + f"Can't parse hashable from {entity=}, using legacy fullchannel request" + ) + return await client(GetFullChannelRequest(channel=entity)) + else: + hashable_entity = entity + + if str(hashable_entity).isdigit() and int(hashable_entity) < 0: + hashable_entity = int(str(hashable_entity)[4:]) + + if ( + not force + and client._hikka_fullchannel_cache.get(hashable_entity) + and not client._hikka_fullchannel_cache[hashable_entity].expired() + and client._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time() + ): + return client._hikka_fullchannel_cache[hashable_entity].full_channel + + result = await client(GetFullChannelRequest(channel=entity)) + client._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel( + hashable_entity, + result, + exp, + ) + return result + + async def cleaner(client: TelegramClient): + while True: + for channel_id, record in client._hikka_fullchannel_cache.copy().items(): + if record.expired(): + del client._hikka_fullchannel_cache[channel_id] + logger.debug(f"Cleaned outdated fullchannel cache {channel_id=}") + + await asyncio.sleep(3) + + client.get_fullchannel = get_fullchannel + asyncio.ensure_future(cleaner(client)) + logger.debug("Monkeypatched client with fullchannel cacher")