Heroku/private/listener.py

908 lines
30 KiB
Python

import asyncio
import atexit
import contextlib
import copy
import dataclasses
import datetime
import functools
import io
import json
import logging
import os
import re
import signal
import struct
import sys
import time
from hashlib import sha256
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Dict, List, Optional, Union
from zlib import crc32
from hikkatl.crypto import AES, AuthKey
from hikkatl.errors import (
AuthKeyNotFound,
BadMessageError,
InvalidBufferError,
InvalidChecksumError,
SecurityError,
TypeNotFoundError,
)
from hikkatl.extensions.binaryreader import BinaryReader
from hikkatl.extensions.messagepacker import MessagePacker
from hikkatl.network.connection import ConnectionTcpFull
from hikkatl.network.mtprotosender import MTProtoSender
from hikkatl.network.mtprotostate import MTProtoState
from hikkatl.network.requeststate import RequestState
from hikkatl.sessions import SQLiteSession
from hikkatl.tl import TLRequest
from hikkatl.tl.core import GzipPacked, MessageContainer, TLMessage
from hikkatl.tl.functions import (
InitConnectionRequest,
InvokeAfterMsgRequest,
InvokeAfterMsgsRequest,
InvokeWithLayerRequest,
InvokeWithMessagesRangeRequest,
InvokeWithoutUpdatesRequest,
InvokeWithTakeoutRequest,
PingRequest,
)
from hikkatl.tl.functions.account import DeleteAccountRequest, UpdateProfileRequest
from hikkatl.tl.functions.auth import (
BindTempAuthKeyRequest,
CancelCodeRequest,
CheckRecoveryPasswordRequest,
ExportAuthorizationRequest,
ExportLoginTokenRequest,
ImportAuthorizationRequest,
LogOutRequest,
RecoverPasswordRequest,
RequestPasswordRecoveryRequest,
ResetAuthorizationsRequest,
ResetLoginEmailRequest,
SendCodeRequest,
)
from hikkatl.tl.functions.help import GetConfigRequest
from hikkatl.tl.functions.messages import (
ForwardMessagesRequest,
GetHistoryRequest,
SearchRequest,
)
from hikkatl.tl.types import (
InputPeerUser,
Message,
MsgsAck,
PeerUser,
UpdateNewMessage,
Updates,
UpdateShortMessage,
)
os.chdir(os.path.dirname(os.path.abspath(__file__)))
logging.basicConfig(level=logging.DEBUG)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
style="%",
)
)
rotating_handler = RotatingFileHandler(
filename="hikka.log",
mode="a",
maxBytes=10 * 1024 * 1024,
backupCount=1,
encoding="utf-8",
delay=0,
)
rotating_handler.setLevel(logging.DEBUG)
rotating_handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
style="%",
)
)
logging.getLogger().handlers[0].setLevel(logging.CRITICAL)
logging.getLogger().addHandler(handler)
logging.getLogger().addHandler(rotating_handler)
class _OpaqueRequest(TLRequest):
def __init__(self, data: bytes):
self.data = data
def _bytes(self):
return self.data
class CustomMTProtoState(MTProtoState):
def write_data_as_message(
self,
buffer,
data,
content_related,
*,
after_id=None,
msg_id=None,
):
msg_id = msg_id or self._get_new_msg_id()
seq_no = self._get_seq_no(content_related)
if after_id is None:
body = GzipPacked.gzip_if_smaller(content_related, data)
else:
body = GzipPacked.gzip_if_smaller(
content_related,
bytes(InvokeAfterMsgRequest(after_id, _OpaqueRequest(data))),
)
buffer.write(struct.pack("<qii", msg_id, seq_no, len(body)))
buffer.write(body)
return msg_id
class CustomMessagePacker(MessagePacker):
async def get(self):
if not self._deque:
self._ready.clear()
await self._ready.wait()
buffer = io.BytesIO()
batch = []
size = 0
while self._deque and len(batch) <= MessageContainer.MAXIMUM_LENGTH:
state = self._deque.popleft()
size += len(state.data) + TLMessage.SIZE_OVERHEAD
if size <= MessageContainer.MAXIMUM_SIZE:
state.msg_id = self._state.write_data_as_message(
buffer,
state.data,
isinstance(state.request, TLRequest),
after_id=state.after.msg_id if state.after else None,
msg_id=state.msg_id,
)
batch.append(state)
self._log.debug(
"Assigned msg_id = %d to %s (%x)",
state.msg_id,
state.request.__class__.__name__,
id(state.request),
)
continue
if batch:
self._deque.appendleft(state)
break
self._log.warning(
"Message payload for %s is too long (%d) and cannot be sent",
state.request.__class__.__name__,
len(state.data),
)
state.future.set_exception(ValueError("Request payload is too big"))
size = 0
continue
if not batch:
return None, None
if len(batch) > 1:
data = (
struct.pack("<Ii", MessageContainer.CONSTRUCTOR_ID, len(batch))
+ buffer.getvalue()
)
buffer = io.BytesIO()
container_id = self._state.write_data_as_message(
buffer, data, content_related=False
)
for s in batch:
s.container_id = container_id
data = buffer.getvalue()
return batch, data
class ClientFullPacketCodec:
tag = None
def encode_packet(self, data):
length = len(data) + 12
data = struct.pack("<ii", length, 0) + data
crc = struct.pack("<I", crc32(data))
return data + crc
async def read_packet(self, reader):
packet_len_seq = await reader.readexactly(8)
packet_len, seq = struct.unpack("<ii", packet_len_seq)
if packet_len < 0 and seq < 0:
body = await reader.readexactly(4)
raise InvalidBufferError(body)
body = await reader.readexactly(packet_len - 8)
checksum = struct.unpack("<I", body[-4:])[0]
body = body[:-4]
valid_checksum = crc32(packet_len_seq + body)
if checksum != valid_checksum:
raise InvalidChecksumError(checksum, valid_checksum)
return body
def get_config_key(key: str) -> Union[str, bool]:
"""
Parse and return key from config
:param key: Key name in config
:return: Value of config key or `False`, if it doesn't exist
"""
try:
return json.loads(Path("../config.json").read_text()).get(key, False)
except FileNotFoundError:
return False
start_ts = time.perf_counter()
class CustomMTProtoSender(MTProtoSender):
def __init__(
self,
auth_key,
*,
loggers,
retries=5,
delay=1,
auto_reconnect=True,
connect_timeout=None,
auth_key_callback=None,
updates_queue=None,
auto_reconnect_callback=None,
):
super().__init__(
auth_key,
loggers=loggers,
retries=retries,
delay=delay,
auto_reconnect=auto_reconnect,
connect_timeout=connect_timeout,
auth_key_callback=auth_key_callback,
updates_queue=updates_queue,
auto_reconnect_callback=auto_reconnect_callback,
)
self._state = CustomMTProtoState(self.auth_key, loggers=self._loggers)
self._send_queue = CustomMessagePacker(self._state, loggers=self._loggers)
def external_append(self, state):
self._send_queue.append(state)
def external_extend(self, states):
self._send_queue.extend(states)
async def _send_loop(self):
while self._user_connected and not self._reconnecting:
if self._pending_ack:
ack = RequestState(MsgsAck(list(self._pending_ack)))
self._send_queue.append(ack)
self._last_acks.append(ack)
self._pending_ack.clear()
self._log.debug("Waiting for messages to send...")
batch, data = await self._send_queue.get()
if not data:
continue
logging.debug("Sending data %s", data)
self._log.debug(
"Encrypting %d message(s) in %d bytes for sending",
len(batch),
len(data),
)
data = self._state.encrypt_message_data(data)
for state in batch:
if not isinstance(state, list):
if isinstance(state.request, TLRequest):
self._pending_state[state.msg_id] = state
else:
for s in state:
if isinstance(s.request, TLRequest):
self._pending_state[s.msg_id] = s
try:
await self._connection.send(data)
except IOError as e:
self._log.info("Connection closed while sending data")
self._start_reconnect(e)
return
self._log.debug("Encrypted messages put in a queue to be sent")
def partial_decrypt(self, body):
if len(body) < 8:
raise InvalidBufferError(body)
key_id = struct.unpack("<Q", body[:8])[0]
if key_id != self._state.auth_key.key_id:
raise SecurityError("Server replied with an invalid auth key")
msg_key = body[8:24]
aes_key, aes_iv = self._state._calc_key(
self._state.auth_key.key, msg_key, False
)
body = AES.decrypt_ige(body[24:], aes_key, aes_iv)
our_key = sha256(self._state.auth_key.key[96 : 96 + 32] + body)
if msg_key != our_key.digest()[8:24]:
raise SecurityError("Received msg_key doesn't match with expected one")
return body[16:]
async def _handle_recv(self, body: bytes):
try:
message = self._state.decrypt_message_data(body)
if message is None:
return False
except TypeNotFoundError as e:
self._log.info(
"Type %08x not found, remaining data %r",
e.invalid_constructor_id,
e.remaining,
)
return False
except SecurityError as e:
self._log.warning(
"Security error while unpacking a received message: %s", e
)
return False
except BufferError as e:
if isinstance(e, InvalidBufferError) and e.code == 404:
self._log.info(
"Server does not know about the current auth key; the session may"
" need to be recreated"
)
await self._disconnect(error=AuthKeyNotFound())
else:
self._log.warning("Invalid buffer %s", e)
self._start_reconnect(e)
return -1
except Exception as e:
self._log.exception("Unhandled error while decrypting data")
self._start_reconnect(e)
return -1
try:
await self._process_message(message)
except Exception:
self._log.exception("Unhandled error while processing msgs")
logging.debug("Got message from Telegram %s", message)
try:
msg = self.partial_decrypt(body)
to_censor = ""
if isinstance(message, TLMessage):
if isinstance(message.obj, Updates) and (
malicious := next(
(
update
for update in message.obj.updates
if isinstance(update, UpdateNewMessage)
and isinstance(update.message, Message)
and isinstance(update.message.peer_id, PeerUser)
and update.message.peer_id.user_id == 777000
),
None,
)
):
to_censor = malicious.message.message
elif (
isinstance(message.obj, UpdateShortMessage)
and message.obj.user_id == 777000
):
to_censor = message.obj.message
elif isinstance(message.obj, MessageContainer) and (
any((
isinstance(bigmsg.obj, Updates)
and (
malicious := next(
(
update
for update in bigmsg.obj.updates
if isinstance(update, UpdateNewMessage)
and isinstance(update.message, Message)
and isinstance(update.message.peer_id, PeerUser)
and update.message.peer_id.user_id == 777000
),
None,
)
)
for bigmsg in message.obj.messages
if isinstance(bigmsg, TLMessage)
))
):
to_censor = malicious.message.message
elif isinstance(message.obj, MessageContainer) and (
malicious := next(
(
bigmsg
for bigmsg in message.obj.messages
if isinstance(bigmsg, TLMessage)
and isinstance(bigmsg.obj, UpdateShortMessage)
and bigmsg.obj.user_id == 777000
),
None,
)
):
to_censor = malicious.message.message
if to_censor:
to_censor = to_censor.encode()
original_msg = ""
if msg[16:].startswith(b"\xa1\xcfr0"):
with BinaryReader(msg[16:]) as reader:
obj = reader.tgread_object()
assert isinstance(obj, GzipPacked)
original_msg = copy.copy(msg)
msg = obj.data
logging.info("Censoring message %s in %s", to_censor, msg)
msg = msg.replace(to_censor, (b"*" * len(to_censor)))
if original_msg:
msg = original_msg[:16] + msg
if hasattr(self, "_socket"):
logging.debug(
"Got data from socket, forwarding, %s",
ClientFullPacketCodec.encode_packet(None, msg),
)
self._socket.write(ClientFullPacketCodec.encode_packet(None, msg))
await self._socket.drain()
else:
logging.debug("Got data with no socket")
except Exception:
logging.exception("Unhandled error while processing msgs")
return True
def set_socket(self, socket: asyncio.StreamWriter):
self._socket = socket
async def _recv_loop(self):
while self._user_connected and not self._reconnecting:
self._log.debug("Receiving items from the network...")
try:
body = await self._connection.recv()
except IOError as e:
self._log.info("Connection closed while receiving data")
self._start_reconnect(e)
return
except InvalidBufferError as e:
if e.code == 429:
self._log.warning(
"Server indicated flood error at transport level: %s", e
)
await self._disconnect(error=e)
else:
self._log.exception("Server sent invalid buffer")
self._start_reconnect(e)
return
except Exception as e:
self._log.exception("Unhandled error while receiving data")
self._start_reconnect(e)
return
res = await self._handle_recv(body)
if res is False:
continue
elif res == -1:
return
async def _handle_bad_notification(self, message):
bad_msg = message.obj
states = self._pop_states(bad_msg.bad_msg_id)
self._log.debug("Handling bad msg %s", bad_msg)
if bad_msg.error_code in (16, 17):
to = self._state.update_time_offset(correct_msg_id=message.msg_id)
self._log.info("System clock is wrong, set time offset to %ds", to)
elif bad_msg.error_code == 32:
self._state._sequence += 1
elif bad_msg.error_code == 33:
self._state._sequence -= 1
else:
for state in states:
state.future.set_exception(
BadMessageError(state.request, bad_msg.error_code)
)
return
self._send_queue.extend(states)
self._log.debug("%d messages will be resent due to bad msg", len(states))
class SessionStorage:
def __init__(self):
self._sessions: List[SQLiteSession] = []
self._safe_sessions: List[SQLiteSession] = []
self._clients: Dict[int, MTProtoSender] = {}
async def pop_client(self, client_id: int):
await self._clients[client_id].disconnect()
self._clients.pop(client_id)
@property
def client_ids(self) -> List[int]:
return list(self._clients.keys())
@property
def clients(self) -> Dict[int, MTProtoSender]:
return self._clients
@property
def sessions(self) -> List[SQLiteSession]:
return self._safe_sessions
def read_sessions(self):
logging.debug("Reading sessions...")
session_files = list(
filter(
lambda f: f.startswith("hikka-") and f.endswith(".session"),
os.listdir("."),
)
)
for session in session_files:
Path("safe-" + session).write_bytes(Path(session).read_bytes())
self._sessions = [SQLiteSession(session) for session in session_files]
self._safe_sessions = [
SQLiteSession("safe-" + session) for session in session_files
]
for session in self._safe_sessions:
logging.debug("Processing session %s...", session.filename)
session.set_dc(0, "0.0.0.0", 11111)
session.auth_key = AuthKey(
data=(
"Where are you at?\nWhere have you"
" been?\n問いかけに答えはなく\nWhere are we headed?\nWhat did you"
" mean?\n追いかけても 遅く 遠く\nA bird, a butterfly and my red"
" scarf\nDon't make a mess of memories\nJust let me heal your"
" scars\nThe wall, the owl, forgotten wharf\n時が止まることもなく"
).encode()
+ b"\x00" * 13
)
session.save()
session.close()
def rename(filename: str) -> str:
session_id = re.findall(r"\d+", filename)[-1]
return f"hikka-{session_id}.session"
for session in self._safe_sessions:
os.rename(
os.path.abspath(session.filename),
os.path.abspath(os.path.join("../", rename(session.filename))),
)
async def init_clients(self):
for session in self._sessions:
class _Loggers(dict):
def __missing__(self, key):
if key.startswith("telethon."):
key = key.split(".", maxsplit=1)[1]
return logging.getLogger("hikkatl").getChild(key)
def _auth_key_callback(auth_key):
self.session.auth_key = auth_key
self.session.save()
_updates_queue = asyncio.Queue()
client = CustomMTProtoSender(
session.auth_key,
loggers=_Loggers(),
retries=5,
delay=1,
auto_reconnect=True,
connect_timeout=10,
auth_key_callback=_auth_key_callback,
updates_queue=_updates_queue,
auto_reconnect_callback=None,
)
await client.connect(
ConnectionTcpFull(
session.server_address,
session.port,
session.dc_id,
loggers=_Loggers(),
)
)
client.id = int(re.findall(r"\d+", session.filename)[-1])
logging.debug("Client %s connected", client.id)
self._clients[client.id] = client
@dataclasses.dataclass
class Socket:
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
client_id: int
class TCP:
def __init__(self, session_storage: SessionStorage):
self._sockets = {}
self._socket_files = []
self._session_storage = session_storage
self.gc(init=True)
for client_id in self._session_storage.client_ids:
filename = os.path.abspath(
os.path.join("../", f"hikka-{client_id}-proxy.sock")
)
asyncio.ensure_future(
asyncio.start_unix_server(
functools.partial(
self._process_conn,
client_id=client_id,
filename=filename,
),
filename,
)
)
def _process_conn(self, reader, writer, client_id, filename):
self._session_storage.clients[client_id].set_socket(writer)
self._socket_files.append(filename)
logging.info("Socket %s connected", filename)
socket = Socket(reader, writer, client_id)
self._sockets[client_id] = socket
asyncio.ensure_future(read_loop(socket))
@staticmethod
async def recv(sock: Socket):
return await ClientFullPacketCodec.read_packet(None, sock.reader)
@staticmethod
async def send(sock: Socket, data: bytes):
sock.writer.write(ClientFullPacketCodec.encode_packet(None, data))
await sock.writer.drain()
def _find_real_request(self, request: TLRequest) -> TLRequest:
if isinstance(
request,
(
InvokeWithLayerRequest,
InvokeAfterMsgRequest,
InvokeAfterMsgsRequest,
InvokeWithMessagesRangeRequest,
InvokeWithTakeoutRequest,
InvokeWithoutUpdatesRequest,
),
):
return self._find_real_request(request.query)
return request
def _malicious(self, request: TLRequest) -> bool:
request = self._find_real_request(request)
if (
isinstance(
request,
(
DeleteAccountRequest,
BindTempAuthKeyRequest,
CancelCodeRequest,
CheckRecoveryPasswordRequest,
ExportAuthorizationRequest,
ExportLoginTokenRequest,
ImportAuthorizationRequest,
LogOutRequest,
RecoverPasswordRequest,
RequestPasswordRecoveryRequest,
ResetAuthorizationsRequest,
ResetLoginEmailRequest,
SendCodeRequest,
),
)
or (
isinstance(request, UpdateProfileRequest)
and "savedmessages"
in (request.first_name + request.last_name).replace(" ", "").lower()
)
or (
isinstance(request, GetHistoryRequest)
and isinstance(request.peer, InputPeerUser)
and request.peer.user_id == 777000
)
or (
isinstance(request, ForwardMessagesRequest)
and isinstance(request.from_peer, InputPeerUser)
and request.from_peer.user_id == 777000
)
or (
isinstance(request, SearchRequest)
and isinstance(request.peer, InputPeerUser)
and request.peer.user_id == 777000
)
):
return True
async def read(self, conn: Socket):
data = await self.recv(conn)
logging.debug("Got data from client %s", data)
if data:
msg_id = struct.unpack("<q", data[:8])[0]
with BinaryReader(data[16:]) as reader:
tgobject = reader.tgread_object()
logging.debug("Got object %s", tgobject)
if isinstance(tgobject, MsgsAck):
return
while isinstance(tgobject, GzipPacked):
with BinaryReader(tgobject.data) as reader:
tgobject = reader.tgread_object()
logging.debug("Modified object %s", tgobject)
if isinstance(tgobject, InvokeWithLayerRequest) and isinstance(
tgobject.query, InitConnectionRequest
):
tgobject = GetConfigRequest()
if isinstance(tgobject, MessageContainer):
states = []
for message in tgobject.messages:
state = RequestState(message.obj)
if self._malicious(message.obj):
logging.critical(
"Suspicious request detected, substituting with ping"
)
state = RequestState(PingRequest(ping_id=123456789))
state.msg_id = message.msg_id
states.append(state)
self._session_storage.clients[conn.client_id].external_extend(states)
else:
state = RequestState(tgobject)
if self._malicious(tgobject):
logging.critical(
"Suspicious request detected, substituting with ping"
)
state = RequestState(PingRequest(ping_id=123456789))
state.msg_id = msg_id
self._session_storage.clients[conn.client_id].external_append(state)
def gc(self, init: bool, pop_client: Optional[int] = None):
for client_id in (
[pop_client] if pop_client else self._session_storage.client_ids
):
with contextlib.suppress(Exception):
self._sockets[client_id].close()
with contextlib.suppress(Exception):
os.remove(
os.path.abspath(
os.path.join("../", f"hikka-{client_id}-proxy.sock")
)
)
if not init:
with contextlib.suppress(Exception):
os.remove(
os.path.abspath(
os.path.join("../", f"hikka-{client_id}.session")
)
)
if not init:
with contextlib.suppress(Exception):
os.remove(
os.path.abspath(
os.path.join("../", f"hikka-{client_id}.session-journal")
)
)
tcp, session_storage, shell = None, None, None
async def read_loop(sock: Socket):
global tcp, session_storage, shell
while True:
try:
await tcp.read(sock)
except (asyncio.IncompleteReadError, ConnectionResetError):
logging.info("Client disconnected, restarting...")
await session_storage.pop_client(sock.client_id)
if shell:
shell.kill()
logging.info("Waiting for sandbox to exit...")
await shell.wait()
logging.info("Sandbox exited")
exit(1)
except Exception as e:
logging.exception(e)
async def main():
global tcp, session_storage, shell
for session in os.listdir("../"):
if session.startswith("hikka-") and session.endswith(".session"):
session = os.path.abspath(os.path.join("../", session))
session = SQLiteSession(session)
if not session.auth_key.key.startswith(b"Where are you at?"):
session.save()
session.close()
os.rename(
os.path.abspath(os.path.join("../", session.filename)),
os.path.abspath(os.path.join("./", session.filename)),
)
else:
session.close()
os.remove(os.path.abspath(os.path.join("../", session.filename)))
session_storage = SessionStorage()
session_storage.read_sessions()
await session_storage.init_clients()
tcp = TCP(session_storage)
logging.info("Startup delay...")
await asyncio.sleep(3)
logging.info("Starting client...")
shell = await asyncio.create_subprocess_shell(
"cd ../ && ./_start_sandbox.sh",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
shell=True,
)
while True:
await asyncio.sleep(3600)
async def integrity_checker():
while True:
await asyncio.sleep(5)
def shutdown_handler(sig, frame):
print("Bye")
if shell:
with contextlib.suppress(ProcessLookupError):
os.kill(shell.pid, signal.SIGINT)
if tcp:
tcp.gc(init=False)
sys.exit(0)
if __name__ == "__main__":
signal.signal(signal.SIGINT, shutdown_handler)
asyncio.get_event_loop().run_until_complete(main())