mirror of https://github.com/coddrago/Heroku
364 lines
12 KiB
Python
364 lines
12 KiB
Python
# ©️ Dan Gazizullin, 2021-2023
|
|
# This file is a part of Hikka Userbot
|
|
# 🌐 https://github.com/hikariatama/Hikka
|
|
# You can redistribute it and/or modify it under the terms of the GNU AGPLv3
|
|
# 🔑 https://www.gnu.org/licenses/agpl-3.0.html
|
|
|
|
# ©️ Codrago, 2024-2025
|
|
# This file is a part of Heroku Userbot
|
|
# 🌐 https://github.com/coddrago/Heroku
|
|
# You can redistribute it and/or modify it under the terms of the GNU AGPLv3
|
|
# 🔑 https://www.gnu.org/licenses/agpl-3.0.html
|
|
|
|
import asyncio
|
|
import collections
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
import typing
|
|
|
|
from herokutl.errors.rpcerrorlist import ChannelsTooMuchError
|
|
from herokutl.tl.types import Message, User
|
|
|
|
from . import main, utils
|
|
from .pointers import (
|
|
BaseSerializingMiddlewareDict,
|
|
BaseSerializingMiddlewareList,
|
|
NamedTupleMiddlewareDict,
|
|
NamedTupleMiddlewareList,
|
|
PointerDict,
|
|
PointerList,
|
|
)
|
|
from .tl_cache import CustomTelegramClient
|
|
from .types import JSONSerializable
|
|
|
|
__all__ = [
|
|
"Database",
|
|
"PointerList",
|
|
"PointerDict",
|
|
"NamedTupleMiddlewareDict",
|
|
"NamedTupleMiddlewareList",
|
|
"BaseSerializingMiddlewareDict",
|
|
"BaseSerializingMiddlewareList",
|
|
]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NoAssetsChannel(Exception):
|
|
"""Raised when trying to read/store asset with no asset channel present"""
|
|
|
|
|
|
class Database(dict):
|
|
def __init__(self, client: CustomTelegramClient):
|
|
super().__init__()
|
|
self._client: CustomTelegramClient = client
|
|
self._next_revision_call: int = 0
|
|
self._revisions: typing.List[dict] = []
|
|
self._assets: int = None
|
|
self._me: User = None
|
|
self._redis: redis.Redis = None
|
|
self._saving_task: asyncio.Future = None
|
|
|
|
def __repr__(self):
|
|
return object.__repr__(self)
|
|
|
|
def _redis_save_sync(self):
|
|
with self._redis.pipeline() as pipe:
|
|
pipe.set(
|
|
str(self._client.tg_id),
|
|
json.dumps(self, ensure_ascii=True),
|
|
)
|
|
pipe.execute()
|
|
|
|
async def remote_force_save(self) -> bool:
|
|
"""Force save database to remote endpoint without waiting"""
|
|
if not self._redis:
|
|
return False
|
|
|
|
await utils.run_sync(self._redis_save_sync)
|
|
logger.debug("Published db to Redis")
|
|
return True
|
|
|
|
async def _redis_save(self) -> bool:
|
|
"""Save database to redis"""
|
|
if not self._redis:
|
|
return False
|
|
|
|
await asyncio.sleep(5)
|
|
await utils.run_sync(self._redis_save_sync)
|
|
logger.debug("Published db to Redis")
|
|
self._saving_task = None
|
|
return True
|
|
|
|
async def redis_init(self) -> bool:
|
|
"""Init redis database"""
|
|
if REDIS_URI := (
|
|
os.environ.get("REDIS_URL") or main.get_config_key("redis_uri")
|
|
):
|
|
self._redis = redis.Redis.from_url(REDIS_URI)
|
|
else:
|
|
return False
|
|
|
|
async def init(self):
|
|
"""Asynchronous initialization unit"""
|
|
if os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"):
|
|
await self.redis_init()
|
|
|
|
self._db_file = main.BASE_PATH / f"config-{self._client.tg_id}.json"
|
|
self.read()
|
|
|
|
try:
|
|
self._assets, _ = await utils.asset_channel(
|
|
self._client,
|
|
"heroku-assets",
|
|
"🌆 Your Heroku assets will be stored here",
|
|
archive=True,
|
|
avatar="https://raw.githubusercontent.com/coddrago/Heroku/dev-test/assets/heroku-assets.png"
|
|
)
|
|
except ChannelsTooMuchError:
|
|
self._assets = None
|
|
logger.error(
|
|
"Can't find and/or create assets folder\n"
|
|
"This may cause several consequences, such as:\n"
|
|
"- Non working assets feature (e.g. notes)\n"
|
|
"- This error will occur every restart\n\n"
|
|
"You can solve this by leaving some channels/groups"
|
|
)
|
|
|
|
def read(self):
|
|
"""Read database and stores it in self"""
|
|
if self._redis:
|
|
try:
|
|
self.update(
|
|
**json.loads(
|
|
self._redis.get(
|
|
str(self._client.tg_id),
|
|
).decode(),
|
|
)
|
|
)
|
|
except Exception:
|
|
logger.exception("Error reading redis database")
|
|
return
|
|
|
|
try:
|
|
db = self._db_file.read_text()
|
|
if re.search(r'"(hikka\.)(\S+\":)', db):
|
|
logging.warning("Converting db after update")
|
|
db = re.sub(r'(hikka\.)(\S+\":)', lambda m: 'heroku.' + m.group(2), db)
|
|
self.update(**json.loads(db))
|
|
except json.decoder.JSONDecodeError:
|
|
logger.warning("Database read failed! Creating new one...")
|
|
except FileNotFoundError:
|
|
logger.debug("Database file not found, creating new one...")
|
|
|
|
def process_db_autofix(self, db: dict) -> bool:
|
|
if not utils.is_serializable(db):
|
|
return False
|
|
|
|
for key, value in db.copy().items():
|
|
if not isinstance(key, (str, int)):
|
|
logger.warning(
|
|
"DbAutoFix: Dropped key %s, because it is not string or int",
|
|
key,
|
|
)
|
|
continue
|
|
|
|
if not isinstance(value, dict):
|
|
# If value is not a dict (module values), drop it,
|
|
# otherwise it may cause problems
|
|
del db[key]
|
|
logger.warning(
|
|
"DbAutoFix: Dropped key %s, because it is non-dict, but %s",
|
|
key,
|
|
type(value),
|
|
)
|
|
continue
|
|
|
|
for subkey in value:
|
|
if not isinstance(subkey, (str, int)):
|
|
del db[key][subkey]
|
|
logger.warning(
|
|
(
|
|
"DbAutoFix: Dropped subkey %s of db key %s, because it is"
|
|
" not string or int"
|
|
),
|
|
subkey,
|
|
key,
|
|
)
|
|
continue
|
|
|
|
return True
|
|
|
|
def save(self) -> bool:
|
|
"""Save database"""
|
|
if not self.process_db_autofix(self):
|
|
try:
|
|
rev = self._revisions.pop()
|
|
while not self.process_db_autofix(rev):
|
|
rev = self._revisions.pop()
|
|
except IndexError:
|
|
raise RuntimeError(
|
|
"Can't find revision to restore broken database from "
|
|
"database is most likely broken and will lead to problems, "
|
|
"so its save is forbidden."
|
|
)
|
|
|
|
self.clear()
|
|
self.update(**rev)
|
|
|
|
raise RuntimeError(
|
|
"Rewriting database to the last revision because new one destructed it"
|
|
)
|
|
|
|
if self._next_revision_call < time.time():
|
|
self._revisions += [dict(self)]
|
|
self._next_revision_call = time.time() + 3
|
|
|
|
while len(self._revisions) > 15:
|
|
self._revisions.pop()
|
|
|
|
if self._redis:
|
|
if not self._saving_task:
|
|
self._saving_task = asyncio.ensure_future(self._redis_save())
|
|
return True
|
|
|
|
try:
|
|
self._db_file.write_text(json.dumps(self, indent=4))
|
|
except Exception:
|
|
logger.exception("Database save failed!")
|
|
return False
|
|
|
|
return True
|
|
|
|
async def store_asset(self, message: Message) -> int:
|
|
"""
|
|
Save assets
|
|
returns asset_id as integer
|
|
"""
|
|
if not self._assets:
|
|
raise NoAssetsChannel("Tried to save asset to non-existing asset channel")
|
|
|
|
return (
|
|
(await self._client.send_message(self._assets, message)).id
|
|
if isinstance(message, Message)
|
|
else (
|
|
await self._client.send_message(
|
|
self._assets,
|
|
file=message,
|
|
force_document=True,
|
|
)
|
|
).id
|
|
)
|
|
|
|
async def fetch_asset(self, asset_id: int) -> typing.Optional[Message]:
|
|
"""Fetch previously saved asset by its asset_id"""
|
|
if not self._assets:
|
|
raise NoAssetsChannel(
|
|
"Tried to fetch asset from non-existing asset channel"
|
|
)
|
|
|
|
asset = await self._client.get_messages(self._assets, ids=[asset_id])
|
|
|
|
return asset[0] if asset else None
|
|
|
|
def get(
|
|
self,
|
|
owner: str,
|
|
key: str,
|
|
default: typing.Optional[JSONSerializable] = None,
|
|
) -> JSONSerializable:
|
|
"""Get database key"""
|
|
try:
|
|
return self[owner][key]
|
|
except KeyError:
|
|
return default
|
|
|
|
def set(self, owner: str, key: str, value: JSONSerializable) -> bool:
|
|
"""Set database key"""
|
|
if not utils.is_serializable(owner):
|
|
raise RuntimeError(
|
|
"Attempted to write object to "
|
|
f"{owner=} ({type(owner)=}) of database. It is not "
|
|
"JSON-serializable key which will cause errors"
|
|
)
|
|
|
|
if not utils.is_serializable(key):
|
|
raise RuntimeError(
|
|
"Attempted to write object to "
|
|
f"{key=} ({type(key)=}) of database. It is not "
|
|
"JSON-serializable key which will cause errors"
|
|
)
|
|
|
|
if not utils.is_serializable(value):
|
|
raise RuntimeError(
|
|
"Attempted to write object of "
|
|
f"{key=} ({type(value)=}) to database. It is not "
|
|
"JSON-serializable value which will cause errors"
|
|
)
|
|
|
|
super().setdefault(owner, {})[key] = value
|
|
return self.save()
|
|
|
|
def pointer(
|
|
self,
|
|
owner: str,
|
|
key: str,
|
|
default: typing.Optional[JSONSerializable] = None,
|
|
item_type: typing.Optional[typing.Any] = None,
|
|
) -> typing.Union[JSONSerializable, PointerList, PointerDict]:
|
|
"""Get a pointer to database key"""
|
|
value = self.get(owner, key, default)
|
|
mapping = {
|
|
list: PointerList,
|
|
dict: PointerDict,
|
|
collections.abc.Hashable: lambda v: v,
|
|
}
|
|
|
|
pointer_constructor = next(
|
|
(pointer for type_, pointer in mapping.items() if isinstance(value, type_)),
|
|
None,
|
|
)
|
|
|
|
if (current_value := self.get(owner, key, None)) and type(
|
|
current_value
|
|
) is not type(default):
|
|
raise ValueError(
|
|
f"Can't switch the type of pointer in database (current: {type(current_value)}, requested: {type(default)})"
|
|
)
|
|
|
|
if pointer_constructor is None:
|
|
raise ValueError(
|
|
f"Pointer for type {type(value).__name__} is not implemented"
|
|
)
|
|
|
|
if item_type is not None:
|
|
if isinstance(value, list):
|
|
for item in self.get(owner, key, default):
|
|
if not isinstance(item, dict):
|
|
raise ValueError(
|
|
"Item type can only be specified for dedicated keys and"
|
|
" can't be mixed with other ones"
|
|
)
|
|
|
|
return NamedTupleMiddlewareList(
|
|
pointer_constructor(self, owner, key, default),
|
|
item_type,
|
|
)
|
|
if isinstance(value, dict):
|
|
for item in self.get(owner, key, default).values():
|
|
if not isinstance(item, dict):
|
|
raise ValueError(
|
|
"Item type can only be specified for dedicated keys and"
|
|
" can't be mixed with other ones"
|
|
)
|
|
|
|
return NamedTupleMiddlewareDict(
|
|
pointer_constructor(self, owner, key, default),
|
|
item_type,
|
|
)
|
|
|
|
return pointer_constructor(self, owner, key, default)
|