Merge pull request 'clan-cli/ssh: rename Group -> HostGroup' (#130) from Mic92-mic92 into main
All checks were successful
build / test (push) Successful in 20s

This commit is contained in:
clan-bot 2023-08-11 14:11:29 +00:00
commit 6d7c78a175
11 changed files with 260 additions and 77 deletions

View File

@ -156,6 +156,7 @@ class Host:
host_key_check: HostKeyCheck = HostKeyCheck.STRICT, host_key_check: HostKeyCheck = HostKeyCheck.STRICT,
meta: Dict[str, Any] = {}, meta: Dict[str, Any] = {},
verbose_ssh: bool = False, verbose_ssh: bool = False,
ssh_options: dict[str, str] = {},
) -> None: ) -> None:
""" """
Creates a Host Creates a Host
@ -179,6 +180,7 @@ class Host:
self.host_key_check = host_key_check self.host_key_check = host_key_check
self.meta = meta self.meta = meta
self.verbose_ssh = verbose_ssh self.verbose_ssh = verbose_ssh
self.ssh_options = ssh_options
def _prefix_output( def _prefix_output(
self, self,
@ -451,6 +453,10 @@ class Host:
ssh_target = self.host ssh_target = self.host
ssh_opts = ["-A"] if self.forward_agent else [] ssh_opts = ["-A"] if self.forward_agent else []
for k, v in self.ssh_options.items():
ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"])
if self.port: if self.port:
ssh_opts.extend(["-p", str(self.port)]) ssh_opts.extend(["-p", str(self.port)])
if self.key: if self.key:
@ -534,7 +540,7 @@ def _worker(
results[idx] = HostResult(host, e) results[idx] = HostResult(host, e)
class Group: class HostGroup:
def __init__(self, hosts: List[Host]) -> None: def __init__(self, hosts: List[Host]) -> None:
self.hosts = hosts self.hosts = hosts
@ -745,9 +751,9 @@ class Group:
self._reraise_errors(results) self._reraise_errors(results)
return results return results
def filter(self, pred: Callable[[Host], bool]) -> "Group": def filter(self, pred: Callable[[Host], bool]) -> "HostGroup":
"""Return a new Group with the results filtered by the predicate""" """Return a new Group with the results filtered by the predicate"""
return Group(list(filter(pred, self.hosts))) return HostGroup(list(filter(pred, self.hosts)))
@overload @overload

View File

@ -0,0 +1,105 @@
import argparse
import json
import subprocess
from .ssh import Host, HostGroup, HostKeyCheck
def deploy_nixos(hosts: HostGroup) -> None:
"""
Deploy to all hosts in parallel
"""
flake_store_paths = {}
for h in hosts.hosts:
flake_uri = str(h.meta.get("flake_uri", ".#"))
if flake_uri not in flake_store_paths:
res = subprocess.run(
[
"nix",
"--extra-experimental-features",
"nix-command flakes",
"flake",
"metadata",
"--json",
flake_uri,
],
check=True,
text=True,
stdout=subprocess.PIPE,
)
data = json.loads(res.stdout)
flake_store_paths[flake_uri] = data["path"]
def deploy(h: Host) -> None:
target = f"{h.user or 'root'}@{h.host}"
flake_store_path = flake_store_paths[str(h.meta.get("flake_uri", ".#"))]
flake_path = str(h.meta.get("flake_path", "/etc/nixos"))
ssh_arg = f"-p {h.port}" if h.port else ""
if h.host_key_check != HostKeyCheck.STRICT:
ssh_arg += " -o StrictHostKeyChecking=no"
if h.host_key_check == HostKeyCheck.NONE:
ssh_arg += " -o UserKnownHostsFile=/dev/null"
ssh_arg += " -i " + h.key if h.key else ""
h.run_local(
f"rsync --checksum -vaF --delete -e 'ssh {ssh_arg}' {flake_store_path}/ {target}:{flake_path}"
)
flake_attr = h.meta.get("flake_attr", "")
if flake_attr:
flake_attr = "#" + flake_attr
target_host = h.meta.get("target_host")
if target_host:
target_user = h.meta.get("target_user")
if target_user:
target_host = f"{target_user}@{target_host}"
extra_args = h.meta.get("extra_args", [])
cmd = (
["nixos-rebuild", "switch"]
+ extra_args
+ [
"--fast",
"--option",
"keep-going",
"true",
"--option",
"accept-flake-config",
"true",
"--build-host",
"",
"--flake",
f"{flake_path}{flake_attr}",
]
)
if target_host:
cmd.extend(["--target-host", target_host])
ret = h.run(cmd, check=False)
# re-retry switch if the first time fails
if ret.returncode != 0:
ret = h.run(cmd)
hosts.run_function(deploy)
# FIXME: we want some kind of inventory here.
def update(args: argparse.Namespace) -> None:
deploy_nixos(
HostGroup(
[Host(args.host, user=args.user, meta=dict(flake_attr=args.flake_attr))]
)
)
def register_parser(parser: argparse.ArgumentParser) -> None:
parser.add_mutually_exclusive_group(required=True)
# TODO pass all args we don't parse into ssh_args, currently it fails if arg starts with -
parser.add_argument("--flake-uri", type=str, default=".#", desc="nix flake uri")
parser.add_argument(
"--flake-attr", type=str, description="nixos configuration in the flake"
)
parser.add_argument("--user", type=str, default="root")
parser.add_argument("host", type=str)
parser.set_defaults(func=update)

View File

@ -19,6 +19,7 @@
, stdenv , stdenv
, wheel , wheel
, zerotierone , zerotierone
, rsync
}: }:
let let
dependencies = [ argcomplete jsonschema ]; dependencies = [ argcomplete jsonschema ];
@ -63,12 +64,12 @@ python3.pkgs.buildPythonPackage {
''; '';
clan-pytest = runCommand "clan-tests" clan-pytest = runCommand "clan-tests"
{ {
nativeBuildInputs = [ age zerotierone bubblewrap sops nix openssh stdenv.cc ]; nativeBuildInputs = [ age zerotierone bubblewrap sops nix openssh rsync stdenv.cc ];
} '' } ''
cp -r ${source} ./src cp -r ${source} ./src
chmod +w -R ./src chmod +w -R ./src
cd ./src cd ./src
${checkPython}/bin/python -m pytest ./tests NIX_STATE_DIR=$TMPDIR/nix ${checkPython}/bin/python -m pytest -s ./tests
touch $out touch $out
''; '';
}; };

