import asyncio import atexit import contextlib import copy import dataclasses import datetime import functools import io import json import logging import os import re import signal import struct import sys import time from hashlib import sha256 from logging.handlers import RotatingFileHandler from pathlib import Path from typing import Dict, List, Optional, Union from zlib import crc32 from hikkatl.crypto import AES, AuthKey from hikkatl.errors import ( AuthKeyNotFound, BadMessageError, InvalidBufferError, InvalidChecksumError, SecurityError, TypeNotFoundError, ) from hikkatl.extensions.binaryreader import BinaryReader from hikkatl.extensions.messagepacker import MessagePacker from hikkatl.network.connection import ConnectionTcpFull from hikkatl.network.mtprotosender import MTProtoSender from hikkatl.network.mtprotostate import MTProtoState from hikkatl.network.requeststate import RequestState from hikkatl.sessions import SQLiteSession from hikkatl.tl import TLRequest from hikkatl.tl.core import GzipPacked, MessageContainer, TLMessage from hikkatl.tl.functions import ( InitConnectionRequest, InvokeAfterMsgRequest, InvokeAfterMsgsRequest, InvokeWithLayerRequest, InvokeWithMessagesRangeRequest, InvokeWithoutUpdatesRequest, InvokeWithTakeoutRequest, PingRequest, ) from hikkatl.tl.functions.account import DeleteAccountRequest, UpdateProfileRequest from hikkatl.tl.functions.auth import ( BindTempAuthKeyRequest, CancelCodeRequest, CheckRecoveryPasswordRequest, ExportAuthorizationRequest, ExportLoginTokenRequest, ImportAuthorizationRequest, LogOutRequest, RecoverPasswordRequest, RequestPasswordRecoveryRequest, ResetAuthorizationsRequest, ResetLoginEmailRequest, SendCodeRequest, ) from hikkatl.tl.functions.help import GetConfigRequest from hikkatl.tl.functions.messages import ( ForwardMessagesRequest, GetHistoryRequest, SearchRequest, ) from hikkatl.tl.types import ( InputPeerUser, Message, MsgsAck, PeerUser, UpdateNewMessage, Updates, UpdateShortMessage, ) os.chdir(os.path.dirname(os.path.abspath(__file__))) logging.basicConfig(level=logging.DEBUG) handler = logging.StreamHandler() handler.setLevel(logging.INFO) handler.setFormatter( logging.Formatter( fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", style="%", ) ) rotating_handler = RotatingFileHandler( filename="hikka.log", mode="a", maxBytes=10 * 1024 * 1024, backupCount=1, encoding="utf-8", delay=0, ) rotating_handler.setLevel(logging.DEBUG) rotating_handler.setFormatter( logging.Formatter( fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", style="%", ) ) logging.getLogger().handlers[0].setLevel(logging.CRITICAL) logging.getLogger().addHandler(handler) logging.getLogger().addHandler(rotating_handler) class _OpaqueRequest(TLRequest): def __init__(self, data: bytes): self.data = data def _bytes(self): return self.data class CustomMTProtoState(MTProtoState): def write_data_as_message( self, buffer, data, content_related, *, after_id=None, msg_id=None, ): msg_id = msg_id or self._get_new_msg_id() seq_no = self._get_seq_no(content_related) if after_id is None: body = GzipPacked.gzip_if_smaller(content_related, data) else: body = GzipPacked.gzip_if_smaller( content_related, bytes(InvokeAfterMsgRequest(after_id, _OpaqueRequest(data))), ) buffer.write(struct.pack(" 1: data = ( struct.pack(" Union[str, bool]: """ Parse and return key from config :param key: Key name in config :return: Value of config key or `False`, if it doesn't exist """ try: return json.loads(Path("../config.json").read_text()).get(key, False) except FileNotFoundError: return False start_ts = time.perf_counter() class CustomMTProtoSender(MTProtoSender): def __init__( self, auth_key, *, loggers, retries=5, delay=1, auto_reconnect=True, connect_timeout=None, auth_key_callback=None, updates_queue=None, auto_reconnect_callback=None, ): super().__init__( auth_key, loggers=loggers, retries=retries, delay=delay, auto_reconnect=auto_reconnect, connect_timeout=connect_timeout, auth_key_callback=auth_key_callback, updates_queue=updates_queue, auto_reconnect_callback=auto_reconnect_callback, ) self._state = CustomMTProtoState(self.auth_key, loggers=self._loggers) self._send_queue = CustomMessagePacker(self._state, loggers=self._loggers) def external_append(self, state): self._send_queue.append(state) def external_extend(self, states): self._send_queue.extend(states) async def _send_loop(self): while self._user_connected and not self._reconnecting: if self._pending_ack: ack = RequestState(MsgsAck(list(self._pending_ack))) self._send_queue.append(ack) self._last_acks.append(ack) self._pending_ack.clear() self._log.debug("Waiting for messages to send...") batch, data = await self._send_queue.get() if not data: continue logging.debug("Sending data %s", data) self._log.debug( "Encrypting %d message(s) in %d bytes for sending", len(batch), len(data), ) data = self._state.encrypt_message_data(data) for state in batch: if not isinstance(state, list): if isinstance(state.request, TLRequest): self._pending_state[state.msg_id] = state else: for s in state: if isinstance(s.request, TLRequest): self._pending_state[s.msg_id] = s try: await self._connection.send(data) except IOError as e: self._log.info("Connection closed while sending data") self._start_reconnect(e) return self._log.debug("Encrypted messages put in a queue to be sent") def partial_decrypt(self, body): if len(body) < 8: raise InvalidBufferError(body) key_id = struct.unpack(" List[int]: return list(self._clients.keys()) @property def clients(self) -> Dict[int, MTProtoSender]: return self._clients @property def sessions(self) -> List[SQLiteSession]: return self._safe_sessions def read_sessions(self): logging.debug("Reading sessions...") session_files = list( filter( lambda f: f.startswith("hikka-") and f.endswith(".session"), os.listdir("."), ) ) for session in session_files: Path("safe-" + session).write_bytes(Path(session).read_bytes()) self._sessions = [SQLiteSession(session) for session in session_files] self._safe_sessions = [ SQLiteSession("safe-" + session) for session in session_files ] for session in self._safe_sessions: logging.debug("Processing session %s...", session.filename) session.set_dc(0, "0.0.0.0", 11111) session.auth_key = AuthKey( data=( "Where are you at?\nWhere have you" " been?\n問いかけに答えはなく\nWhere are we headed?\nWhat did you" " mean?\n追いかけても 遅く 遠く\nA bird, a butterfly and my red" " scarf\nDon't make a mess of memories\nJust let me heal your" " scars\nThe wall, the owl, forgotten wharf\n時が止まることもなく" ).encode() + b"\x00" * 13 ) session.save() session.close() def rename(filename: str) -> str: session_id = re.findall(r"\d+", filename)[-1] return f"hikka-{session_id}.session" for session in self._safe_sessions: os.rename( os.path.abspath(session.filename), os.path.abspath(os.path.join("../", rename(session.filename))), ) async def init_clients(self): for session in self._sessions: class _Loggers(dict): def __missing__(self, key): if key.startswith("telethon."): key = key.split(".", maxsplit=1)[1] return logging.getLogger("hikkatl").getChild(key) def _auth_key_callback(auth_key): self.session.auth_key = auth_key self.session.save() _updates_queue = asyncio.Queue() client = CustomMTProtoSender( session.auth_key, loggers=_Loggers(), retries=5, delay=1, auto_reconnect=True, connect_timeout=10, auth_key_callback=_auth_key_callback, updates_queue=_updates_queue, auto_reconnect_callback=None, ) await client.connect( ConnectionTcpFull( session.server_address, session.port, session.dc_id, loggers=_Loggers(), ) ) client.id = int(re.findall(r"\d+", session.filename)[-1]) logging.debug("Client %s connected", client.id) self._clients[client.id] = client @dataclasses.dataclass class Socket: reader: asyncio.StreamReader writer: asyncio.StreamWriter client_id: int class TCP: def __init__(self, session_storage: SessionStorage): self._sockets = {} self._socket_files = [] self._session_storage = session_storage self.gc(init=True) for client_id in self._session_storage.client_ids: filename = os.path.abspath( os.path.join("../", f"hikka-{client_id}-proxy.sock") ) asyncio.ensure_future( asyncio.start_unix_server( functools.partial( self._process_conn, client_id=client_id, filename=filename, ), filename, ) ) def _process_conn(self, reader, writer, client_id, filename): self._session_storage.clients[client_id].set_socket(writer) self._socket_files.append(filename) logging.info("Socket %s connected", filename) socket = Socket(reader, writer, client_id) self._sockets[client_id] = socket asyncio.ensure_future(read_loop(socket)) @staticmethod async def recv(sock: Socket): return await ClientFullPacketCodec.read_packet(None, sock.reader) @staticmethod async def send(sock: Socket, data: bytes): sock.writer.write(ClientFullPacketCodec.encode_packet(None, data)) await sock.writer.drain() def _find_real_request(self, request: TLRequest) -> TLRequest: if isinstance( request, ( InvokeWithLayerRequest, InvokeAfterMsgRequest, InvokeAfterMsgsRequest, InvokeWithMessagesRangeRequest, InvokeWithTakeoutRequest, InvokeWithoutUpdatesRequest, ), ): return self._find_real_request(request.query) return request def _malicious(self, request: TLRequest) -> bool: request = self._find_real_request(request) if ( isinstance( request, ( DeleteAccountRequest, BindTempAuthKeyRequest, CancelCodeRequest, CheckRecoveryPasswordRequest, ExportAuthorizationRequest, ExportLoginTokenRequest, ImportAuthorizationRequest, LogOutRequest, RecoverPasswordRequest, RequestPasswordRecoveryRequest, ResetAuthorizationsRequest, ResetLoginEmailRequest, SendCodeRequest, ), ) or ( isinstance(request, UpdateProfileRequest) and "savedmessages" in (request.first_name + request.last_name).replace(" ", "").lower() ) or ( isinstance(request, GetHistoryRequest) and isinstance(request.peer, InputPeerUser) and request.peer.user_id == 777000 ) or ( isinstance(request, ForwardMessagesRequest) and isinstance(request.from_peer, InputPeerUser) and request.from_peer.user_id == 777000 ) or ( isinstance(request, SearchRequest) and isinstance(request.peer, InputPeerUser) and request.peer.user_id == 777000 ) ): return True async def read(self, conn: Socket): data = await self.recv(conn) logging.debug("Got data from client %s", data) if data: msg_id = struct.unpack("