refactor facts command to regenerate facts for all machines

This commit is contained in:
Jörg Thalheim 2024-04-15 21:06:03 +02:00
parent 060e3baa08
commit f385e0e037
10 changed files with 200 additions and 137 deletions

View File

@ -11,6 +11,7 @@ from clan_cli.cmd import run
from ..errors import ClanError from ..errors import ClanError
from ..git import commit_files from ..git import commit_files
from ..machines.inventory import get_all_machines, get_selected_machines
from ..machines.machines import Machine from ..machines.machines import Machine
from ..nix import nix_shell from ..nix import nix_shell
from .check import check_secrets from .check import check_secrets
@ -127,53 +128,76 @@ def generate_service_facts(
return True return True
def generate_facts( def prompt_func(text: str) -> str:
machine: Machine, print(f"{text}: ")
prompt: None | Callable[[str], str] = None, return read_multiline_input()
def _generate_facts_for_machine(
machine: Machine, tmpdir: Path, prompt: Callable[[str], str] = prompt_func
) -> bool: ) -> bool:
local_temp = tmpdir / machine.name
local_temp.mkdir()
secret_facts_module = importlib.import_module(machine.secret_facts_module) secret_facts_module = importlib.import_module(machine.secret_facts_module)
secret_facts_store = secret_facts_module.SecretStore(machine=machine) secret_facts_store = secret_facts_module.SecretStore(machine=machine)
public_facts_module = importlib.import_module(machine.public_facts_module) public_facts_module = importlib.import_module(machine.public_facts_module)
public_facts_store = public_facts_module.FactStore(machine=machine) public_facts_store = public_facts_module.FactStore(machine=machine)
if prompt is None: machine_updated = False
def prompt_func(text: str) -> str:
print(f"{text}: ")
return read_multiline_input()
prompt = prompt_func
was_regenerated = False
with TemporaryDirectory() as tmp:
tmpdir = Path(tmp)
for service in machine.facts_data: for service in machine.facts_data:
was_regenerated |= generate_service_facts( machine_updated |= generate_service_facts(
machine=machine, machine=machine,
service=service, service=service,
secret_facts_store=secret_facts_store, secret_facts_store=secret_facts_store,
public_facts_store=public_facts_store, public_facts_store=public_facts_store,
tmpdir=tmpdir, tmpdir=local_temp,
prompt=prompt, prompt=prompt,
) )
if machine_updated:
if was_regenerated:
# flush caches to make sure the new secrets are available in evaluation # flush caches to make sure the new secrets are available in evaluation
machine.flush_caches() machine.flush_caches()
else: return machine_updated
def generate_facts(
machines: list[Machine], prompt: Callable[[str], str] = prompt_func
) -> bool:
was_regenerated = False
with TemporaryDirectory() as tmp:
tmpdir = Path(tmp)
for machine in machines:
errors = 0
try:
was_regenerated |= _generate_facts_for_machine(machine, tmpdir, prompt)
except Exception as exc:
log.error(f"Failed to generate facts for {machine.name}: {exc}")
errors += 1
if errors > 0:
raise ClanError(
f"Failed to generate facts for {errors} hosts. Check the logs above"
)
if not was_regenerated:
print("All secrets and facts are already up to date") print("All secrets and facts are already up to date")
return was_regenerated return was_regenerated
def generate_command(args: argparse.Namespace) -> None: def generate_command(args: argparse.Namespace) -> None:
machine = Machine(name=args.machine, flake=args.flake) if len(args.machines) == 0:
generate_facts(machine) machines = get_all_machines(args.flake)
else:
machines = get_selected_machines(args.flake, args.machines)
generate_facts(machines)
def register_generate_parser(parser: argparse.ArgumentParser) -> None: def register_generate_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument( parser.add_argument(
"machine", "machines",
help="The machine to generate facts for", type=str,
help="machine to generate facts for. if empty, generate facts for all machines",
nargs="*",
default=[],
) )
parser.set_defaults(func=generate_command) parser.set_defaults(func=generate_command)

View File

@ -5,6 +5,7 @@ from clan_cli.errors import ClanError
from clan_cli.nix import nix_shell from clan_cli.nix import nix_shell
from .cmd import Log, run from .cmd import Log, run
from .locked_open import locked_open
def commit_file( def commit_file(
@ -55,6 +56,7 @@ def _commit_file_to_git(
:param commit_message: The commit message. :param commit_message: The commit message.
:raises ClanError: If the file is not in the git repository. :raises ClanError: If the file is not in the git repository.
""" """
with locked_open(repo_dir / ".git" / "clan.lock", "w+"):
for file_path in file_paths: for file_path in file_paths:
cmd = nix_shell( cmd = nix_shell(
["nixpkgs#git"], ["nixpkgs#git"],
@ -62,7 +64,11 @@ def _commit_file_to_git(
) )
# add the file to the git index # add the file to the git index
run(cmd, log=Log.BOTH, error_msg=f"Failed to add {file_path} file to git index") run(
cmd,
log=Log.BOTH,
error_msg=f"Failed to add {file_path} file to git index",
)
# check if there is a diff # check if there is a diff
cmd = nix_shell( cmd = nix_shell(
@ -89,4 +95,6 @@ def _commit_file_to_git(
+ [str(file_path) for file_path in file_paths], + [str(file_path) for file_path in file_paths],
) )
run(cmd, error_msg=f"Failed to commit {file_paths} to git repository {repo_dir}") run(
cmd, error_msg=f"Failed to commit {file_paths} to git repository {repo_dir}"
)

View File

@ -11,7 +11,7 @@ from .dirs import user_history_file
@contextmanager @contextmanager
def _locked_open(filename: str | Path, mode: str = "r") -> Generator: def locked_open(filename: str | Path, mode: str = "r") -> Generator:
""" """
This is a context manager that provides an advisory write lock on the file specified by `filename` when entering the context, and releases the lock when leaving the context. The lock is acquired using the `fcntl` module's `LOCK_EX` flag, which applies an exclusive write lock to the file. This is a context manager that provides an advisory write lock on the file specified by `filename` when entering the context, and releases the lock when leaving the context. The lock is acquired using the `fcntl` module's `LOCK_EX` flag, which applies an exclusive write lock to the file.
""" """
@ -22,12 +22,12 @@ def _locked_open(filename: str | Path, mode: str = "r") -> Generator:
def write_history_file(data: Any) -> None: def write_history_file(data: Any) -> None:
with _locked_open(user_history_file(), "w+") as f: with locked_open(user_history_file(), "w+") as f:
f.write(json.dumps(data, cls=ClanJSONEncoder, indent=4)) f.write(json.dumps(data, cls=ClanJSONEncoder, indent=4))
def read_history_file() -> list[dict]: def read_history_file() -> list[dict]:
with _locked_open(user_history_file(), "r") as f: with locked_open(user_history_file(), "r") as f:
content: str = f.read() content: str = f.read()
parsed: list[dict] = json.loads(content) parsed: list[dict] = json.loads(content)
return parsed return parsed

View File

@ -24,7 +24,7 @@ def install_nixos(
target_host = f"{h.user or 'root'}@{h.host}" target_host = f"{h.user or 'root'}@{h.host}"
log.info(f"target host: {target_host}") log.info(f"target host: {target_host}")
generate_facts(machine) generate_facts([machine])
with TemporaryDirectory() as tmpdir_: with TemporaryDirectory() as tmpdir_:
tmpdir = Path(tmpdir_) tmpdir = Path(tmpdir_)

View File

@ -0,0 +1,31 @@
import json
from pathlib import Path
from ..cmd import run
from ..nix import nix_build, nix_config
from .machines import Machine
# function to speedup eval if we want to evauluate all machines
def get_all_machines(flake_dir: Path) -> list[Machine]:
config = nix_config()
system = config["system"]
json_path = run(
nix_build([f'{flake_dir}#clanInternals.all-machines-json."{system}"'])
).stdout
machines_json = json.loads(Path(json_path.rstrip()).read_text())
machines = []
for name, machine_data in machines_json.items():
machines.append(
Machine(name=name, flake=flake_dir, deployment_info=machine_data)
)
return machines
def get_selected_machines(flake_dir: Path, machine_names: list[str]) -> list[Machine]:
machines = []
for name in machine_names:
machines.append(Machine(name=name, flake=flake_dir))
return machines

View File

@ -0,0 +1,26 @@
from collections.abc import Callable
from typing import TypeVar
from ..ssh import Host, HostGroup, HostResult
from .machines import Machine
T = TypeVar("T")
class MachineGroup:
def __init__(self, machines: list[Machine]) -> None:
self.group = HostGroup(list(m.target_host for m in machines))
def run_function(
self, func: Callable[[Machine], T], check: bool = True
) -> list[HostResult[T]]:
"""
Function to run for each host in the group in parallel
@func the function to call
"""
def wrapped_func(host: Host) -> T:
return func(host.meta["machine"])
return self.group.run_function(wrapped_func, check=check)

View File

@ -5,15 +5,15 @@ import os
import shlex import shlex
import subprocess import subprocess
import sys import sys
from pathlib import Path
from ..cmd import run
from ..errors import ClanError from ..errors import ClanError
from ..facts.generate import generate_facts from ..facts.generate import generate_facts
from ..facts.upload import upload_secrets from ..facts.upload import upload_secrets
from ..machines.machines import Machine from ..machines.machines import Machine
from ..nix import nix_build, nix_command, nix_config, nix_metadata from ..nix import nix_command, nix_metadata
from ..ssh import Host, HostGroup, HostKeyCheck, parse_deployment_address from ..ssh import HostKeyCheck
from .inventory import get_all_machines, get_selected_machines
from .machine_group import MachineGroup
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -86,31 +86,31 @@ def upload_sources(
) )
def deploy_nixos(hosts: HostGroup) -> None: def deploy_nixos(machines: MachineGroup) -> None:
""" """
Deploy to all hosts in parallel Deploy to all hosts in parallel
""" """
def deploy(h: Host) -> None: def deploy(machine: Machine) -> None:
target = f"{h.user or 'root'}@{h.host}" host = machine.build_host
ssh_arg = f"-p {h.port}" if h.port else "" target = f"{host.user or 'root'}@{host.host}"
ssh_arg = f"-p {host.port}" if host.port else ""
env = os.environ.copy() env = os.environ.copy()
env["NIX_SSHOPTS"] = ssh_arg env["NIX_SSHOPTS"] = ssh_arg
machine: Machine = h.meta["machine"]
generate_facts(machine) generate_facts([machine])
upload_secrets(machine) upload_secrets(machine)
path = upload_sources(".", target) path = upload_sources(".", target)
if h.host_key_check != HostKeyCheck.STRICT: if host.host_key_check != HostKeyCheck.STRICT:
ssh_arg += " -o StrictHostKeyChecking=no" ssh_arg += " -o StrictHostKeyChecking=no"
if h.host_key_check == HostKeyCheck.NONE: if host.host_key_check == HostKeyCheck.NONE:
ssh_arg += " -o UserKnownHostsFile=/dev/null" ssh_arg += " -o UserKnownHostsFile=/dev/null"
ssh_arg += " -i " + h.key if h.key else "" ssh_arg += " -i " + host.key if host.key else ""
extra_args = h.meta.get("extra_args", []) extra_args = host.meta.get("extra_args", [])
cmd = [ cmd = [
"nixos-rebuild", "nixos-rebuild",
"switch", "switch",
@ -127,82 +127,55 @@ def deploy_nixos(hosts: HostGroup) -> None:
"--flake", "--flake",
f"{path}#{machine.name}", f"{path}#{machine.name}",
] ]
if target_host := h.meta.get("target_host"): if target_host := host.meta.get("target_host"):
target_host = f"{target_host.user or 'root'}@{target_host.host}" target_host = f"{target_host.user or 'root'}@{target_host.host}"
cmd.extend(["--target-host", target_host]) cmd.extend(["--target-host", target_host])
ret = h.run(cmd, check=False) ret = host.run(cmd, check=False)
# re-retry switch if the first time fails # re-retry switch if the first time fails
if ret.returncode != 0: if ret.returncode != 0:
ret = h.run(cmd) ret = host.run(cmd)
hosts.run_function(deploy) machines.run_function(deploy)
# function to speedup eval if we want to evauluate all machines
def get_all_machines(clan_dir: Path) -> HostGroup:
config = nix_config()
system = config["system"]
machines_json = run(
nix_build([f'{clan_dir}#clanInternals.all-machines-json."{system}"'])
).stdout
machines = json.loads(Path(machines_json.rstrip()).read_text())
hosts = []
ignored_machines = []
for name, machine_data in machines.items():
if machine_data.get("requireExplicitUpdate", False):
continue
machine = Machine(name=name, flake=clan_dir, deployment_info=machine_data)
try:
hosts.append(machine.build_host)
except ClanError:
ignored_machines.append(name)
continue
if not hosts and ignored_machines != []:
print(
"WARNING: No machines to update. The following defined machines were ignored because they do not have `clan.networking.targetHost` nixos option set:",
file=sys.stderr,
)
for machine in ignored_machines:
print(machine, file=sys.stderr)
# very hacky. would be better to do a MachinesGroup instead
return HostGroup(hosts)
def get_selected_machines(machine_names: list[str], flake_dir: Path) -> HostGroup:
hosts = []
for name in machine_names:
machine = Machine(name=name, flake=flake_dir)
hosts.append(machine.build_host)
return HostGroup(hosts)
# FIXME: we want some kind of inventory here.
def update(args: argparse.Namespace) -> None: def update(args: argparse.Namespace) -> None:
if args.flake is None: if args.flake is None:
raise ClanError("Could not find clan flake toplevel directory") raise ClanError("Could not find clan flake toplevel directory")
machines = []
if len(args.machines) == 1 and args.target_host is not None: if len(args.machines) == 1 and args.target_host is not None:
machine = Machine(name=args.machines[0], flake=args.flake) machine = Machine(name=args.machines[0], flake=args.flake)
machine.target_host_address = args.target_host machine.target_host_address = args.target_host
host = parse_deployment_address( machines.append(machine)
args.machines[0],
args.target_host,
meta={"machine": machine},
)
machines = HostGroup([host])
elif args.target_host is not None: elif args.target_host is not None:
print("target host can only be specified for a single machine") print("target host can only be specified for a single machine")
exit(1) exit(1)
else: else:
if len(args.machines) == 0: if len(args.machines) == 0:
machines = get_all_machines(args.flake) ignored_machines = []
else: for machine in get_all_machines(args.flake):
machines = get_selected_machines(args.machines, args.flake) if machine.deployment_info.get("requireExplicitUpdate", False):
continue
try:
machine.build_host
except ClanError: # check if we have a build host set
ignored_machines.append(machine)
continue
deploy_nixos(machines) machines.append(machine)
if not machines and ignored_machines != []:
print(
"WARNING: No machines to update. The following defined machines were ignored because they do not have `clan.networking.targetHost` nixos option set:",
file=sys.stderr,
)
for machine in ignored_machines:
print(machine, file=sys.stderr)
else:
machines = get_selected_machines(args.flake, args.machines)
deploy_nixos(MachineGroup(machines))
def register_update_parser(parser: argparse.ArgumentParser) -> None: def register_update_parser(parser: argparse.ArgumentParser) -> None:

View File

@ -18,6 +18,8 @@ from shlex import quote
from threading import Thread from threading import Thread
from typing import IO, Any, Generic, TypeVar from typing import IO, Any, Generic, TypeVar
from ..errors import ClanError
# https://no-color.org # https://no-color.org
DISABLE_COLOR = not sys.stderr.isatty() or os.environ.get("NO_COLOR", "") != "" DISABLE_COLOR = not sys.stderr.isatty() or os.environ.get("NO_COLOR", "") != ""
@ -285,7 +287,7 @@ class Host:
elif stdout == subprocess.PIPE: elif stdout == subprocess.PIPE:
stdout_read, stdout_write = stack.enter_context(_pipe()) stdout_read, stdout_write = stack.enter_context(_pipe())
else: else:
raise Exception(f"unsupported value for stdout parameter: {stdout}") raise ClanError(f"unsupported value for stdout parameter: {stdout}")
if stderr is None: if stderr is None:
stderr_read = None stderr_read = None
@ -293,7 +295,7 @@ class Host:
elif stderr == subprocess.PIPE: elif stderr == subprocess.PIPE:
stderr_read, stderr_write = stack.enter_context(_pipe()) stderr_read, stderr_write = stack.enter_context(_pipe())
else: else:
raise Exception(f"unsupported value for stderr parameter: {stderr}") raise ClanError(f"unsupported value for stderr parameter: {stderr}")
env = os.environ.copy() env = os.environ.copy()
env.update(extra_env) env.update(extra_env)
@ -610,7 +612,7 @@ class HostGroup:
) )
errors += 1 errors += 1
if errors > 0: if errors > 0:
raise Exception( raise ClanError(
f"{errors} hosts failed with an error. Check the logs above" f"{errors} hosts failed with an error. Check the logs above"
) )

View File

@ -69,7 +69,7 @@ def get_secrets(
secret_facts_module = importlib.import_module(machine.secret_facts_module) secret_facts_module = importlib.import_module(machine.secret_facts_module)
secret_facts_store = secret_facts_module.SecretStore(machine=machine) secret_facts_store = secret_facts_module.SecretStore(machine=machine)
generate_facts(machine) generate_facts([machine])
secret_facts_store.upload(secrets_dir) secret_facts_store.upload(secrets_dir)
return secrets_dir return secrets_dir

View File

@ -59,8 +59,8 @@ def test_generate_secret(
age_key_mtime = age_key.lstat().st_mtime_ns age_key_mtime = age_key.lstat().st_mtime_ns
secret1_mtime = identity_secret.lstat().st_mtime_ns secret1_mtime = identity_secret.lstat().st_mtime_ns
# test idempotency # test idempotency for vm1 and also generate for vm2
cli.run(["facts", "generate", "vm1"]) cli.run(["facts", "generate"])
assert age_key.lstat().st_mtime_ns == age_key_mtime assert age_key.lstat().st_mtime_ns == age_key_mtime
assert identity_secret.lstat().st_mtime_ns == secret1_mtime assert identity_secret.lstat().st_mtime_ns == secret1_mtime
@ -68,7 +68,6 @@ def test_generate_secret(
secrets_folder / "vm1-zerotier-identity-secret" / "machines" / "vm1" secrets_folder / "vm1-zerotier-identity-secret" / "machines" / "vm1"
).exists() ).exists()
cli.run(["facts", "generate", "vm2"])
assert has_secret(test_flake_with_core.path, "vm2-age.key") assert has_secret(test_flake_with_core.path, "vm2-age.key")
assert has_secret(test_flake_with_core.path, "vm2-zerotier-identity-secret") assert has_secret(test_flake_with_core.path, "vm2-zerotier-identity-secret")
ip = machine_get_fact(test_flake_with_core.path, "vm1", "zerotier-ip") ip = machine_get_fact(test_flake_with_core.path, "vm1", "zerotier-ip")