View File

@ -20,6 +20,7 @@
zbar zbar
tor tor
age age
rsync
sops; sops;
# Override license so that we can build zerotierone without # Override license so that we can build zerotierone without
# having to re-import nixpkgs. # having to re-import nixpkgs.

View File

@ -10,6 +10,15 @@ def clan_flake(temporary_dir: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator
flake = temporary_dir / "clan-flake" flake = temporary_dir / "clan-flake"
flake.mkdir() flake.mkdir()
(flake / ".clan-flake").touch() (flake / ".clan-flake").touch()
(flake / "flake.nix").write_text(
"""
{
description = "A flake for testing clan";
inputs = {};
outputs = { self }: {};
}
"""
)
monkeypatch.chdir(flake) monkeypatch.chdir(flake)
with mock_env(HOME=str(temporary_dir)): with mock_env(HOME=str(temporary_dir)):
yield flake yield flake

View File

@ -11,4 +11,5 @@ pytest_plugins = [
"sshd", "sshd",
"command", "command",
"ports", "ports",
"host_group",
] ]

View File

@ -0,0 +1,23 @@
import os
import pwd
import pytest
from sshd import Sshd
from clan_cli.ssh import Host, HostGroup, HostKeyCheck
@pytest.fixture
def host_group(sshd: Sshd) -> HostGroup:
login = pwd.getpwuid(os.getuid()).pw_name
return HostGroup(
[
Host(
"127.0.0.1",
port=sshd.port,
user=login,
key=sshd.key,
host_key_check=HostKeyCheck.NONE,
)
]
)

View File

