# ÂŠī¸ Dan Gazizullin, 2021-2023 # This file is a part of Hikka Userbot # 🌐 https://github.com/hikariatama/Hikka # You can redistribute it and/or modify it under the terms of the GNU AGPLv3 # 🔑 https://www.gnu.org/licenses/agpl-3.0.html # ÂŠī¸ Codrago, 2024-2025 # This file is a part of Heroku Userbot # 🌐 https://github.com/coddrago/Heroku # You can redistribute it and/or modify it under the terms of the GNU AGPLv3 # 🔑 https://www.gnu.org/licenses/agpl-3.0.html import asyncio import collections import json import logging import os import re import time import typing from herokutl.errors.rpcerrorlist import ChannelsTooMuchError from herokutl.tl.types import Message, User from . import main, utils from .pointers import ( BaseSerializingMiddlewareDict, BaseSerializingMiddlewareList, NamedTupleMiddlewareDict, NamedTupleMiddlewareList, PointerDict, PointerList, ) from .tl_cache import CustomTelegramClient from .types import JSONSerializable __all__ = [ "Database", "PointerList", "PointerDict", "NamedTupleMiddlewareDict", "NamedTupleMiddlewareList", "BaseSerializingMiddlewareDict", "BaseSerializingMiddlewareList", ] logger = logging.getLogger(__name__) class NoAssetsChannel(Exception): """Raised when trying to read/store asset with no asset channel present""" class Database(dict): def __init__(self, client: CustomTelegramClient): super().__init__() self._client: CustomTelegramClient = client self._next_revision_call: int = 0 self._revisions: typing.List[dict] = [] self._assets: int = None self._me: User = None self._redis: redis.Redis = None self._saving_task: asyncio.Future = None def __repr__(self): return object.__repr__(self) def _redis_save_sync(self): with self._redis.pipeline() as pipe: pipe.set( str(self._client.tg_id), json.dumps(self, ensure_ascii=True), ) pipe.execute() async def remote_force_save(self) -> bool: """Force save database to remote endpoint without waiting""" if not self._redis: return False await utils.run_sync(self._redis_save_sync) logger.debug("Published db to Redis") return True async def _redis_save(self) -> bool: """Save database to redis""" if not self._redis: return False await asyncio.sleep(5) await utils.run_sync(self._redis_save_sync) logger.debug("Published db to Redis") self._saving_task = None return True async def redis_init(self) -> bool: """Init redis database""" if REDIS_URI := ( os.environ.get("REDIS_URL") or main.get_config_key("redis_uri") ): self._redis = redis.Redis.from_url(REDIS_URI) else: return False async def init(self): """Asynchronous initialization unit""" if os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"): await self.redis_init() self._db_file = main.BASE_PATH / f"config-{self._client.tg_id}.json" self.read() try: self._assets, _ = await utils.asset_channel( self._client, "heroku-assets", "🌆 Your Heroku assets will be stored here", archive=True, avatar="https://raw.githubusercontent.com/coddrago/Heroku/dev-test/assets/heroku-assets.png" ) except ChannelsTooMuchError: self._assets = None logger.error( "Can't find and/or create assets folder\n" "This may cause several consequences, such as:\n" "- Non working assets feature (e.g. notes)\n" "- This error will occur every restart\n\n" "You can solve this by leaving some channels/groups" ) def read(self): """Read database and stores it in self""" if self._redis: try: self.update( **json.loads( self._redis.get( str(self._client.tg_id), ).decode(), ) ) except Exception: logger.exception("Error reading redis database") return try: db = self._db_file.read_text() if re.search(r'"(hikka\.)(\S+\":)', db): logging.warning("Converting db after update") db = re.sub(r'(hikka\.)(\S+\":)', lambda m: 'heroku.' + m.group(2), db) self.update(**json.loads(db)) except json.decoder.JSONDecodeError: logger.warning("Database read failed! Creating new one...") except FileNotFoundError: logger.debug("Database file not found, creating new one...") def process_db_autofix(self, db: dict) -> bool: if not utils.is_serializable(db): return False for key, value in db.copy().items(): if not isinstance(key, (str, int)): logger.warning( "DbAutoFix: Dropped key %s, because it is not string or int", key, ) continue if not isinstance(value, dict): # If value is not a dict (module values), drop it, # otherwise it may cause problems del db[key] logger.warning( "DbAutoFix: Dropped key %s, because it is non-dict, but %s", key, type(value), ) continue for subkey in value: if not isinstance(subkey, (str, int)): del db[key][subkey] logger.warning( ( "DbAutoFix: Dropped subkey %s of db key %s, because it is" " not string or int" ), subkey, key, ) continue return True def save(self) -> bool: """Save database""" if not self.process_db_autofix(self): try: rev = self._revisions.pop() while not self.process_db_autofix(rev): rev = self._revisions.pop() except IndexError: raise RuntimeError( "Can't find revision to restore broken database from " "database is most likely broken and will lead to problems, " "so its save is forbidden." ) self.clear() self.update(**rev) raise RuntimeError( "Rewriting database to the last revision because new one destructed it" ) if self._next_revision_call < time.time(): self._revisions += [dict(self)] self._next_revision_call = time.time() + 3 while len(self._revisions) > 15: self._revisions.pop() if self._redis: if not self._saving_task: self._saving_task = asyncio.ensure_future(self._redis_save()) return True try: self._db_file.write_text(json.dumps(self, indent=4)) except Exception: logger.exception("Database save failed!") return False return True async def store_asset(self, message: Message) -> int: """ Save assets returns asset_id as integer """ if not self._assets: raise NoAssetsChannel("Tried to save asset to non-existing asset channel") return ( (await self._client.send_message(self._assets, message)).id if isinstance(message, Message) else ( await self._client.send_message( self._assets, file=message, force_document=True, ) ).id ) async def fetch_asset(self, asset_id: int) -> typing.Optional[Message]: """Fetch previously saved asset by its asset_id""" if not self._assets: raise NoAssetsChannel( "Tried to fetch asset from non-existing asset channel" ) asset = await self._client.get_messages(self._assets, ids=[asset_id]) return asset[0] if asset else None def get( self, owner: str, key: str, default: typing.Optional[JSONSerializable] = None, ) -> JSONSerializable: """Get database key""" try: return self[owner][key] except KeyError: return default def set(self, owner: str, key: str, value: JSONSerializable) -> bool: """Set database key""" if not utils.is_serializable(owner): raise RuntimeError( "Attempted to write object to " f"{owner=} ({type(owner)=}) of database. It is not " "JSON-serializable key which will cause errors" ) if not utils.is_serializable(key): raise RuntimeError( "Attempted to write object to " f"{key=} ({type(key)=}) of database. It is not " "JSON-serializable key which will cause errors" ) if not utils.is_serializable(value): raise RuntimeError( "Attempted to write object of " f"{key=} ({type(value)=}) to database. It is not " "JSON-serializable value which will cause errors" ) super().setdefault(owner, {})[key] = value return self.save() def pointer( self, owner: str, key: str, default: typing.Optional[JSONSerializable] = None, item_type: typing.Optional[typing.Any] = None, ) -> typing.Union[JSONSerializable, PointerList, PointerDict]: """Get a pointer to database key""" value = self.get(owner, key, default) mapping = { list: PointerList, dict: PointerDict, collections.abc.Hashable: lambda v: v, } pointer_constructor = next( (pointer for type_, pointer in mapping.items() if isinstance(value, type_)), None, ) if (current_value := self.get(owner, key, None)) and type( current_value ) is not type(default): raise ValueError( f"Can't switch the type of pointer in database (current: {type(current_value)}, requested: {type(default)})" ) if pointer_constructor is None: raise ValueError( f"Pointer for type {type(value).__name__} is not implemented" ) if item_type is not None: if isinstance(value, list): for item in self.get(owner, key, default): if not isinstance(item, dict): raise ValueError( "Item type can only be specified for dedicated keys and" " can't be mixed with other ones" ) return NamedTupleMiddlewareList( pointer_constructor(self, owner, key, default), item_type, ) if isinstance(value, dict): for item in self.get(owner, key, default).values(): if not isinstance(item, dict): raise ValueError( "Item type can only be specified for dedicated keys and" " can't be mixed with other ones" ) return NamedTupleMiddlewareDict( pointer_constructor(self, owner, key, default), item_type, ) return pointer_constructor(self, owner, key, default)