Merge pull request 'refactor facts command to regenerate facts for all machines' (#1223) from parallelism into main
All checks were successful
checks / checks (push) Successful in 36s
checks / checks-impure (push) Successful in 1m53s

This commit is contained in:
clan-bot 2024-04-15 20:35:22 +00:00
commit f8b881c41e
12 changed files with 203 additions and 140 deletions

View File

@ -50,7 +50,7 @@ in
clanInternals = lib.mkOption {
type = lib.types.submodule {
options = {
all-machines-json = lib.mkOption { type = lib.types.attrsOf lib.types.str; };
all-machines-json = lib.mkOption { type = lib.types.attrsOf lib.types.unspecified; };
machines = lib.mkOption { type = lib.types.attrsOf (lib.types.attrsOf lib.types.unspecified); };
machinesFunc = lib.mkOption { type = lib.types.attrsOf (lib.types.attrsOf lib.types.unspecified); };
};

View File

@ -27,14 +27,14 @@ def check_secrets(machine: Machine, service: None | str = None) -> bool:
secret_name = secret_fact["name"]
if not secret_facts_store.exists(service, secret_name):
log.info(
f"Secret fact '{secret_fact}' for service {service} is missing."
f"Secret fact '{secret_fact}' for service '{service}' in machine {machine.name} is missing."
)
missing_secret_facts.append((service, secret_name))
for public_fact in machine.facts_data[service]["public"]:
if not public_facts_store.exists(service, public_fact):
log.info(
f"Public fact '{public_fact}' for service {service} is missing."
f"Public fact '{public_fact}' for service '{service}' in machine {machine.name} is missing."
)
missing_public_facts.append((service, public_fact))

View File

@ -11,6 +11,7 @@ from clan_cli.cmd import run
from ..errors import ClanError
from ..git import commit_files
from ..machines.inventory import get_all_machines, get_selected_machines
from ..machines.machines import Machine
from ..nix import nix_shell
from .check import check_secrets
@ -127,53 +128,76 @@ def generate_service_facts(
return True
def generate_facts(
machine: Machine,
prompt: None | Callable[[str], str] = None,
def prompt_func(text: str) -> str:
print(f"{text}: ")
return read_multiline_input()
def _generate_facts_for_machine(
machine: Machine, tmpdir: Path, prompt: Callable[[str], str] = prompt_func
) -> bool:
local_temp = tmpdir / machine.name
local_temp.mkdir()
secret_facts_module = importlib.import_module(machine.secret_facts_module)
secret_facts_store = secret_facts_module.SecretStore(machine=machine)
public_facts_module = importlib.import_module(machine.public_facts_module)
public_facts_store = public_facts_module.FactStore(machine=machine)
if prompt is None:
machine_updated = False
for service in machine.facts_data:
machine_updated |= generate_service_facts(
machine=machine,
service=service,
secret_facts_store=secret_facts_store,
public_facts_store=public_facts_store,
tmpdir=local_temp,
prompt=prompt,
)
if machine_updated:
# flush caches to make sure the new secrets are available in evaluation
machine.flush_caches()
return machine_updated
def prompt_func(text: str) -> str:
print(f"{text}: ")
return read_multiline_input()
prompt = prompt_func
def generate_facts(
machines: list[Machine], prompt: Callable[[str], str] = prompt_func
) -> bool:
was_regenerated = False
with TemporaryDirectory() as tmp:
tmpdir = Path(tmp)
for service in machine.facts_data:
was_regenerated |= generate_service_facts(
machine=machine,
service=service,
secret_facts_store=secret_facts_store,
public_facts_store=public_facts_store,
tmpdir=tmpdir,
prompt=prompt,
)
if was_regenerated:
# flush caches to make sure the new secrets are available in evaluation
machine.flush_caches()
else:
for machine in machines:
errors = 0
try:
was_regenerated |= _generate_facts_for_machine(machine, tmpdir, prompt)
except Exception as exc:
log.error(f"Failed to generate facts for {machine.name}: {exc}")
errors += 1
if errors > 0:
raise ClanError(
f"Failed to generate facts for {errors} hosts. Check the logs above"
)
if not was_regenerated:
print("All secrets and facts are already up to date")
return was_regenerated
def generate_command(args: argparse.Namespace) -> None:
machine = Machine(name=args.machine, flake=args.flake)
generate_facts(machine)
if len(args.machines) == 0:
machines = get_all_machines(args.flake)
else:
machines = get_selected_machines(args.flake, args.machines)
generate_facts(machines)
def register_generate_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"machine",
help="The machine to generate facts for",
"machines",
type=str,
help="machine to generate facts for. if empty, generate facts for all machines",
nargs="*",
default=[],
)
parser.set_defaults(func=generate_command)

View File

@ -5,6 +5,7 @@ from clan_cli.errors import ClanError
from clan_cli.nix import nix_shell
from .cmd import Log, run
from .locked_open import locked_open
def commit_file(
@ -55,38 +56,45 @@ def _commit_file_to_git(
:param commit_message: The commit message.
:raises ClanError: If the file is not in the git repository.
"""
for file_path in file_paths:
with locked_open(repo_dir / ".git" / "clan.lock", "w+"):
for file_path in file_paths:
cmd = nix_shell(
["nixpkgs#git"],
["git", "-C", str(repo_dir), "add", str(file_path)],
)
# add the file to the git index
run(
cmd,
log=Log.BOTH,
error_msg=f"Failed to add {file_path} file to git index",
)
# check if there is a diff
cmd = nix_shell(
["nixpkgs#git"],
["git", "-C", str(repo_dir), "add", str(file_path)],
["git", "-C", str(repo_dir), "diff", "--cached", "--exit-code"]
+ [str(file_path) for file_path in file_paths],
)
# add the file to the git index
result = run(cmd, check=False, cwd=repo_dir)
# if there is no diff, return
if result.returncode == 0:
return
run(cmd, log=Log.BOTH, error_msg=f"Failed to add {file_path} file to git index")
# commit only that file
cmd = nix_shell(
["nixpkgs#git"],
[
"git",
"-C",
str(repo_dir),
"commit",
"-m",
commit_message,
]
+ [str(file_path) for file_path in file_paths],
)
# check if there is a diff
cmd = nix_shell(
["nixpkgs#git"],
["git", "-C", str(repo_dir), "diff", "--cached", "--exit-code"]
+ [str(file_path) for file_path in file_paths],
)
result = run(cmd, check=False, cwd=repo_dir)
# if there is no diff, return
if result.returncode == 0:
return
# commit only that file
cmd = nix_shell(
["nixpkgs#git"],
[
"git",
"-C",
str(repo_dir),
"commit",
"-m",
commit_message,
]
+ [str(file_path) for file_path in file_paths],
)
run(cmd, error_msg=f"Failed to commit {file_paths} to git repository {repo_dir}")
run(
cmd, error_msg=f"Failed to commit {file_paths} to git repository {repo_dir}"
)

View File

@ -11,7 +11,7 @@ from .dirs import user_history_file
@contextmanager
def _locked_open(filename: str | Path, mode: str = "r") -> Generator:
def locked_open(filename: str | Path, mode: str = "r") -> Generator:
"""
This is a context manager that provides an advisory write lock on the file specified by `filename` when entering the context, and releases the lock when leaving the context. The lock is acquired using the `fcntl` module's `LOCK_EX` flag, which applies an exclusive write lock to the file.
"""
@ -22,12 +22,12 @@ def _locked_open(filename: str | Path, mode: str = "r") -> Generator:
def write_history_file(data: Any) -> None:
with _locked_open(user_history_file(), "w+") as f:
with locked_open(user_history_file(), "w+") as f:
f.write(json.dumps(data, cls=ClanJSONEncoder, indent=4))
def read_history_file() -> list[dict]:
with _locked_open(user_history_file(), "r") as f:
with locked_open(user_history_file(), "r") as f:
content: str = f.read()
parsed: list[dict] = json.loads(content)
return parsed

View File

@ -24,7 +24,7 @@ def install_nixos(
target_host = f"{h.user or 'root'}@{h.host}"
log.info(f"target host: {target_host}")
generate_facts(machine)
generate_facts([machine])
with TemporaryDirectory() as tmpdir_:
tmpdir = Path(tmpdir_)

View File

@ -0,0 +1,31 @@
import json
from pathlib import Path
from ..cmd import run
from ..nix import nix_build, nix_config
from .machines import Machine
# function to speedup eval if we want to evauluate all machines
def get_all_machines(flake_dir: Path) -> list[Machine]:
config = nix_config()
system = config["system"]
json_path = run(
nix_build([f'{flake_dir}#clanInternals.all-machines-json."{system}"'])
).stdout
machines_json = json.loads(Path(json_path.rstrip()).read_text())
machines = []
for name, machine_data in machines_json.items():
machines.append(
Machine(name=name, flake=flake_dir, deployment_info=machine_data)
)
return machines
def get_selected_machines(flake_dir: Path, machine_names: list[str]) -> list[Machine]:
machines = []
for name in machine_names:
machines.append(Machine(name=name, flake=flake_dir))
return machines

View File

@ -0,0 +1,26 @@
from collections.abc import Callable
from typing import TypeVar
from ..ssh import Host, HostGroup, HostResult
from .machines import Machine
T = TypeVar("T")
class MachineGroup:
def __init__(self, machines: list[Machine]) -> None:
self.group = HostGroup(list(m.target_host for m in machines))
def run_function(
self, func: Callable[[Machine], T], check: bool = True
) -> list[HostResult[T]]:
"""
Function to run for each host in the group in parallel
@func the function to call
"""
def wrapped_func(host: Host) -> T:
return func(host.meta["machine"])
return self.group.run_function(wrapped_func, check=check)

View File

@ -5,15 +5,15 @@ import os
import shlex
import subprocess
import sys
from pathlib import Path
from ..cmd import run
from ..errors import ClanError
from ..facts.generate import generate_facts
from ..facts.upload import upload_secrets
from ..machines.machines import Machine
from ..nix import nix_build, nix_command, nix_config, nix_metadata
from ..ssh import Host, HostGroup, HostKeyCheck, parse_deployment_address
from ..nix import nix_command, nix_metadata
from ..ssh import HostKeyCheck
from .inventory import get_all_machines, get_selected_machines
from .machine_group import MachineGroup
log = logging.getLogger(__name__)
@ -86,31 +86,31 @@ def upload_sources(
)
def deploy_nixos(hosts: HostGroup) -> None:
def deploy_nixos(machines: MachineGroup) -> None:
"""
Deploy to all hosts in parallel
"""
def deploy(h: Host) -> None:
target = f"{h.user or 'root'}@{h.host}"
ssh_arg = f"-p {h.port}" if h.port else ""
def deploy(machine: Machine) -> None:
host = machine.build_host
target = f"{host.user or 'root'}@{host.host}"
ssh_arg = f"-p {host.port}" if host.port else ""
env = os.environ.copy()
env["NIX_SSHOPTS"] = ssh_arg
machine: Machine = h.meta["machine"]
generate_facts(machine)
generate_facts([machine])
upload_secrets(machine)
path = upload_sources(".", target)
if h.host_key_check != HostKeyCheck.STRICT:
if host.host_key_check != HostKeyCheck.STRICT:
ssh_arg += " -o StrictHostKeyChecking=no"
if h.host_key_check == HostKeyCheck.NONE:
if host.host_key_check == HostKeyCheck.NONE:
ssh_arg += " -o UserKnownHostsFile=/dev/null"
ssh_arg += " -i " + h.key if h.key else ""
ssh_arg += " -i " + host.key if host.key else ""
extra_args = h.meta.get("extra_args", [])
extra_args = host.meta.get("extra_args", [])
cmd = [
"nixos-rebuild",
"switch",
@ -127,82 +127,55 @@ def deploy_nixos(hosts: HostGroup) -> None:
"--flake",
f"{path}#{machine.name}",
]
if target_host := h.meta.get("target_host"):
if target_host := host.meta.get("target_host"):
target_host = f"{target_host.user or 'root'}@{target_host.host}"
cmd.extend(["--target-host", target_host])
ret = h.run(cmd, check=False)
ret = host.run(cmd, check=False)
# re-retry switch if the first time fails
if ret.returncode != 0:
ret = h.run(cmd)
ret = host.run(cmd)
hosts.run_function(deploy)
machines.run_function(deploy)
# function to speedup eval if we want to evauluate all machines
def get_all_machines(clan_dir: Path) -> HostGroup:
config = nix_config()
system = config["system"]
machines_json = run(
nix_build([f'{clan_dir}#clanInternals.all-machines-json."{system}"'])
).stdout
machines = json.loads(Path(machines_json.rstrip()).read_text())
hosts = []
ignored_machines = []
for name, machine_data in machines.items():
if machine_data.get("requireExplicitUpdate", False):
continue
machine = Machine(name=name, flake=clan_dir, deployment_info=machine_data)
try:
hosts.append(machine.build_host)
except ClanError:
ignored_machines.append(name)
continue
if not hosts and ignored_machines != []:
print(
"WARNING: No machines to update. The following defined machines were ignored because they do not have `clan.networking.targetHost` nixos option set:",
file=sys.stderr,
)
for machine in ignored_machines:
print(machine, file=sys.stderr)
# very hacky. would be better to do a MachinesGroup instead
return HostGroup(hosts)
def get_selected_machines(machine_names: list[str], flake_dir: Path) -> HostGroup:
hosts = []
for name in machine_names:
machine = Machine(name=name, flake=flake_dir)
hosts.append(machine.build_host)
return HostGroup(hosts)
# FIXME: we want some kind of inventory here.
def update(args: argparse.Namespace) -> None:
if args.flake is None:
raise ClanError("Could not find clan flake toplevel directory")
machines = []
if len(args.machines) == 1 and args.target_host is not None:
machine = Machine(name=args.machines[0], flake=args.flake)
machine.target_host_address = args.target_host
host = parse_deployment_address(
args.machines[0],
args.target_host,
meta={"machine": machine},
)
machines = HostGroup([host])
machines.append(machine)
elif args.target_host is not None:
print("target host can only be specified for a single machine")
exit(1)
else:
if len(args.machines) == 0:
machines = get_all_machines(args.flake)
else:
machines = get_selected_machines(args.machines, args.flake)
ignored_machines = []
for machine in get_all_machines(args.flake):
if machine.deployment_info.get("requireExplicitUpdate", False):
continue
try:
machine.build_host
except ClanError: # check if we have a build host set
ignored_machines.append(machine)
continue
deploy_nixos(machines)
machines.append(machine)
if not machines and ignored_machines != []:
print(
"WARNING: No machines to update. The following defined machines were ignored because they do not have `clan.networking.targetHost` nixos option set:",
file=sys.stderr,
)
for machine in ignored_machines:
print(machine, file=sys.stderr)
else:
machines = get_selected_machines(args.flake, args.machines)
deploy_nixos(MachineGroup(machines))
def register_update_parser(parser: argparse.ArgumentParser) -> None:

View File

@ -18,6 +18,8 @@ from shlex import quote
from threading import Thread
from typing import IO, Any, Generic, TypeVar
from ..errors import ClanError
# https://no-color.org
DISABLE_COLOR = not sys.stderr.isatty() or os.environ.get("NO_COLOR", "") != ""
@ -285,7 +287,7 @@ class Host:
elif stdout == subprocess.PIPE:
stdout_read, stdout_write = stack.enter_context(_pipe())
else:
raise Exception(f"unsupported value for stdout parameter: {stdout}")
raise ClanError(f"unsupported value for stdout parameter: {stdout}")
if stderr is None:
stderr_read = None
@ -293,7 +295,7 @@ class Host:
elif stderr == subprocess.PIPE:
stderr_read, stderr_write = stack.enter_context(_pipe())
else:
raise Exception(f"unsupported value for stderr parameter: {stderr}")
raise ClanError(f"unsupported value for stderr parameter: {stderr}")
env = os.environ.copy()
env.update(extra_env)
@ -610,7 +612,7 @@ class HostGroup:
)
errors += 1
if errors > 0:
raise Exception(
raise ClanError(
f"{errors} hosts failed with an error. Check the logs above"
)

View File

@ -69,7 +69,7 @@ def get_secrets(
secret_facts_module = importlib.import_module(machine.secret_facts_module)
secret_facts_store = secret_facts_module.SecretStore(machine=machine)
generate_facts(machine)
generate_facts([machine])
secret_facts_store.upload(secrets_dir)
return secrets_dir

View File

@ -59,8 +59,8 @@ def test_generate_secret(
age_key_mtime = age_key.lstat().st_mtime_ns
secret1_mtime = identity_secret.lstat().st_mtime_ns
# test idempotency
cli.run(["facts", "generate", "vm1"])
# test idempotency for vm1 and also generate for vm2
cli.run(["facts", "generate"])
assert age_key.lstat().st_mtime_ns == age_key_mtime
assert identity_secret.lstat().st_mtime_ns == secret1_mtime
@ -68,7 +68,6 @@ def test_generate_secret(
secrets_folder / "vm1-zerotier-identity-secret" / "machines" / "vm1"
).exists()
cli.run(["facts", "generate", "vm2"])
assert has_secret(test_flake_with_core.path, "vm2-age.key")
assert has_secret(test_flake_with_core.path, "vm2-zerotier-identity-secret")
ip = machine_get_fact(test_flake_with_core.path, "vm1", "zerotier-ip")