mirror of https://github.com/coddrago/Heroku
parent
f9857b609d
commit
a57f15307b
|
@ -15,6 +15,7 @@
|
|||
- Add `utils.find_caller`
|
||||
- Add possible cause of error in logs (module and method)
|
||||
- Add `client.get_perms_cached` to cache native `client.get_permissions`
|
||||
- Add `client.get_fullchannel` with cache
|
||||
- Add `exp` and `force` params to `client.get_perms_cached` and `client.get_entity`
|
||||
- Add `exp` cached values check in `client.get_perms_cached` and `client.get_entity`
|
||||
- Change errors format in web to more human-readable
|
||||
|
|
|
@ -56,7 +56,7 @@ from . import database, loader, utils, heroku
|
|||
from .dispatcher import CommandDispatcher
|
||||
from .translations import Translator
|
||||
from .version import __version__
|
||||
from .tl_cache import install_entity_caching, install_perms_caching
|
||||
from .tl_cache import install_entity_caching, install_perms_caching, install_fullchannel_caching
|
||||
|
||||
try:
|
||||
from .web import core
|
||||
|
@ -525,6 +525,7 @@ class Hikka:
|
|||
|
||||
install_entity_caching(client)
|
||||
install_perms_caching(client)
|
||||
install_fullchannel_caching(client)
|
||||
|
||||
self.clients += [client]
|
||||
except sqlite3.OperationalError:
|
||||
|
|
|
@ -11,8 +11,11 @@ import time
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
from xml.dom.minidom import Entity
|
||||
from telethon.hints import EntityLike
|
||||
from telethon import TelegramClient
|
||||
from telethon.tl.functions.channels import GetFullChannelRequest
|
||||
from telethon.tl.types import ChannelFull
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -43,7 +46,7 @@ class CacheRecord:
|
|||
return self._exp < time.time()
|
||||
|
||||
def __eq__(self, record: "CacheRecord"):
|
||||
return hash(record._hashable_entity) == hash(self._hashable_entity)
|
||||
return hash(record) == hash(self)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._hashable_entity)
|
||||
|
@ -72,19 +75,45 @@ class CacheRecordPerms:
|
|||
def expired(self):
|
||||
return self._exp < time.time()
|
||||
|
||||
def __eq__(self, record: "CacheRecord"):
|
||||
return hash((record._hashable_entity, record._hashable_user)) == hash(
|
||||
(self._hashable_entity, self._hashable_user)
|
||||
)
|
||||
def __eq__(self, record: "CacheRecordPerms"):
|
||||
return hash(record) == hash(self)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self._hashable_entity, self._hashable_user))
|
||||
|
||||
def __str__(self):
|
||||
return f"CacheRecord of {self.perms}"
|
||||
return f"CacheRecordPerms of {self.perms}"
|
||||
|
||||
def __repr__(self):
|
||||
return f"CacheRecord(perms={type(self.perms).__name__}(...), exp={self._exp})"
|
||||
return (
|
||||
f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})"
|
||||
)
|
||||
|
||||
|
||||
class CacheRecordFullChannel:
|
||||
def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int):
|
||||
self.channel_id = channel_id
|
||||
self.full_channel = full_channel
|
||||
self._exp = round(time.time() + exp)
|
||||
self.ts = time.time()
|
||||
|
||||
def expired(self):
|
||||
return self._exp < time.time()
|
||||
|
||||
def __eq__(self, record: "CacheRecordFullChannel"):
|
||||
return hash(record) == hash(self)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self._hashable_entity, self._hashable_user))
|
||||
|
||||
def __str__(self):
|
||||
return f"CacheRecordFullChannel of {self.channel_id}"
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"CacheRecordFullChannel(channel_id={self.channel_id}(...),"
|
||||
f" exp={self._exp})"
|
||||
)
|
||||
|
||||
|
||||
def install_entity_caching(client: TelegramClient):
|
||||
|
@ -93,15 +122,14 @@ def install_entity_caching(client: TelegramClient):
|
|||
old = client.get_entity
|
||||
|
||||
async def new(
|
||||
entity: EntityLike, exp: Optional[int] = 5 * 60, force: Optional[bool] = False
|
||||
entity: EntityLike,
|
||||
exp: Optional[int] = 5 * 60,
|
||||
force: Optional[bool] = False,
|
||||
):
|
||||
# Will be used to determine, which client caused logging messages
|
||||
# parsed via inspect.stack()
|
||||
_hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
|
||||
|
||||
if force:
|
||||
return await old(entity)
|
||||
|
||||
if not hashable(entity):
|
||||
try:
|
||||
hashable_entity = next(
|
||||
|
@ -121,7 +149,8 @@ def install_entity_caching(client: TelegramClient):
|
|||
hashable_entity = int(str(hashable_entity)[4:])
|
||||
|
||||
if (
|
||||
hashable_entity
|
||||
not force
|
||||
and hashable_entity
|
||||
and hashable_entity in client._hikka_entity_cache
|
||||
and (
|
||||
not exp
|
||||
|
@ -187,9 +216,6 @@ def install_perms_caching(client: TelegramClient):
|
|||
# parsed via inspect.stack()
|
||||
_hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
|
||||
|
||||
if force:
|
||||
return await old(entity, user)
|
||||
|
||||
entity = await client.get_entity(entity)
|
||||
user = await client.get_entity(user) if user else None
|
||||
|
||||
|
@ -226,7 +252,8 @@ def install_perms_caching(client: TelegramClient):
|
|||
hashable_user = int(str(hashable_user)[4:])
|
||||
|
||||
if (
|
||||
hashable_entity
|
||||
not force
|
||||
and hashable_entity
|
||||
and hashable_user
|
||||
and hashable_user in client._hikka_perms_cache.get(hashable_entity, {})
|
||||
and (
|
||||
|
@ -295,3 +322,66 @@ def install_perms_caching(client: TelegramClient):
|
|||
client.get_perms_cached = new
|
||||
asyncio.ensure_future(cleaner(client))
|
||||
logger.debug("Monkeypatched client with perms cacher")
|
||||
|
||||
|
||||
def install_fullchannel_caching(client: TelegramClient):
|
||||
client._hikka_fullchannel_cache = {}
|
||||
|
||||
async def get_fullchannel(
|
||||
entity: EntityLike,
|
||||
exp: Optional[int] = 300,
|
||||
force: Optional[bool] = False,
|
||||
) -> ChannelFull:
|
||||
"""
|
||||
Gets the FullChannelRequest and cache it
|
||||
:param channel_id: Channel to fetch ChannelFull of
|
||||
:param exp: Expiration time of the cache record and maximum time of already cached record
|
||||
:param force: Whether to force refresh the cache (make API request)
|
||||
:return: :obj:`FullChannel`
|
||||
"""
|
||||
if not hashable(entity):
|
||||
try:
|
||||
hashable_entity = next(
|
||||
getattr(entity, attr)
|
||||
for attr in {"channel_id", "chat_id", "id"}
|
||||
if getattr(entity, attr, None)
|
||||
)
|
||||
except StopIteration:
|
||||
logger.debug(
|
||||
f"Can't parse hashable from {entity=}, using legacy fullchannel request"
|
||||
)
|
||||
return await client(GetFullChannelRequest(channel=entity))
|
||||
else:
|
||||
hashable_entity = entity
|
||||
|
||||
if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
|
||||
hashable_entity = int(str(hashable_entity)[4:])
|
||||
|
||||
if (
|
||||
not force
|
||||
and client._hikka_fullchannel_cache.get(hashable_entity)
|
||||
and not client._hikka_fullchannel_cache[hashable_entity].expired()
|
||||
and client._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
|
||||
):
|
||||
return client._hikka_fullchannel_cache[hashable_entity].full_channel
|
||||
|
||||
result = await client(GetFullChannelRequest(channel=entity))
|
||||
client._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
|
||||
hashable_entity,
|
||||
result,
|
||||
exp,
|
||||
)
|
||||
return result
|
||||
|
||||
async def cleaner(client: TelegramClient):
|
||||
while True:
|
||||
for channel_id, record in client._hikka_fullchannel_cache.copy().items():
|
||||
if record.expired():
|
||||
del client._hikka_fullchannel_cache[channel_id]
|
||||
logger.debug(f"Cleaned outdated fullchannel cache {channel_id=}")
|
||||
|
||||
await asyncio.sleep(3)
|
||||
|
||||
client.get_fullchannel = get_fullchannel
|
||||
asyncio.ensure_future(cleaner(client))
|
||||
logger.debug("Monkeypatched client with fullchannel cacher")
|
||||
|
|
Loading…
Reference in New Issue