add test for remote ssh commands #115

Merged
clan-bot merged 1 commits from Mic92-mic92 into main 2023-08-10 09:05:48 +00:00
8 changed files with 356 additions and 4 deletions

View File

@ -16,6 +16,8 @@
, pytest
, pytest-cov
, pytest-subprocess
, openssh
, stdenv
, wheel
}:
let
@ -26,6 +28,8 @@ let
pytest-cov
pytest-subprocess
mypy
openssh
stdenv.cc
];
checkPython = python3.withPackages (_ps: dependencies ++ testDependencies);
@ -50,7 +54,7 @@ python3.pkgs.buildPythonPackage {
'';
clan-pytest = runCommand "clan-tests"
{
nativeBuildInputs = [ age zerotierone bubblewrap sops nix ];
nativeBuildInputs = [ age zerotierone bubblewrap sops nix openssh stdenv.cc ];
} ''
cp -r ${./.} ./src
chmod +w -R ./src

View File

@ -0,0 +1,60 @@
import os
import signal
import subprocess
from typing import IO, Any, Dict, Iterator, List, Union
import pytest
_FILE = Union[None, int, IO[Any]]
class Command:
def __init__(self) -> None:
self.processes: List[subprocess.Popen[str]] = []
def run(
self,
command: List[str],
extra_env: Dict[str, str] = {},
stdin: _FILE = None,
stdout: _FILE = None,
stderr: _FILE = None,
) -> subprocess.Popen[str]:
env = os.environ.copy()
env.update(extra_env)
# We start a new session here so that we can than more reliably kill all childs as well
p = subprocess.Popen(
command,
env=env,
start_new_session=True,
stdout=stdout,
stderr=stderr,
stdin=stdin,
text=True,
)
self.processes.append(p)
return p
def terminate(self) -> None:
# Stop in reverse order in case there are dependencies.
# We just kill all processes as quickly as possible because we don't
# care about corrupted state and want to make tests fasts.
for p in reversed(self.processes):
try:
os.killpg(os.getpgid(p.pid), signal.SIGKILL)
except OSError:
pass
@pytest.fixture
def command() -> Iterator[Command]:
"""
Starts a background command. The process is automatically terminated in the end.
>>> p = command.run(["some", "daemon"])
>>> print(p.pid)
"""
c = Command()
try:
yield c
finally:
c.terminate()

View File

@ -3,4 +3,12 @@ import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
pytest_plugins = ["temporary_dir", "clan_flake", "root", "age_keys"]
pytest_plugins = [
"temporary_dir",
"clan_flake",
"root",
"age_keys",
"sshd",
"command",
"ports",
]

View File

@ -0,0 +1,27 @@
#define _GNU_SOURCE
#include <string.h>
#include <sys/types.h>
#include <pwd.h>
#include <dlfcn.h>
#include <stdlib.h>
#include <stdio.h>
typedef struct passwd *(*getpwnam_type)(const char *name);
struct passwd *getpwnam(const char *name) {
struct passwd *pw;
getpwnam_type orig_getpwnam;
orig_getpwnam = (getpwnam_type)dlsym(RTLD_NEXT, "getpwnam");
pw = orig_getpwnam(name);
if (pw) {
const char *shell = getenv("LOGIN_SHELL");
if (!shell) {
fprintf(stderr, "no LOGIN_SHELL set\n");
exit(1);
}
fprintf(stderr, "SHELL:%s\n", shell);
pw->pw_shell = strdup(shell);
}
return pw;
}

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python3
import socket
import pytest
NEXT_PORT = 10000
def check_port(port: int) -> bool:
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
with tcp, udp:
try:
tcp.bind(("127.0.0.1", port))
udp.bind(("127.0.0.1", port))
return True
except socket.error:
return False
def check_port_range(port_range: range) -> bool:
for port in port_range:
if not check_port(port):
return False
return True
class Ports:
def allocate(self, num: int) -> int:
"""
Allocates
"""
global NEXT_PORT
while NEXT_PORT + num <= 65535:
start = NEXT_PORT
NEXT_PORT += num
if not check_port_range(range(start, NEXT_PORT)):
continue
return start
raise Exception("cannot find enough free port")
@pytest.fixture
def ports() -> Ports:
return Ports()

View File

@ -6,7 +6,7 @@ TEST_ROOT = Path(__file__).parent.resolve()
PROJECT_ROOT = TEST_ROOT.parent
@pytest.fixture
@pytest.fixture(scope="session")
def project_root() -> Path:
"""
Root directory of the tests
@ -14,7 +14,7 @@ def project_root() -> Path:
return PROJECT_ROOT
@pytest.fixture
@pytest.fixture(scope="session")
def test_root() -> Path:
"""
Root directory of the tests

117
pkgs/clan-cli/tests/sshd.py Normal file
View File

@ -0,0 +1,117 @@
import os
import shutil
import subprocess
import time
from pathlib import Path
from sys import platform
from tempfile import TemporaryDirectory
from typing import Iterator, Optional
import pytest
from command import Command
from ports import Ports
class Sshd:
def __init__(self, port: int, proc: subprocess.Popen[str], key: str) -> None:
self.port = port
self.proc = proc
self.key = key
class SshdConfig:
def __init__(self, path: str, key: str, preload_lib: Optional[str]) -> None:
self.path = path
self.key = key
self.preload_lib = preload_lib
@pytest.fixture(scope="session")
def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]:
# FIXME, if any parent of `project_root` is world-writable than sshd will refuse it.
with TemporaryDirectory(dir=project_root) as _dir:
dir = Path(_dir)
host_key = dir / "host_ssh_host_ed25519_key"
subprocess.run(
[
"ssh-keygen",
"-t",
"ed25519",
"-f",
host_key,
"-N",
"",
],
check=True,
)
sshd_config = dir / "sshd_config"
sshd_config.write_text(
f"""
HostKey {host_key}
LogLevel DEBUG3
# In the nix build sandbox we don't get any meaningful PATH after login
SetEnv PATH={os.environ.get("PATH", "")}
MaxStartups 64:30:256
AuthorizedKeysFile {host_key}.pub
"""
)
lib_path = None
if platform == "linux":
# This enforces a login shell by overriding the login shell of `getpwnam(3)`
lib_path = str(dir / "libgetpwnam-preload.so")
subprocess.run(
[
os.environ.get("CC", "cc"),
"-shared",
"-o",
lib_path,
str(test_root / "getpwnam-preload.c"),
],
check=True,
)
yield SshdConfig(str(sshd_config), str(host_key), lib_path)
@pytest.fixture
def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Sshd]:
port = ports.allocate(1)
sshd = shutil.which("sshd")
assert sshd is not None, "no sshd binary found"
env = {}
if sshd_config.preload_lib is not None:
bash = shutil.which("bash")
assert bash is not None
env = dict(LD_PRELOAD=str(sshd_config.preload_lib), LOGIN_SHELL=bash)
proc = command.run(
[sshd, "-f", sshd_config.path, "-D", "-p", str(port)], extra_env=env
)
while True:
if (
subprocess.run(
[
"ssh",
"-o",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
"-i",
sshd_config.key,
"localhost",
"-p",
str(port),
"true",
]
).returncode
== 0
):
yield Sshd(port, proc, sshd_config.key)
return
else:
rc = proc.poll()
if rc is not None:
raise Exception(f"sshd processes was terminated with {rc}")
time.sleep(0.1)

