- Add fullchannel caching (`client.get_fullchannel`)
pull/1/head
hikariatama 2022-08-09 21:28:02 +00:00
parent f9857b609d
commit a57f15307b
3 changed files with 109 additions and 17 deletions

View File

@ -15,6 +15,7 @@
- Add `utils.find_caller`
- Add possible cause of error in logs (module and method)
- Add `client.get_perms_cached` to cache native `client.get_permissions`
- Add `client.get_fullchannel` with cache
- Add `exp` and `force` params to `client.get_perms_cached` and `client.get_entity`
- Add `exp` cached values check in `client.get_perms_cached` and `client.get_entity`
- Change errors format in web to more human-readable

View File

@ -56,7 +56,7 @@ from . import database, loader, utils, heroku
from .dispatcher import CommandDispatcher
from .translations import Translator
from .version import __version__
from .tl_cache import install_entity_caching, install_perms_caching
from .tl_cache import install_entity_caching, install_perms_caching, install_fullchannel_caching
try:
from .web import core
@ -525,6 +525,7 @@ class Hikka:
install_entity_caching(client)
install_perms_caching(client)
install_fullchannel_caching(client)
self.clients += [client]
except sqlite3.OperationalError:

View File

@ -11,8 +11,11 @@ import time
import asyncio
import logging
from typing import Optional, Union
from xml.dom.minidom import Entity
from telethon.hints import EntityLike
from telethon import TelegramClient
from telethon.tl.functions.channels import GetFullChannelRequest
from telethon.tl.types import ChannelFull
logger = logging.getLogger(__name__)
@ -43,7 +46,7 @@ class CacheRecord:
return self._exp < time.time()
def __eq__(self, record: "CacheRecord"):
return hash(record._hashable_entity) == hash(self._hashable_entity)
return hash(record) == hash(self)
def __hash__(self):
return hash(self._hashable_entity)
@ -72,19 +75,45 @@ class CacheRecordPerms:
def expired(self):
return self._exp < time.time()
def __eq__(self, record: "CacheRecord"):
return hash((record._hashable_entity, record._hashable_user)) == hash(
(self._hashable_entity, self._hashable_user)
)
def __eq__(self, record: "CacheRecordPerms"):
return hash(record) == hash(self)
def __hash__(self):
return hash((self._hashable_entity, self._hashable_user))
def __str__(self):
return f"CacheRecord of {self.perms}"
return f"CacheRecordPerms of {self.perms}"
def __repr__(self):
return f"CacheRecord(perms={type(self.perms).__name__}(...), exp={self._exp})"
return (
f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})"
)
class CacheRecordFullChannel:
def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int):
self.channel_id = channel_id
self.full_channel = full_channel
self._exp = round(time.time() + exp)
self.ts = time.time()
def expired(self):
return self._exp < time.time()
def __eq__(self, record: "CacheRecordFullChannel"):
return hash(record) == hash(self)
def __hash__(self):
return hash((self._hashable_entity, self._hashable_user))
def __str__(self):
return f"CacheRecordFullChannel of {self.channel_id}"
def __repr__(self):
return (
f"CacheRecordFullChannel(channel_id={self.channel_id}(...),"
f" exp={self._exp})"
)
def install_entity_caching(client: TelegramClient):
@ -93,15 +122,14 @@ def install_entity_caching(client: TelegramClient):
old = client.get_entity
async def new(
entity: EntityLike, exp: Optional[int] = 5 * 60, force: Optional[bool] = False
entity: EntityLike,
exp: Optional[int] = 5 * 60,
force: Optional[bool] = False,
):
# Will be used to determine, which client caused logging messages
# parsed via inspect.stack()
_hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
if force:
return await old(entity)
if not hashable(entity):
try:
hashable_entity = next(
@ -121,7 +149,8 @@ def install_entity_caching(client: TelegramClient):
hashable_entity = int(str(hashable_entity)[4:])
if (
hashable_entity
not force
and hashable_entity
and hashable_entity in client._hikka_entity_cache
and (
not exp
@ -187,9 +216,6 @@ def install_perms_caching(client: TelegramClient):
# parsed via inspect.stack()
_hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
if force:
return await old(entity, user)
entity = await client.get_entity(entity)
user = await client.get_entity(user) if user else None
@ -226,7 +252,8 @@ def install_perms_caching(client: TelegramClient):
hashable_user = int(str(hashable_user)[4:])
if (
hashable_entity
not force
and hashable_entity
and hashable_user
and hashable_user in client._hikka_perms_cache.get(hashable_entity, {})
and (
@ -295,3 +322,66 @@ def install_perms_caching(client: TelegramClient):
client.get_perms_cached = new
asyncio.ensure_future(cleaner(client))
logger.debug("Monkeypatched client with perms cacher")
def install_fullchannel_caching(client: TelegramClient):
client._hikka_fullchannel_cache = {}
async def get_fullchannel(
entity: EntityLike,
exp: Optional[int] = 300,
force: Optional[bool] = False,
) -> ChannelFull:
"""
Gets the FullChannelRequest and cache it
:param channel_id: Channel to fetch ChannelFull of
:param exp: Expiration time of the cache record and maximum time of already cached record
:param force: Whether to force refresh the cache (make API request)
:return: :obj:`FullChannel`
"""
if not hashable(entity):
try:
hashable_entity = next(
getattr(entity, attr)
for attr in {"channel_id", "chat_id", "id"}
if getattr(entity, attr, None)
)
except StopIteration:
logger.debug(
f"Can't parse hashable from {entity=}, using legacy fullchannel request"
)
return await client(GetFullChannelRequest(channel=entity))
else:
hashable_entity = entity
if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
hashable_entity = int(str(hashable_entity)[4:])
if (
not force
and client._hikka_fullchannel_cache.get(hashable_entity)
and not client._hikka_fullchannel_cache[hashable_entity].expired()
and client._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
):
return client._hikka_fullchannel_cache[hashable_entity].full_channel
result = await client(GetFullChannelRequest(channel=entity))
client._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
hashable_entity,
result,
exp,
)
return result
async def cleaner(client: TelegramClient):
while True:
for channel_id, record in client._hikka_fullchannel_cache.copy().items():
if record.expired():
del client._hikka_fullchannel_cache[channel_id]
logger.debug(f"Cleaned outdated fullchannel cache {channel_id=}")
await asyncio.sleep(3)
client.get_fullchannel = get_fullchannel
asyncio.ensure_future(cleaner(client))
logger.debug("Monkeypatched client with fullchannel cacher")