Heroku/heroku/modules/api_protection.py

249 lines
8.5 KiB
Python

# ©️ Dan Gazizullin, 2021-2023
# This file is a part of Hikka Userbot
# 🌐 https://github.com/hikariatama/Hikka
# You can redistribute it and/or modify it under the terms of the GNU AGPLv3
# 🔑 https://www.gnu.org/licenses/agpl-3.0.html
# ©️ Codrago, 2024-2025
# This file is a part of Heroku Userbot
# 🌐 https://github.com/coddrago/Heroku
# You can redistribute it and/or modify it under the terms of the GNU AGPLv3
# 🔑 https://www.gnu.org/licenses/agpl-3.0.html
import asyncio
import io
import json
import logging
import random
import time
import typing
from herokutl.tl import functions
from herokutl.tl.tlobject import TLRequest
from herokutl.tl.types import Message
from herokutl.utils import is_list_like
from .. import loader, utils
from ..inline.types import InlineCall
from ..web.debugger import WebDebugger
logger = logging.getLogger(__name__)
GROUPS = [
"auth",
"account",
"users",
"contacts",
"messages",
"updates",
"photos",
"upload",
"help",
"channels",
"bots",
"payments",
"stickers",
"phone",
"langpack",
"folders",
"stats",
]
CONSTRUCTORS = {
(entity_name[0].lower() + entity_name[1:]).rsplit("Request", 1)[0]: getattr(cur_entity, "CONSTRUCTOR_ID")
for group in GROUPS
for entity_name in dir(getattr(functions, group))
if hasattr((cur_entity := getattr(getattr(functions, group), entity_name)), "__bases__")
and TLRequest in cur_entity.__bases__
and hasattr(cur_entity, "CONSTRUCTOR_ID")
}
@loader.tds
class APIRatelimiterMod(loader.Module):
"""Helps userbot avoid spamming Telegram API"""
strings = {"name": "APILimiter"}
def __init__(self):
self._ratelimiter: typing.List[tuple] = []
self._suspend_until = 0
self._lock = False
self.config = loader.ModuleConfig(
loader.ConfigValue(
"time_sample",
15,
lambda: self.strings("_cfg_time_sample"),
validator=loader.validators.Integer(minimum=1),
),
loader.ConfigValue(
"threshold",
100,
lambda: self.strings("_cfg_threshold"),
validator=loader.validators.Integer(minimum=10),
),
loader.ConfigValue(
"local_floodwait",
30,
lambda: self.strings("_cfg_local_floodwait"),
validator=loader.validators.Integer(minimum=10, maximum=3600),
),
loader.ConfigValue(
"forbidden_methods",
["joinChannel", "importChatInvite"],
lambda: self.strings("_cfg_forbidden_methods"),
validator=loader.validators.MultiChoice(
[
"sendReaction",
"joinChannel",
"importChatInvite",
]
),
on_change=self.on_forbidden_methods_update
),
)
async def client_ready(self):
asyncio.ensure_future(self._install_protection())
async def on_forbidden_methods_update(self):
self._client.forbid_constructors(list(map(lambda x: CONSTRUCTORS[x], self.config['forbidden_methods'], )))
async def _install_protection(self):
await asyncio.sleep(30) # Restart lock
if hasattr(self._client._call, "_old_call_rewritten"):
raise loader.SelfUnload("Already installed")
old_call = self._client._call
async def new_call(
sender: "MTProtoSender", # type: ignore # noqa: F821
request: TLRequest,
ordered: bool = False,
flood_sleep_threshold: int = None,
):
await asyncio.sleep(random.randint(1, 5) / 100)
req = (request,) if not is_list_like(request) else request
for r in req:
if (
time.perf_counter() > self._suspend_until
and not self.get(
"disable_protection",
True,
)
and (
r.__module__.rsplit(".", maxsplit=1)[1]
in {"messages", "account", "channels"}
)
):
request_name = type(r).__name__
self._ratelimiter += [(request_name, time.perf_counter())]
self._ratelimiter = list(
filter(
lambda x: time.perf_counter() - x[1]
< int(self.config["time_sample"]),
self._ratelimiter,
)
)
if (
len(self._ratelimiter) > int(self.config["threshold"])
and not self._lock
):
self._lock = True
report = io.BytesIO(
json.dumps(
self._ratelimiter,
indent=4,
).encode()
)
report.name = "local_fw_report.json"
await self.inline.bot.send_document(
self.tg_id,
report,
caption=self.inline.sanitise_text(
self.strings("warning").format(
self.config["local_floodwait"],
prefix=utils.escape_html(self.get_prefix()),
)
),
)
# It is intented to use time.sleep instead of asyncio.sleep
time.sleep(int(self.config["local_floodwait"]))
self._lock = False
return await old_call(sender, request, ordered, flood_sleep_threshold)
self._client._call = new_call
self._client._old_call_rewritten = old_call
self._client._call._heroku_overwritten = True
logger.debug("Successfully installed ratelimiter")
async def on_unload(self):
if hasattr(self._client, "_old_call_rewritten"):
self._client._call = self._client._old_call_rewritten
delattr(self._client, "_old_call_rewritten")
logger.debug("Successfully uninstalled ratelimiter")
@loader.command()
async def suspend_api_protect(self, message: Message):
if not (args := utils.get_args_raw(message)) or not args.isdigit():
await utils.answer(message, self.strings("args_invalid"))
return
self._suspend_until = time.perf_counter() + int(args)
await utils.answer(message, self.strings("suspended_for").format(args))
@loader.command()
async def api_fw_protection(self, message: Message):
await self.inline.form(
message=message,
text=self.strings("u_sure"),
reply_markup=[
{"text": self.strings("btn_no"), "action": "close"},
{"text": self.strings("btn_yes"), "callback": self._finish},
],
)
@property
def _debugger(self) -> WebDebugger:
return logging.getLogger().handlers[0].web_debugger
async def _show_pin(self, call: InlineCall):
self.inline.bot(await call.answer(f"Werkzeug PIN: {self._debugger.pin}", show_alert=True))
@loader.command()
async def debugger(self, message: Message):
if not self._debugger:
await utils.answer(message, self.strings("debugger_disabled"))
return
await self.inline.form(
message=message,
text=self.strings("web_pin"),
reply_markup=[
[
{
"text": self.strings("web_pin_btn"),
"callback": self._show_pin,
}
],
[
{"text": self.strings("proxied_url"), "url": self._debugger.url},
{
"text": self.strings("local_url"),
"url": f"http://127.0.0.1:{self._debugger.port}",
},
],
],
)
async def _finish(self, call: InlineCall):
state = self.get("disable_protection", True)
self.set("disable_protection", not state)
await call.edit(self.strings("on" if state else "off"))