@ -5,11 +5,13 @@ import time
from pathlib import Path from pathlib import Path
from sys import platform from sys import platform
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Iterator, Optional from typing import TYPE_CHECKING, Iterator
import pytest import pytest
from command import Command
from ports import Ports if TYPE_CHECKING:
from command import Command
from ports import Ports
class Sshd: class Sshd:
@ -20,8 +22,11 @@ class Sshd:
class SshdConfig: class SshdConfig:
def __init__(self, path: str, key: str, preload_lib: Optional[str]) -> None: def __init__(
self, path: Path, login_shell: Path, key: str, preload_lib: Path
) -> None:
self.path = path self.path = path
self.login_shell = login_shell
self.key = key self.key = key
self.preload_lib = preload_lib self.preload_lib = preload_lib
@ -51,42 +56,65 @@ def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]:
HostKey {host_key} HostKey {host_key}
LogLevel DEBUG3 LogLevel DEBUG3
# In the nix build sandbox we don't get any meaningful PATH after login # In the nix build sandbox we don't get any meaningful PATH after login
SetEnv PATH={os.environ.get("PATH", "")}
MaxStartups 64:30:256 MaxStartups 64:30:256
AuthorizedKeysFile {host_key}.pub AuthorizedKeysFile {host_key}.pub
AcceptEnv REALPATH
""" """
) )
login_shell = dir / "shell"
bash = shutil.which("bash")
path = os.environ["PATH"]
assert bash is not None
login_shell.write_text(
f"""#!{bash}
if [[ -f /etc/profile ]]; then
source /etc/profile
fi
if [[ -n "$REALPATH" ]]; then
export PATH="$REALPATH:${path}"
else
export PATH="${path}"
fi
exec {bash} -l "${{@}}"
"""
)
login_shell.chmod(0o755)
lib_path = None lib_path = None
if platform == "linux": assert (
# This enforces a login shell by overriding the login shell of `getpwnam(3)` platform == "linux"
lib_path = str(dir / "libgetpwnam-preload.so") ), "we do not support the ld_preload trick on non-linux just now"
subprocess.run(
[
os.environ.get("CC", "cc"),
"-shared",
"-o",
lib_path,
str(test_root / "getpwnam-preload.c"),
],
check=True,
)
yield SshdConfig(str(sshd_config), str(host_key), lib_path) # This enforces a login shell by overriding the login shell of `getpwnam(3)`
lib_path = dir / "libgetpwnam-preload.so"
subprocess.run(
[
os.environ.get("CC", "cc"),
"-shared",
"-o",
lib_path,
str(test_root / "getpwnam-preload.c"),
],
check=True,
)
yield SshdConfig(sshd_config, login_shell, str(host_key), lib_path)
@pytest.fixture @pytest.fixture
def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Sshd]: def sshd(sshd_config: SshdConfig, command: "Command", ports: "Ports") -> Iterator[Sshd]:
port = ports.allocate(1) port = ports.allocate(1)
sshd = shutil.which("sshd") sshd = shutil.which("sshd")
assert sshd is not None, "no sshd binary found" assert sshd is not None, "no sshd binary found"
env = {} env = {}
if sshd_config.preload_lib is not None: env = dict(
bash = shutil.which("bash") LD_PRELOAD=str(sshd_config.preload_lib),
assert bash is not None LOGIN_SHELL=str(sshd_config.login_shell),
env = dict(LD_PRELOAD=str(sshd_config.preload_lib), LOGIN_SHELL=bash) )
proc = command.run( proc = command.run(
[sshd, "-f", sshd_config.path, "-D", "-p", str(port)], extra_env=env [sshd, "-f", str(sshd_config.path), "-D", "-p", str(port)], extra_env=env
) )
while True: while True:

View File

@ -1,6 +1,6 @@
import subprocess import subprocess
from clan_cli.ssh import Group, Host, run from clan_cli.ssh import Host, HostGroup, run
def test_run() -> None: def test_run() -> None:
@ -20,7 +20,7 @@ def test_run_failure() -> None:
assert False, "Command should have raised an error" assert False, "Command should have raised an error"
hosts = Group([Host("some_host")]) hosts = HostGroup([Host("some_host")])
def test_run_environment() -> None: def test_run_environment() -> None:

View File