View File

@ -0,0 +1,90 @@
import os
import pwd
import subprocess
from sshd import Sshd
from clan_cli.ssh import Group, Host, HostKeyCheck
def deploy_group(sshd: Sshd) -> Group:
login = pwd.getpwuid(os.getuid()).pw_name
return Group(
[
Host(
"127.0.0.1",
port=sshd.port,
user=login,
key=sshd.key,
host_key_check=HostKeyCheck.NONE,
)
]
)
def test_run(sshd: Sshd) -> None:
g = deploy_group(sshd)
proc = g.run("echo hello", stdout=subprocess.PIPE)
assert proc[0].result.stdout == "hello\n"
def test_run_environment(sshd: Sshd) -> None:
g = deploy_group(sshd)
p1 = g.run("echo $env_var", stdout=subprocess.PIPE, extra_env=dict(env_var="true"))
assert p1[0].result.stdout == "true\n"
p2 = g.run(["env"], stdout=subprocess.PIPE, extra_env=dict(env_var="true"))
assert "env_var=true" in p2[0].result.stdout
def test_run_no_shell(sshd: Sshd) -> None:
g = deploy_group(sshd)
proc = g.run(["echo", "$hello"], stdout=subprocess.PIPE)
assert proc[0].result.stdout == "$hello\n"
def test_run_function(sshd: Sshd) -> None:
def some_func(h: Host) -> bool:
p = h.run("echo hello", stdout=subprocess.PIPE)
return p.stdout == "hello\n"
g = deploy_group(sshd)
res = g.run_function(some_func)
assert res[0].result
def test_timeout(sshd: Sshd) -> None:
g = deploy_group(sshd)
try:
g.run_local("sleep 10", timeout=0.01)
except Exception:
pass
else:
assert False, "should have raised TimeoutExpired"
def test_run_exception(sshd: Sshd) -> None:
g = deploy_group(sshd)
r = g.run("exit 1", check=False)
assert r[0].result.returncode == 1
try:
g.run("exit 1")
except Exception:
pass
else:
assert False, "should have raised Exception"
def test_run_function_exception(sshd: Sshd) -> None:
def some_func(h: Host) -> subprocess.CompletedProcess[str]:
return h.run_local("exit 1")
g = deploy_group(sshd)
try:
g.run_function(some_func)
except Exception:
pass
else:
assert False, "should have raised Exception"