231 lines
7.0 KiB
Python
231 lines
7.0 KiB
Python
import argparse
|
|
import contextlib
|
|
import ipaddress
|
|
import json
|
|
import os
|
|
import signal
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import urllib.request
|
|
from collections.abc import Iterator
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from typing import Any
|
|
|
|
|
|
class ClanError(Exception):
|
|
pass
|
|
|
|
|
|
def try_bind_port(port: int) -> bool:
|
|
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
with tcp, udp:
|
|
try:
|
|
tcp.bind(("127.0.0.1", port))
|
|
udp.bind(("127.0.0.1", port))
|
|
return True
|
|
except OSError:
|
|
return False
|
|
|
|
|
|
def try_connect_port(port: int) -> bool:
|
|
sock = socket.socket(socket.AF_INET)
|
|
result = sock.connect_ex(("127.0.0.1", port))
|
|
sock.close()
|
|
return result == 0
|
|
|
|
|
|
def find_free_port() -> int | None:
|
|
"""Find an unused localhost port from 1024-65535 and return it."""
|
|
with contextlib.closing(socket.socket(type=socket.SOCK_STREAM)) as sock:
|
|
sock.bind(("127.0.0.1", 0))
|
|
return sock.getsockname()[1]
|
|
|
|
|
|
class Identity:
|
|
def __init__(self, path: Path) -> None:
|
|
self.public = (path / "identity.public").read_text()
|
|
self.private = (path / "identity.secret").read_text()
|
|
|
|
def node_id(self) -> str:
|
|
nid = self.public.split(":")[0]
|
|
assert (
|
|
len(nid) == 10
|
|
), f"node_id must be 10 characters long, got {len(nid)}: {nid}"
|
|
return nid
|
|
|
|
|
|
class ZerotierController:
|
|
def __init__(self, port: int, home: Path) -> None:
|
|
self.port = port
|
|
self.home = home
|
|
self.authtoken = (home / "authtoken.secret").read_text()
|
|
self.identity = Identity(home)
|
|
|
|
def _http_request(
|
|
self,
|
|
path: str,
|
|
method: str = "GET",
|
|
headers: dict[str, str] = {},
|
|
data: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
body = None
|
|
headers = headers.copy()
|
|
if data is not None:
|
|
body = json.dumps(data).encode("ascii")
|
|
headers["Content-Type"] = "application/json"
|
|
headers["X-ZT1-AUTH"] = self.authtoken
|
|
url = f"http://127.0.0.1:{self.port}{path}"
|
|
req = urllib.request.Request(url, headers=headers, method=method, data=body)
|
|
resp = urllib.request.urlopen(req)
|
|
return json.load(resp)
|
|
|
|
def status(self) -> dict[str, Any]:
|
|
return self._http_request("/status")
|
|
|
|
def create_network(self, data: dict[str, Any] = {}) -> dict[str, Any]:
|
|
return self._http_request(
|
|
f"/controller/network/{self.identity.node_id()}______",
|
|
method="POST",
|
|
data=data,
|
|
)
|
|
|
|
def get_network(self, network_id: str) -> dict[str, Any]:
|
|
return self._http_request(f"/controller/network/{network_id}")
|
|
|
|
|
|
@contextmanager
|
|
def zerotier_controller() -> Iterator[ZerotierController]:
|
|
# This check could be racy but it's unlikely in practice
|
|
controller_port = find_free_port()
|
|
if controller_port is None:
|
|
raise ClanError("cannot find a free port for zerotier controller")
|
|
|
|
with TemporaryDirectory() as d:
|
|
tempdir = Path(d)
|
|
home = tempdir / "zerotier-one"
|
|
home.mkdir()
|
|
cmd = [
|
|
"fakeroot",
|
|
"--",
|
|
"zerotier-one",
|
|
f"-p{controller_port}",
|
|
str(home),
|
|
]
|
|
with subprocess.Popen(
|
|
cmd,
|
|
preexec_fn=os.setsid,
|
|
) as p:
|
|
process_group = os.getpgid(p.pid)
|
|
try:
|
|
print(
|
|
f"wait for controller to be started on 127.0.0.1:{controller_port}...",
|
|
)
|
|
while not try_connect_port(controller_port):
|
|
status = p.poll()
|
|
if status is not None:
|
|
raise ClanError(
|
|
f"zerotier-one has been terminated unexpected with {status}"
|
|
)
|
|
time.sleep(0.1)
|
|
print()
|
|
|
|
yield ZerotierController(controller_port, home)
|
|
finally:
|
|
os.killpg(process_group, signal.SIGKILL)
|
|
|
|
|
|
@dataclass
|
|
class NetworkController:
|
|
networkid: str
|
|
identity: Identity
|
|
|
|
|
|
# TODO: allow merging more network configuration here
|
|
def create_network_controller() -> NetworkController:
|
|
e = ClanError("Bug, should never happen")
|
|
for _ in range(10):
|
|
try:
|
|
with zerotier_controller() as controller:
|
|
network = controller.create_network()
|
|
return NetworkController(network["nwid"], controller.identity)
|
|
except ClanError: # probably failed to allocate port, so retry
|
|
print("failed to create network, retrying..., probabl", file=sys.stderr)
|
|
raise e
|
|
|
|
|
|
def create_identity() -> Identity:
|
|
with TemporaryDirectory() as d:
|
|
tmpdir = Path(d)
|
|
private = tmpdir / "identity.secret"
|
|
public = tmpdir / "identity.public"
|
|
subprocess.run(["zerotier-idtool", "generate", private, public])
|
|
return Identity(tmpdir)
|
|
|
|
|
|
def compute_zerotier_ip(network_id: str, identity: Identity) -> ipaddress.IPv6Address:
|
|
assert (
|
|
len(network_id) == 16
|
|
), "network_id must be 16 characters long, got {network_id}"
|
|
nwid = int(network_id, 16)
|
|
node_id = int(identity.node_id(), 16)
|
|
addr_parts = bytearray(
|
|
[
|
|
0xFD,
|
|
(nwid >> 56) & 0xFF,
|
|
(nwid >> 48) & 0xFF,
|
|
(nwid >> 40) & 0xFF,
|
|
(nwid >> 32) & 0xFF,
|
|
(nwid >> 24) & 0xFF,
|
|
(nwid >> 16) & 0xFF,
|
|
(nwid >> 8) & 0xFF,
|
|
(nwid) & 0xFF,
|
|
0x99,
|
|
0x93,
|
|
(node_id >> 32) & 0xFF,
|
|
(node_id >> 24) & 0xFF,
|
|
(node_id >> 16) & 0xFF,
|
|
(node_id >> 8) & 0xFF,
|
|
(node_id) & 0xFF,
|
|
]
|
|
)
|
|
return ipaddress.IPv6Address(bytes(addr_parts))
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--mode", choices=["network", "identity"], required=True, type=str
|
|
)
|
|
parser.add_argument("--ip", type=Path, required=True)
|
|
parser.add_argument("--identity-secret", type=Path, required=True)
|
|
parser.add_argument("--network-id", type=str, required=False)
|
|
args = parser.parse_args()
|
|
|
|
match args.mode:
|
|
case "network":
|
|
if args.network_id is None:
|
|
raise ValueError("network_id parameter is required")
|
|
controller = create_network_controller()
|
|
identity = controller.identity
|
|
network_id = controller.networkid
|
|
Path(args.network_id).write_text(network_id)
|
|
case "identity":
|
|
identity = create_identity()
|
|
network_id = args.network_id
|
|
case _:
|
|
raise ValueError(f"unknown mode {args.mode}")
|
|
ip = compute_zerotier_ip(network_id, identity)
|
|
|
|
args.identity_secret.write_text(identity.private)
|
|
args.ip.write_text(ip.compressed)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|