mirror of https://github.com/coddrago/Heroku
fix: proxy passing with fallbacks to cloudflare servers
commit
55820e0c0c
|
@ -15,7 +15,7 @@ RUN git clone https://github.com/coddrago/Heroku /Heroku
|
|||
# Создаем виртуальное окружение Python
|
||||
RUN python -m venv /venv
|
||||
# Устанавливаем зависимости проекта
|
||||
RUN /venv/bin/pip install --no-warn-script-location --no-cache-dir -r /Hikka/requirements.txt
|
||||
RUN /venv/bin/pip install --no-warn-script-location --no-cache-dir -r /Heroku/requirements.txt
|
||||
|
||||
# -------------------------------
|
||||
# Используем другой базовый образ для финального контейнера
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
import typing
|
||||
|
||||
class BaseTunnel:
|
||||
async def start(self):
|
||||
raise NotImplementedError("Subclasses must implement the 'start' method.")
|
||||
|
||||
async def stop(self):
|
||||
raise NotImplementedError("Subclasses must implement the 'stop' method.")
|
||||
|
||||
async def wait_for_url(self, timeout: float) -> typing.Optional[str]:
|
||||
raise NotImplementedError("Subclasses must implement the 'wait_for_url' method.")
|
|
@ -0,0 +1,62 @@
|
|||
import typing
|
||||
import logging
|
||||
import asyncio
|
||||
import contextvars
|
||||
import functools
|
||||
|
||||
from pycloudflared import try_cloudflare
|
||||
|
||||
from .base_tunnel import BaseTunnel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CloudflareTunnel(BaseTunnel):
|
||||
def __init__(
|
||||
self,
|
||||
port: int,
|
||||
verbose: bool = False,
|
||||
change_url_callback: typing.Callable[[str], None] = None,
|
||||
):
|
||||
self.port = port
|
||||
self.verbose = verbose
|
||||
self._change_url_callback = change_url_callback
|
||||
self._tunnel_url = None
|
||||
self._url_available = asyncio.Event()
|
||||
self._url_available.clear()
|
||||
|
||||
# to support python 3.8...
|
||||
async def to_thread(self, func, /, *args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
ctx = contextvars.copy_context()
|
||||
func_call = functools.partial(ctx.run, func, *args, **kwargs)
|
||||
return await loop.run_in_executor(None, func_call)
|
||||
|
||||
async def start(self):
|
||||
logger.debug(f"Attempting Cloudflare tunnel on port {self.port}...")
|
||||
|
||||
try:
|
||||
self._tunnel_url = (await self.to_thread(try_cloudflare, port=self.port, verbose=self.verbose)).tunnel
|
||||
logger.debug(f"Cloudflare tunnel established: {self._tunnel_url}")
|
||||
|
||||
if self._change_url_callback:
|
||||
self._change_url_callback(self._tunnel_url)
|
||||
|
||||
self._url_available.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to establish Cloudflare tunnel: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
logger.debug("Stopping Cloudflare tunnel...")
|
||||
try_cloudflare.terminate(self.port)
|
||||
|
||||
async def wait_for_url(self, timeout: float) -> typing.Optional[str]:
|
||||
try:
|
||||
await asyncio.wait_for(self._url_available.wait(), timeout)
|
||||
return self._tunnel_url
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout waiting for Cloudflare URL.")
|
||||
return None
|
|
@ -49,7 +49,7 @@ class Web(root.Web):
|
|||
self.ready = asyncio.Event()
|
||||
self.client_data = {}
|
||||
self.app = web.Application()
|
||||
self.proxypasser = proxypass.ProxyPasser()
|
||||
self.proxypasser = None
|
||||
aiohttp_jinja2.setup(
|
||||
self.app,
|
||||
filters={"getdoc": inspect.getdoc, "ascii": ascii},
|
||||
|
@ -81,10 +81,7 @@ class Web(root.Web):
|
|||
|
||||
if proxy_pass:
|
||||
with contextlib.suppress(Exception):
|
||||
url = await asyncio.wait_for(
|
||||
self.proxypasser.get_url(self.port),
|
||||
timeout=10,
|
||||
)
|
||||
url = await self.proxypasser.get_url(timeout=10)
|
||||
|
||||
if not url:
|
||||
ip = (
|
||||
|
@ -109,6 +106,7 @@ class Web(root.Web):
|
|||
await self.runner.setup()
|
||||
self.port = os.environ.get("PORT", port)
|
||||
site = web.TCPSite(self.runner, None, self.port)
|
||||
self.proxypasser = proxypass.ProxyPasser(port=self.port)
|
||||
await site.start()
|
||||
|
||||
await self.get_url(proxy_pass)
|
||||
|
|
|
@ -4,107 +4,59 @@
|
|||
# You can redistribute it and/or modify it under the terms of the GNU AGPLv3
|
||||
# 🔑 https://www.gnu.org/licenses/agpl-3.0.html
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
import typing
|
||||
from .ssh_tunnel import SSHTunnel
|
||||
from .cloudflare_tunnel import CloudflareTunnel
|
||||
|
||||
from .. import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyPasser:
|
||||
def __init__(self, change_url_callback: callable = lambda _: None):
|
||||
self._tunnel_url = None
|
||||
self._sproc = None
|
||||
self._url_available = asyncio.Event()
|
||||
self._url_available.set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._change_url_callback = change_url_callback
|
||||
|
||||
async def _read_stream(
|
||||
def __init__(
|
||||
self,
|
||||
callback: callable,
|
||||
stream: typing.BinaryIO,
|
||||
delay: int,
|
||||
) -> None:
|
||||
for getline in iter(stream.readline, ""):
|
||||
await asyncio.sleep(delay)
|
||||
data_chunk = await getline
|
||||
if await callback(data_chunk.decode("utf-8")):
|
||||
if not self._url_available.is_set():
|
||||
self._url_available.set()
|
||||
port: int,
|
||||
change_url_callback: typing.Callable[[str], None] = None,
|
||||
verbose: bool = False
|
||||
):
|
||||
self._tunnel_url = None
|
||||
self._port = port
|
||||
self._change_url_callback = change_url_callback
|
||||
self._verbose = verbose
|
||||
self._tunnels = [
|
||||
SSHTunnel(port=port, change_url_callback=self._on_url_change),
|
||||
CloudflareTunnel(port=port, verbose=verbose, change_url_callback=self._on_url_change)
|
||||
]
|
||||
|
||||
def kill(self):
|
||||
try:
|
||||
self._sproc.terminate()
|
||||
except Exception:
|
||||
logger.exception("Failed to kill proxy pass process")
|
||||
else:
|
||||
logger.debug("Proxy pass tunnel killed")
|
||||
|
||||
async def _process_stream(self, stdout_line: str) -> None:
|
||||
logger.debug(stdout_line)
|
||||
regex = r"tunneled.*?(https:\/\/.+)"
|
||||
def _on_url_change(self, url: str):
|
||||
self._tunnel_url = url
|
||||
if self._change_url_callback:
|
||||
self._change_url_callback(url)
|
||||
|
||||
def set_port(self, port: int):
|
||||
self.port = port
|
||||
|
||||
if re.search(regex, stdout_line):
|
||||
self._tunnel_url = re.search(regex, stdout_line)[1]
|
||||
self._change_url_callback(self._tunnel_url)
|
||||
logger.debug("Proxy pass tunneled: %s", self._tunnel_url)
|
||||
self._url_available.set()
|
||||
|
||||
async def get_url(self, port: int, no_retry: bool = False) -> typing.Optional[str]:
|
||||
async with self._lock:
|
||||
if self._tunnel_url:
|
||||
try:
|
||||
await asyncio.wait_for(self._sproc.wait(), timeout=0.05)
|
||||
except asyncio.TimeoutError:
|
||||
async def get_url(self, timeout: float = 25) -> typing.Optional[str]:
|
||||
|
||||
if "DOCKER" in os.environ:
|
||||
# We're in a Docker container, so we can't use ssh
|
||||
# Also, the concept of Docker is to keep
|
||||
# everything isolated, so we can't proxy-pass to
|
||||
# open web.
|
||||
return None
|
||||
|
||||
for tunnel in self._tunnels:
|
||||
try:
|
||||
await tunnel.start()
|
||||
self._tunnel_url = await tunnel.wait_for_url(timeout)
|
||||
if self._tunnel_url:
|
||||
return self._tunnel_url
|
||||
else:
|
||||
self.kill()
|
||||
logger.warning(f"{tunnel.__class__.__name__} failed to provide URL.")
|
||||
except Exception as e:
|
||||
logger.warning(f"{tunnel.__class__.__name__} failed: {e}")
|
||||
|
||||
if "DOCKER" in os.environ:
|
||||
# We're in a Docker container, so we can't use ssh
|
||||
# Also, the concept of Docker is to keep
|
||||
# everything isolated, so we can't proxy-pass to
|
||||
# open web.
|
||||
return None
|
||||
|
||||
logger.debug("Starting proxy pass shell for port %d", port)
|
||||
self._sproc = await asyncio.create_subprocess_shell(
|
||||
(
|
||||
"ssh -o StrictHostKeyChecking=no -R"
|
||||
f" 80:127.0.0.1:{port} nokey@localhost.run"
|
||||
),
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
utils.atexit(self.kill)
|
||||
|
||||
self._url_available = asyncio.Event()
|
||||
logger.debug("Starting proxy pass reader for port %d", port)
|
||||
asyncio.ensure_future(
|
||||
self._read_stream(
|
||||
self._process_stream,
|
||||
self._sproc.stdout,
|
||||
1,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._url_available.wait(), 15)
|
||||
except asyncio.TimeoutError:
|
||||
self.kill()
|
||||
self._tunnel_url = None
|
||||
if no_retry:
|
||||
return None
|
||||
|
||||
return await self.get_url(port, no_retry=True)
|
||||
|
||||
logger.debug("Proxy pass tunnel url to port %d: %s", port, self._tunnel_url)
|
||||
|
||||
return self._tunnel_url
|
||||
return None
|
|
@ -0,0 +1,127 @@
|
|||
import typing
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from .base_tunnel import BaseTunnel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SSHTunnel(BaseTunnel):
|
||||
def __init__(
|
||||
self,
|
||||
port: int,
|
||||
change_url_callback: typing.Callable[[str], None] = None,
|
||||
):
|
||||
#TODO: select ssh servers?
|
||||
self.ssh_commands = [
|
||||
(f"ssh -R 80:127.0.0.1:{port} serveo.net -T -n", r"https:\/\/(\S*serveo\.net\S*)"),
|
||||
(f"ssh -o StrictHostKeyChecking=no -R 80:127.0.0.1:{port} nokey@localhost.run", r"https:\/\/(\S*lhr\.life\S*)"),
|
||||
]
|
||||
self._change_url_callback = change_url_callback
|
||||
self._tunnel_url = None
|
||||
self._url_available = asyncio.Event()
|
||||
self._url_available.clear()
|
||||
self.process = None
|
||||
self.current_command_index = 0
|
||||
self._ssh_task = None
|
||||
self._all_commands_failed = False
|
||||
|
||||
async def start(self):
|
||||
self._ssh_task = asyncio.create_task(self._run_ssh_tunnel())
|
||||
|
||||
async def stop(self):
|
||||
if self._ssh_task:
|
||||
self._ssh_task.cancel()
|
||||
try:
|
||||
await self._ssh_task
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("SSH task was cancelled")
|
||||
|
||||
if self.process:
|
||||
logger.debug("Stopping SSH tunnel...")
|
||||
try:
|
||||
self.process.terminate()
|
||||
await asyncio.wait_for(self.process.wait(), timeout=5)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to terminate SSH process: {e}")
|
||||
finally:
|
||||
self.process = None
|
||||
|
||||
async def wait_for_url(self, timeout: float) -> typing.Optional[str]:
|
||||
if self._all_commands_failed:
|
||||
return None
|
||||
try:
|
||||
await asyncio.wait_for(self._url_available.wait(), timeout)
|
||||
return self._tunnel_url
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout waiting for tunnel URL.")
|
||||
return None
|
||||
|
||||
async def _run_ssh_tunnel(self):
|
||||
if not self.ssh_commands:
|
||||
logger.debug("SSH command list is empty")
|
||||
return
|
||||
try:
|
||||
while self.current_command_index < len(self.ssh_commands):
|
||||
ssh_command, regex_pattern = self.ssh_commands[self.current_command_index]
|
||||
logger.debug(f"Attempting SSH command: {ssh_command} with pattern: {regex_pattern}")
|
||||
try:
|
||||
command_list = ssh_command.split()
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
*command_list,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
logger.debug(f"SSH tunnel started with PID: {self.process.pid}")
|
||||
asyncio.create_task(self._read_stream_and_process(self.process.stdout, regex_pattern))
|
||||
|
||||
await self.process.wait()
|
||||
|
||||
if self._tunnel_url is None:
|
||||
logger.warning("SSH tunnel disconnected without providing a URL.")
|
||||
else:
|
||||
logger.info("SSH tunnel disconnected, but URL was obtained. Exiting SSH Tunnel attempts.")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to start SSH tunnel with command: {ssh_command}. Error: {e}"
|
||||
)
|
||||
|
||||
finally:
|
||||
if self.process:
|
||||
self.process = None
|
||||
if self._tunnel_url is None:
|
||||
logger.info("Reconnecting SSH tunnel after failure...")
|
||||
self.current_command_index += 1
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logger.info("Exiting SSH Tunnel attempts after disconnect.")
|
||||
return
|
||||
self._all_commands_failed = True
|
||||
finally:
|
||||
if self._tunnel_url is None and self._all_commands_failed:
|
||||
logger.error("All SSH commands failed.")
|
||||
self._url_available.set()
|
||||
|
||||
async def _read_stream_and_process(self, stream, regex_pattern: str):
|
||||
try:
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
line_str = line.decode("utf-8").strip()
|
||||
await self._process_stream(line_str, regex_pattern)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error reading and processing stream: {e}")
|
||||
|
||||
async def _process_stream(self, stdout_line: str, regex_pattern: str):
|
||||
logger.debug(stdout_line)
|
||||
match = re.search(regex_pattern, stdout_line)
|
||||
if match:
|
||||
self._tunnel_url = match.group(0)
|
||||
if self._change_url_callback:
|
||||
self._change_url_callback(self._tunnel_url)
|
||||
self._url_available.set()
|
|
@ -1,3 +1,4 @@
|
|||
heroku-tl-new==3.2.5
|
||||
pycloudflared==0.2.0
|
||||
|
||||
# Python 3.9+
|
||||
|
|
Loading…
Reference in New Issue