@ -1,89 +1,63 @@
import os
import pwd
import subprocess import subprocess
from sshd import Sshd from clan_cli.ssh import Host, HostGroup
from clan_cli.ssh import Group, Host, HostKeyCheck
def deploy_group(sshd: Sshd) -> Group: def test_run(host_group: HostGroup) -> None:
login = pwd.getpwuid(os.getuid()).pw_name proc = host_group.run("echo hello", stdout=subprocess.PIPE)
return Group(
[
Host(
"127.0.0.1",
port=sshd.port,
user=login,
key=sshd.key,
host_key_check=HostKeyCheck.NONE,
)
]
)
def test_run(sshd: Sshd) -> None:
g = deploy_group(sshd)
proc = g.run("echo hello", stdout=subprocess.PIPE)
assert proc[0].result.stdout == "hello\n" assert proc[0].result.stdout == "hello\n"
def test_run_environment(sshd: Sshd) -> None: def test_run_environment(host_group: HostGroup) -> None:
g = deploy_group(sshd) p1 = host_group.run(
p1 = g.run("echo $env_var", stdout=subprocess.PIPE, extra_env=dict(env_var="true")) "echo $env_var", stdout=subprocess.PIPE, extra_env=dict(env_var="true")
)
assert p1[0].result.stdout == "true\n" assert p1[0].result.stdout == "true\n"
p2 = g.run(["env"], stdout=subprocess.PIPE, extra_env=dict(env_var="true")) p2 = host_group.run(["env"], stdout=subprocess.PIPE, extra_env=dict(env_var="true"))
assert "env_var=true" in p2[0].result.stdout assert "env_var=true" in p2[0].result.stdout
def test_run_no_shell(sshd: Sshd) -> None: def test_run_no_shell(host_group: HostGroup) -> None:
g = deploy_group(sshd) proc = host_group.run(["echo", "$hello"], stdout=subprocess.PIPE)
proc = g.run(["echo", "$hello"], stdout=subprocess.PIPE)
assert proc[0].result.stdout == "$hello\n" assert proc[0].result.stdout == "$hello\n"
def test_run_function(sshd: Sshd) -> None: def test_run_function(host_group: HostGroup) -> None:
def some_func(h: Host) -> bool: def some_func(h: Host) -> bool:
p = h.run("echo hello", stdout=subprocess.PIPE) p = h.run("echo hello", stdout=subprocess.PIPE)
return p.stdout == "hello\n" return p.stdout == "hello\n"
g = deploy_group(sshd) res = host_group.run_function(some_func)
res = g.run_function(some_func)
assert res[0].result assert res[0].result
def test_timeout(sshd: Sshd) -> None: def test_timeout(host_group: HostGroup) -> None:
g = deploy_group(sshd)
try: try:
g.run_local("sleep 10", timeout=0.01) host_group.run_local("sleep 10", timeout=0.01)
except Exception: except Exception:
pass pass
else: else:
assert False, "should have raised TimeoutExpired" assert False, "should have raised TimeoutExpired"
def test_run_exception(sshd: Sshd) -> None: def test_run_exception(host_group: HostGroup) -> None:
g = deploy_group(sshd) r = host_group.run("exit 1", check=False)
r = g.run("exit 1", check=False)
assert r[0].result.returncode == 1 assert r[0].result.returncode == 1
try: try:
g.run("exit 1") host_group.run("exit 1")
except Exception: except Exception:
pass pass
else: else:
assert False, "should have raised Exception" assert False, "should have raised Exception"
def test_run_function_exception(sshd: Sshd) -> None: def test_run_function_exception(host_group: HostGroup) -> None:
def some_func(h: Host) -> subprocess.CompletedProcess[str]: def some_func(h: Host) -> subprocess.CompletedProcess[str]:
return h.run_local("exit 1") return h.run_local("exit 1")
g = deploy_group(sshd)
try: try:
g.run_function(some_func) host_group.run_function(some_func)
except Exception: except Exception:
pass pass
else: else:

View File

@ -0,0 +1,35 @@
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from environment import mock_env
from host_group import HostGroup
from clan_cli.update import deploy_nixos
def test_update(clan_flake: Path, host_group: HostGroup) -> None:
assert len(host_group.hosts) == 1
host = host_group.hosts[0]
with TemporaryDirectory() as tmpdir:
host.meta["flake_uri"] = clan_flake
host.meta["flake_path"] = str(Path(tmpdir) / "rsync-target")
host.ssh_options["SendEnv"] = "REALPATH"
bin = Path(tmpdir).joinpath("bin")
bin.mkdir()
nixos_rebuild = bin.joinpath("nixos-rebuild")
bash = shutil.which("bash")
assert bash is not None
nixos_rebuild.write_text(
f"""#!{bash}
exit 0
"""
)
nixos_rebuild.chmod(0o755)
path = f"{tmpdir}/bin:{os.environ['PATH']}"
nix_state_dir = Path(tmpdir).joinpath("nix")
nix_state_dir.mkdir()
with mock_env(REALPATH=path):
deploy_nixos(host_group)