# ÂŠī¸ Dan Gazizullin, 2021-2023 # This file is a part of Hikka Userbot # 🌐 https://github.com/hikariatama/Hikka # You can redistribute it and/or modify it under the terms of the GNU AGPLv3 # 🔑 https://www.gnu.org/licenses/agpl-3.0.html import asyncio import io import json import logging import random import time import typing from hikkatl.tl import functions from hikkatl.tl.tlobject import TLRequest from hikkatl.tl.types import Message from hikkatl.utils import is_list_like from .. import loader, utils from ..inline.types import InlineCall from ..web.debugger import WebDebugger logger = logging.getLogger(__name__) GROUPS = [ "auth", "account", "users", "contacts", "messages", "updates", "photos", "upload", "help", "channels", "bots", "payments", "stickers", "phone", "langpack", "folders", "stats", ] CONSTRUCTORS = { (lambda x: x[0].lower() + x[1:])( method.__class__.__name__.rsplit("Request", 1)[0] ): method.CONSTRUCTOR_ID for method in utils.array_sum( [ [ method for method in dir(getattr(functions, group)) if isinstance(method, TLRequest) ] for group in GROUPS ] ) } @loader.tds class APIRatelimiterMod(loader.Module): """Helps userbot avoid spamming Telegram API""" strings = {"name": "APILimiter"} def __init__(self): self._ratelimiter: typing.List[tuple] = [] self._suspend_until = 0 self._lock = False self.config = loader.ModuleConfig( loader.ConfigValue( "time_sample", 15, lambda: self.strings("_cfg_time_sample"), validator=loader.validators.Integer(minimum=1), ), loader.ConfigValue( "threshold", 100, lambda: self.strings("_cfg_threshold"), validator=loader.validators.Integer(minimum=10), ), loader.ConfigValue( "local_floodwait", 30, lambda: self.strings("_cfg_local_floodwait"), validator=loader.validators.Integer(minimum=10, maximum=3600), ), loader.ConfigValue( "forbidden_methods", ["joinChannel", "importChatInvite"], lambda: self.strings("_cfg_forbidden_methods"), validator=loader.validators.MultiChoice( [ "sendReaction", "joinChannel", "importChatInvite", ] ), on_change=lambda: self._client.forbid_constructors( map( lambda x: CONSTRUCTORS[x], self.config["forbidden_constructors"], ) ), ), ) async def client_ready(self): asyncio.ensure_future(self._install_protection()) async def _install_protection(self): await asyncio.sleep(30) # Restart lock if hasattr(self._client._call, "_old_call_rewritten"): raise loader.SelfUnload("Already installed") old_call = self._client._call async def new_call( sender: "MTProtoSender", # type: ignore # noqa: F821 request: TLRequest, ordered: bool = False, flood_sleep_threshold: int = None, ): await asyncio.sleep(random.randint(1, 5) / 100) req = (request,) if not is_list_like(request) else request for r in req: if ( time.perf_counter() > self._suspend_until and not self.get( "disable_protection", True, ) and ( r.__module__.rsplit(".", maxsplit=1)[1] in {"messages", "account", "channels"} ) ): request_name = type(r).__name__ self._ratelimiter += [(request_name, time.perf_counter())] self._ratelimiter = list( filter( lambda x: time.perf_counter() - x[1] < int(self.config["time_sample"]), self._ratelimiter, ) ) if ( len(self._ratelimiter) > int(self.config["threshold"]) and not self._lock ): self._lock = True report = io.BytesIO( json.dumps( self._ratelimiter, indent=4, ).encode() ) report.name = "local_fw_report.json" await self.inline.bot.send_document( self.tg_id, report, caption=self.inline.sanitise_text( self.strings("warning").format( self.config["local_floodwait"], prefix=utils.escape_html(self.get_prefix()), ) ), ) # It is intented to use time.sleep instead of asyncio.sleep time.sleep(int(self.config["local_floodwait"])) self._lock = False return await old_call(sender, request, ordered, flood_sleep_threshold) self._client._call = new_call self._client._old_call_rewritten = old_call self._client._call._hikka_overwritten = True logger.debug("Successfully installed ratelimiter") async def on_unload(self): if hasattr(self._client, "_old_call_rewritten"): self._client._call = self._client._old_call_rewritten delattr(self._client, "_old_call_rewritten") logger.debug("Successfully uninstalled ratelimiter") @loader.command() async def suspend_api_protect(self, message: Message): if not (args := utils.get_args_raw(message)) or not args.isdigit(): await utils.answer(message, self.strings("args_invalid")) return self._suspend_until = time.perf_counter() + int(args) await utils.answer(message, self.strings("suspended_for").format(args)) @loader.command() async def api_fw_protection(self, message: Message): await self.inline.form( message=message, text=self.strings("u_sure"), reply_markup=[ {"text": self.strings("btn_no"), "action": "close"}, {"text": self.strings("btn_yes"), "callback": self._finish}, ], ) @property def _debugger(self) -> WebDebugger: return logging.getLogger().handlers[0].web_debugger async def _show_pin(self, call: InlineCall): await call.answer(f"Werkzeug PIN: {self._debugger.pin}", show_alert=True) @loader.command() async def debugger(self, message: Message): if not self._debugger: await utils.answer(message, self.strings("debugger_disabled")) return await self.inline.form( message=message, text=self.strings("web_pin"), reply_markup=[ [ { "text": self.strings("web_pin_btn"), "callback": self._show_pin, } ], [ {"text": self.strings("proxied_url"), "url": self._debugger.url}, { "text": self.strings("local_url"), "url": f"http://127.0.0.1:{self._debugger.port}", }, ], ], ) async def _finish(self, call: InlineCall): state = self.get("disable_protection", True) self.set("disable_protection", not state) await call.edit(self.strings("on" if state else "off"))