Heroku/heroku/secure/customtl.py

90 lines
2.8 KiB
Python

import asyncio
import logging
import struct
import time
from herokutl.errors import InvalidBufferError, SecurityError
from herokutl.extensions import BinaryReader
from herokutl.network.connection import ConnectionTcpFull as ConnectionTcpFullOrig
from herokutl.network.mtprotostate import MTProtoState as MTProtoStateOrig
from herokutl.tl.core import TLMessage
from herokutl.tl.types import BadMsgNotification, BadServerSalt
MSG_TOO_NEW_DELTA = 30
MSG_TOO_OLD_DELTA = 300
class MTProtoState(MTProtoStateOrig):
def encrypt_message_data(self, data):
logging.debug("Skipping encryption...")
return data
def decrypt_message_data(self, body):
now = time.time() + self.time_offset
if len(body) < 8:
raise InvalidBufferError(body)
logging.debug("Got raw data: %s", body)
reader = BinaryReader(body)
remote_msg_id = reader.read_long()
if remote_msg_id % 2 != 1:
raise SecurityError("Server sent an even msg_id")
if (
remote_msg_id <= self._highest_remote_id
and remote_msg_id in self._recent_remote_ids
):
self._log.warning(
"Server resent the older message %d, ignoring", remote_msg_id
)
self._count_ignored()
return None
remote_sequence = reader.read_int()
reader.read_int()
obj = reader.tgread_object()
if obj.CONSTRUCTOR_ID not in (
BadServerSalt.CONSTRUCTOR_ID,
BadMsgNotification.CONSTRUCTOR_ID,
):
remote_msg_time = remote_msg_id >> 32
time_delta = now - remote_msg_time
if time_delta > MSG_TOO_OLD_DELTA:
self._log.warning(
"Server sent a very old message with ID %d, ignoring", remote_msg_id
)
self._count_ignored()
return None
if -time_delta > MSG_TOO_NEW_DELTA:
self._log.warning(
"Server sent a very new message with ID %d, ignoring", remote_msg_id
)
self._count_ignored()
return None
self._recent_remote_ids.append(remote_msg_id)
self._highest_remote_id = remote_msg_id
self._ignore_count = 0
return TLMessage(remote_msg_id, remote_sequence, obj)
class ConnectionTcpFull(ConnectionTcpFullOrig):
def set_unix_socket(self, unix_socket_path):
self._unix_socket_path = unix_socket_path
async def _connect(self, timeout=None, ssl=None):
self._reader, self._writer = await asyncio.wait_for(
asyncio.open_unix_connection(path=self._unix_socket_path, ssl=None),
timeout=timeout,
)
self._codec = self.packet_codec(self)
self._init_conn()
await self._writer.drain()