aboutsummaryrefslogtreecommitdiff
path: root/proxy.py
diff options
context:
space:
mode:
Diffstat (limited to 'proxy.py')
-rwxr-xr-xproxy.py31
1 files changed, 20 insertions, 11 deletions
diff --git a/proxy.py b/proxy.py
index 854446b..9bdb76f 100755
--- a/proxy.py
+++ b/proxy.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import argparse
import logging
+import os
import pathlib
import ssl
import socketserver
@@ -8,7 +9,6 @@ import urllib.request
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
-domains_path = None
class Handler(socketserver.BaseRequestHandler):
def handle(self):
@@ -30,23 +30,32 @@ class Handler(socketserver.BaseRequestHandler):
sock.sendall(response.encode("UTF8"))
-def sni_callback(socket: ssl.SSLSocket, server_name, _context):
- context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
- domain_path = domains_path / server_name
- context.load_cert_chain(domain_path / "pubcert.pem", domain_path / "privkey.pem")
- socket.context = context
-
-
def main():
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(message)s")
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=1965)
- parser.add_argument("domains_path", type=pathlib.Path)
+
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument("--certificates-from-path", type=pathlib.Path)
+ group.add_argument("--certificates-from-credential")
args = parser.parse_args()
- global domains_path
- domains_path = args.domains_path
+ if args.certificates_from_path:
+ def domain_to_path(server_name):
+ domain_path = args.certificates_from_path / server_name
+ return (domain_path / "pubcert.pem" , domain_path / "privkey.pem")
+
+ if args.certificates_from_credential:
+ def domain_to_path(server_name):
+ credentials_directory = pathlib.Path(os.environ["CREDENTIALS_DIRECTORY"])
+ return (credentials_directory / f"{args.certificates_from_credential}_{server_name}_pubcert.pem", credentials_directory / f"{args.certificates_from_credential}_{server_name}_privkey.pem")
+
+ def sni_callback(socket: ssl.SSLSocket, server_name, _context):
+ context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+ certfile, keyfile = domain_to_path(server_name)
+ context.load_cert_chain(certfile, keyfile)
+ socket.context = context
context.sni_callback = sni_callback