mirror of https://github.com/coddrago/Heroku
249 lines
8.5 KiB
Python
249 lines
8.5 KiB
Python
# ©️ Dan Gazizullin, 2021-2023
|
|
# This file is a part of Hikka Userbot
|
|
# 🌐 https://github.com/hikariatama/Hikka
|
|
# You can redistribute it and/or modify it under the terms of the GNU AGPLv3
|
|
# 🔑 https://www.gnu.org/licenses/agpl-3.0.html
|
|
|
|
# ©️ Codrago, 2024-2025
|
|
# This file is a part of Heroku Userbot
|
|
# 🌐 https://github.com/coddrago/Heroku
|
|
# You can redistribute it and/or modify it under the terms of the GNU AGPLv3
|
|
# 🔑 https://www.gnu.org/licenses/agpl-3.0.html
|
|
|
|
import asyncio
|
|
import io
|
|
import json
|
|
import logging
|
|
import random
|
|
import time
|
|
import typing
|
|
|
|
from herokutl.tl import functions
|
|
from herokutl.tl.tlobject import TLRequest
|
|
from herokutl.tl.types import Message
|
|
from herokutl.utils import is_list_like
|
|
|
|
from .. import loader, utils
|
|
from ..inline.types import InlineCall
|
|
from ..web.debugger import WebDebugger
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
GROUPS = [
|
|
"auth",
|
|
"account",
|
|
"users",
|
|
"contacts",
|
|
"messages",
|
|
"updates",
|
|
"photos",
|
|
"upload",
|
|
"help",
|
|
"channels",
|
|
"bots",
|
|
"payments",
|
|
"stickers",
|
|
"phone",
|
|
"langpack",
|
|
"folders",
|
|
"stats",
|
|
]
|
|
|
|
|
|
CONSTRUCTORS = {
|
|
(entity_name[0].lower() + entity_name[1:]).rsplit("Request", 1)[0]: getattr(cur_entity, "CONSTRUCTOR_ID")
|
|
for group in GROUPS
|
|
for entity_name in dir(getattr(functions, group))
|
|
if hasattr((cur_entity := getattr(getattr(functions, group), entity_name)), "__bases__")
|
|
and TLRequest in cur_entity.__bases__
|
|
and hasattr(cur_entity, "CONSTRUCTOR_ID")
|
|
}
|
|
|
|
|
|
@loader.tds
|
|
class APIRatelimiterMod(loader.Module):
|
|
"""Helps userbot avoid spamming Telegram API"""
|
|
|
|
strings = {"name": "APILimiter"}
|
|
|
|
def __init__(self):
|
|
self._ratelimiter: typing.List[tuple] = []
|
|
self._suspend_until = 0
|
|
self._lock = False
|
|
self.config = loader.ModuleConfig(
|
|
loader.ConfigValue(
|
|
"time_sample",
|
|
15,
|
|
lambda: self.strings("_cfg_time_sample"),
|
|
validator=loader.validators.Integer(minimum=1),
|
|
),
|
|
loader.ConfigValue(
|
|
"threshold",
|
|
100,
|
|
lambda: self.strings("_cfg_threshold"),
|
|
validator=loader.validators.Integer(minimum=10),
|
|
),
|
|
loader.ConfigValue(
|
|
"local_floodwait",
|
|
30,
|
|
lambda: self.strings("_cfg_local_floodwait"),
|
|
validator=loader.validators.Integer(minimum=10, maximum=3600),
|
|
),
|
|
loader.ConfigValue(
|
|
"forbidden_methods",
|
|
["joinChannel", "importChatInvite"],
|
|
lambda: self.strings("_cfg_forbidden_methods"),
|
|
validator=loader.validators.MultiChoice(
|
|
[
|
|
"sendReaction",
|
|
"joinChannel",
|
|
"importChatInvite",
|
|
]
|
|
),
|
|
on_change=self.on_forbidden_methods_update
|
|
),
|
|
)
|
|
|
|
async def client_ready(self):
|
|
asyncio.ensure_future(self._install_protection())
|
|
|
|
async def on_forbidden_methods_update(self):
|
|
self._client.forbid_constructors(list(map(lambda x: CONSTRUCTORS[x], self.config['forbidden_methods'], )))
|
|
|
|
async def _install_protection(self):
|
|
await asyncio.sleep(30) # Restart lock
|
|
if hasattr(self._client._call, "_old_call_rewritten"):
|
|
raise loader.SelfUnload("Already installed")
|
|
|
|
old_call = self._client._call
|
|
|
|
async def new_call(
|
|
sender: "MTProtoSender", # type: ignore # noqa: F821
|
|
request: TLRequest,
|
|
ordered: bool = False,
|
|
flood_sleep_threshold: int = None,
|
|
):
|
|
await asyncio.sleep(random.randint(1, 5) / 100)
|
|
req = (request,) if not is_list_like(request) else request
|
|
for r in req:
|
|
if (
|
|
time.perf_counter() > self._suspend_until
|
|
and not self.get(
|
|
"disable_protection",
|
|
True,
|
|
)
|
|
and (
|
|
r.__module__.rsplit(".", maxsplit=1)[1]
|
|
in {"messages", "account", "channels"}
|
|
)
|
|
):
|
|
request_name = type(r).__name__
|
|
self._ratelimiter += [(request_name, time.perf_counter())]
|
|
|
|
self._ratelimiter = list(
|
|
filter(
|
|
lambda x: time.perf_counter() - x[1]
|
|
< int(self.config["time_sample"]),
|
|
self._ratelimiter,
|
|
)
|
|
)
|
|
|
|
if (
|
|
len(self._ratelimiter) > int(self.config["threshold"])
|
|
and not self._lock
|
|
):
|
|
self._lock = True
|
|
report = io.BytesIO(
|
|
json.dumps(
|
|
self._ratelimiter,
|
|
indent=4,
|
|
).encode()
|
|
)
|
|
report.name = "local_fw_report.json"
|
|
|
|
await self.inline.bot.send_document(
|
|
self.tg_id,
|
|
report,
|
|
caption=self.inline.sanitise_text(
|
|
self.strings("warning").format(
|
|
self.config["local_floodwait"],
|
|
prefix=utils.escape_html(self.get_prefix()),
|
|
)
|
|
),
|
|
)
|
|
|
|
# It is intented to use time.sleep instead of asyncio.sleep
|
|
time.sleep(int(self.config["local_floodwait"]))
|
|
self._lock = False
|
|
|
|
return await old_call(sender, request, ordered, flood_sleep_threshold)
|
|
|
|
self._client._call = new_call
|
|
self._client._old_call_rewritten = old_call
|
|
self._client._call._heroku_overwritten = True
|
|
logger.debug("Successfully installed ratelimiter")
|
|
|
|
async def on_unload(self):
|
|
if hasattr(self._client, "_old_call_rewritten"):
|
|
self._client._call = self._client._old_call_rewritten
|
|
delattr(self._client, "_old_call_rewritten")
|
|
logger.debug("Successfully uninstalled ratelimiter")
|
|
|
|
@loader.command()
|
|
async def suspend_api_protect(self, message: Message):
|
|
if not (args := utils.get_args_raw(message)) or not args.isdigit():
|
|
await utils.answer(message, self.strings("args_invalid"))
|
|
return
|
|
|
|
self._suspend_until = time.perf_counter() + int(args)
|
|
await utils.answer(message, self.strings("suspended_for").format(args))
|
|
|
|
@loader.command()
|
|
async def api_fw_protection(self, message: Message):
|
|
await self.inline.form(
|
|
message=message,
|
|
text=self.strings("u_sure"),
|
|
reply_markup=[
|
|
{"text": self.strings("btn_no"), "action": "close"},
|
|
{"text": self.strings("btn_yes"), "callback": self._finish},
|
|
],
|
|
)
|
|
|
|
@property
|
|
def _debugger(self) -> WebDebugger:
|
|
return logging.getLogger().handlers[0].web_debugger
|
|
|
|
async def _show_pin(self, call: InlineCall):
|
|
self.inline.bot(await call.answer(f"Werkzeug PIN: {self._debugger.pin}", show_alert=True))
|
|
|
|
@loader.command()
|
|
async def debugger(self, message: Message):
|
|
if not self._debugger:
|
|
await utils.answer(message, self.strings("debugger_disabled"))
|
|
return
|
|
|
|
await self.inline.form(
|
|
message=message,
|
|
text=self.strings("web_pin"),
|
|
reply_markup=[
|
|
[
|
|
{
|
|
"text": self.strings("web_pin_btn"),
|
|
"callback": self._show_pin,
|
|
}
|
|
],
|
|
[
|
|
{"text": self.strings("proxied_url"), "url": self._debugger.url},
|
|
{
|
|
"text": self.strings("local_url"),
|
|
"url": f"http://127.0.0.1:{self._debugger.port}",
|
|
},
|
|
],
|
|
],
|
|
)
|
|
|
|
async def _finish(self, call: InlineCall):
|
|
state = self.get("disable_protection", True)
|
|
self.set("disable_protection", not state)
|
|
await call.edit(self.strings("on" if state else "off"))
|