get rid of task_manager
All checks were successful
checks-impure / test (pull_request) Successful in 1m7s
checks / test (pull_request) Successful in 1m36s

This commit is contained in:
Jörg Thalheim 2023-12-06 15:36:36 +01:00
parent 9576047adb
commit 7bc54cb524
4 changed files with 177 additions and 370 deletions

View File

@ -1,203 +0,0 @@
import logging
import os
import queue
import select
import shlex
import subprocess
import sys
import threading
import traceback
from collections.abc import Iterator
from enum import Enum
from pathlib import Path
from typing import Any, TypeVar
from uuid import UUID, uuid4
from .custom_logger import ThreadFormatter, get_caller
from .deal import deal
from .errors import ClanError
class Command:
def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log
self.p: subprocess.Popen[str] | None = None
self._output: queue.SimpleQueue[str | None] = queue.SimpleQueue()
self.returncode: int | None = None
self.done: bool = False
self.stdout: list[str] = []
self.stderr: list[str] = []
def close_queue(self) -> None:
if self.p is not None:
self.returncode = self.p.returncode
self._output.put(None)
self.done = True
def run(
self,
cmd: list[str],
env: dict[str, str] | None = None,
cwd: Path | None = None,
name: str = "command",
) -> None:
self.running = True
self.log.debug(f"Command: {shlex.join(cmd)}")
self.log.debug(f"Caller: {get_caller()}")
cwd_res = None
if cwd is not None:
if not cwd.exists():
raise ClanError(f"Working directory {cwd} does not exist")
if not cwd.is_dir():
raise ClanError(f"Working directory {cwd} is not a directory")
cwd_res = cwd.resolve()
self.log.debug(f"Working directory: {cwd_res}")
self.p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=cwd_res,
env=env,
)
assert self.p.stdout is not None and self.p.stderr is not None
os.set_blocking(self.p.stdout.fileno(), False)
os.set_blocking(self.p.stderr.fileno(), False)
while self.p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 1)
for fd in rlist:
try:
for line in fd:
self.log.debug(f"[{name}] {line.rstrip()}")
if fd == self.p.stderr:
self.stderr.append(line)
else:
self.stdout.append(line)
self._output.put(line)
except BlockingIOError:
continue
if self.p.returncode != 0:
raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
class TaskStatus(str, Enum):
NOTSTARTED = "NOTSTARTED"
RUNNING = "RUNNING"
FINISHED = "FINISHED"
FAILED = "FAILED"
class BaseTask:
def __init__(self, uuid: UUID, num_cmds: int) -> None:
# constructor
self.uuid: UUID = uuid
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
handler.setFormatter(ThreadFormatter())
logger = logging.getLogger(__name__)
logger.addHandler(handler)
self.log = logger
self.log = logger
self.procs: list[Command] = []
self.status = TaskStatus.NOTSTARTED
self.logs_lock = threading.Lock()
self.error: Exception | None = None
for _ in range(num_cmds):
cmd = Command(self.log)
self.procs.append(cmd)
def _run(self) -> None:
self.status = TaskStatus.RUNNING
try:
self.run()
# TODO: We need to check, if too many commands have been initialized,
# but not run. This would deadlock the log_lines() function.
# Idea: Run next(cmds) and check if it raises StopIteration if not,
# we have too many commands
except Exception as e:
# FIXME: fix exception handling here
traceback.print_exception(*sys.exc_info())
self.error = e
self.log.exception(e)
self.status = TaskStatus.FAILED
else:
self.status = TaskStatus.FINISHED
finally:
for proc in self.procs:
proc.close_queue()
def run(self) -> None:
raise NotImplementedError
## TODO: Test when two clients are connected to the same task
def log_lines(self) -> Iterator[str]:
with self.logs_lock:
for proc in self.procs:
if self.status == TaskStatus.FINISHED:
return
# process has finished
if proc.done:
for line in proc.stdout:
yield line
for line in proc.stderr:
yield line
else:
while maybe_line := proc._output.get():
yield maybe_line
def commands(self) -> Iterator[Command]:
yield from self.procs
# TODO: We need to test concurrency
class TaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BaseTask] = {}
def __getitem__(self, uuid: UUID) -> BaseTask:
with self.lock:
if uuid not in self.pool:
raise ClanError(f"Task with uuid {uuid} does not exist")
return self.pool[uuid]
def __setitem__(self, uuid: UUID, task: BaseTask) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"Task with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = task
POOL: TaskPool = TaskPool()
@deal.raises(ClanError)
def get_task(uuid: UUID) -> BaseTask:
global POOL
return POOL[uuid]
T = TypeVar("T", bound="BaseTask")
@deal.raises(ClanError)
def create_task(task_type: type[T], *args: Any) -> T:
global POOL
# check if task_type is a callable
if not callable(task_type):
raise ClanError("task_type must be callable")
uuid = uuid4()
task = task_type(uuid, *args)
POOL[uuid] = task
threading.Thread(target=task._run).start()
return task

View File

@ -1,19 +1,22 @@
import argparse
import asyncio
import json
import logging
import os
import shlex
import subprocess
import sys
import tempfile
from collections.abc import Iterator
from pathlib import Path
from uuid import UUID
from typing import IO
from ..dirs import module_root
from ..errors import ClanError
from ..nix import nix_build, nix_config, nix_eval, nix_shell
from ..task_manager import BaseTask, Command, create_task
from .inspect import VmConfig, inspect_vm
log = logging.getLogger(__name__)
def qemu_command(
vm: VmConfig,
@ -87,162 +90,189 @@ def qemu_command(
return command
class BuildVmTask(BaseTask):
def __init__(self, uuid: UUID, vm: VmConfig, nix_options: list[str] = []) -> None:
super().__init__(uuid, num_cmds=7)
self.vm = vm
self.nix_options = nix_options
def get_vm_create_info(vm: VmConfig, nix_options: list[str]) -> dict[str, str]:
config = nix_config()
system = config["system"]
def get_vm_create_info(self, cmds: Iterator[Command]) -> dict[str, str]:
config = nix_config()
system = config["system"]
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd = next(cmds)
cmd.run(
nix_build(
[
f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.system.clan.vm.create',
*self.nix_options,
]
),
name="buildvm",
clan_dir = vm.flake_url
machine = vm.flake_attr
cmd = nix_build(
[
f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.system.clan.vm.create',
*nix_options,
]
)
proc = subprocess.run(
cmd,
check=False,
stdout=subprocess.PIPE,
text=True,
)
if proc.returncode != 0:
raise ClanError(
f"Failed to build vm config: {shlex.join(cmd)} failed with: {proc.returncode}"
)
vm_json = "".join(cmd.stdout).strip()
self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f:
return json.load(f)
try:
return json.loads(Path(proc.stdout.strip()).read_text())
except json.JSONDecodeError as e:
raise ClanError(f"Failed to parse vm config: {e}")
def get_clan_name(self, cmds: Iterator[Command]) -> str:
clan_dir = self.vm.flake_url
cmd = next(cmds)
cmd.run(
nix_eval([f"{clan_dir}#clanInternals.clanName"]) + self.nix_options,
name="clanname",
def get_clan_name(vm: VmConfig, nix_options: list[str]) -> str:
clan_dir = vm.flake_url
cmd = nix_eval([f"{clan_dir}#clanInternals.clanName"]) + nix_options
proc = subprocess.run(
cmd,
stdout=subprocess.PIPE,
check=False,
text=True,
)
if proc.returncode != 0:
raise ClanError(
f"Failed to get clan name: {shlex.join(cmd)} failed with: {proc.returncode}"
)
clan_name = cmd.stdout[0].strip().strip('"')
return clan_name
def run(self) -> None:
cmds = self.commands()
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")
# TODO: We should get this from the vm argument
nixos_config = self.get_vm_create_info(cmds)
clan_name = self.get_clan_name(cmds)
self.log.debug(f"Building VM for clan name: {clan_name}")
flake_dir = Path(self.vm.flake_url)
flake_dir.mkdir(exist_ok=True)
with tempfile.TemporaryDirectory() as tmpdir_:
tmpdir = Path(tmpdir_)
xchg_dir = tmpdir / "xchg"
xchg_dir.mkdir(exist_ok=True)
secrets_dir = tmpdir / "secrets"
secrets_dir.mkdir(exist_ok=True)
disk_img = tmpdir / "disk.img"
spice_socket = tmpdir / "spice.sock"
env = os.environ.copy()
env["CLAN_DIR"] = str(self.vm.flake_url)
env["PYTHONPATH"] = str(
":".join(sys.path)
) # TODO do this in the clanCore module
env["SECRETS_DIR"] = str(secrets_dir)
# Only generate secrets for local clans
if isinstance(self.vm.flake_url, Path) and self.vm.flake_url.is_dir():
cmd = next(cmds)
if Path(self.vm.flake_url).is_dir():
cmd.run(
[nixos_config["generateSecrets"], clan_name],
env=env,
name="generateSecrets",
)
else:
self.log.warning("won't generate secrets for non local clan")
cmd = next(cmds)
cmd.run(
[nixos_config["uploadSecrets"]],
env=env,
name="uploadSecrets",
)
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
str(disk_img),
"1024M",
],
),
name="createDisk",
)
cmd = next(cmds)
cmd.run(
nix_shell(
["e2fsprogs"],
[
"mkfs.ext4",
"-L",
"nixos",
str(disk_img),
],
),
name="formatDisk",
)
cmd = next(cmds)
qemu_cmd = qemu_command(
self.vm,
nixos_config,
xchg_dir=xchg_dir,
secrets_dir=secrets_dir,
disk_img=disk_img,
spice_socket=spice_socket,
)
print("$ " + shlex.join(qemu_cmd))
packages = ["qemu"]
if self.vm.graphics:
packages.append("virt-viewer")
env = os.environ.copy()
remote_viewer_mimetypes = module_root() / "vms" / "mimetypes"
env[
"XDG_DATA_DIRS"
] = f"{remote_viewer_mimetypes}:{env.get('XDG_DATA_DIRS', '')}"
print(env["XDG_DATA_DIRS"])
cmd.run(nix_shell(packages, qemu_cmd), name="qemu", env=env)
return proc.stdout.strip().strip('"')
def run_vm(
vm: VmConfig, nix_options: list[str] = [], env: dict[str, str] = {}
) -> BuildVmTask:
return create_task(BuildVmTask, vm, nix_options)
vm: VmConfig, nix_options: list[str] = [], log_fd: IO[str] | None = None
) -> None:
"""
log_fd can be used to stream the output of all commands to a UI
"""
machine = vm.flake_attr
log.debug(f"Creating VM for {machine}")
# TODO: We should get this from the vm argument
nixos_config = get_vm_create_info(vm, nix_options)
clan_name = get_clan_name(vm, nix_options)
log.debug(f"Building VM for clan name: {clan_name}")
flake_dir = Path(vm.flake_url)
flake_dir.mkdir(exist_ok=True)
with tempfile.TemporaryDirectory() as tmpdir_:
tmpdir = Path(tmpdir_)
xchg_dir = tmpdir / "xchg"
xchg_dir.mkdir(exist_ok=True)
secrets_dir = tmpdir / "secrets"
secrets_dir.mkdir(exist_ok=True)
disk_img = tmpdir / "disk.img"
spice_socket = tmpdir / "spice.sock"
env = os.environ.copy()
env["CLAN_DIR"] = str(vm.flake_url)
env["PYTHONPATH"] = str(
":".join(sys.path)
) # TODO do this in the clanCore module
env["SECRETS_DIR"] = str(secrets_dir)
# Only generate secrets for local clans
if isinstance(vm.flake_url, Path) and vm.flake_url.is_dir():
if Path(vm.flake_url).is_dir():
subprocess.run(
[nixos_config["generateSecrets"], clan_name],
env=env,
check=False,
stdout=log_fd,
stderr=log_fd,
)
else:
log.warning("won't generate secrets for non local clan")
cmd = [nixos_config["uploadSecrets"]]
res = subprocess.run(
cmd,
env=env,
check=False,
stdout=log_fd,
stderr=log_fd,
)
if res.returncode != 0:
raise ClanError(
f"Failed to upload secrets: {shlex.join(cmd)} failed with {res.returncode}"
)
cmd = nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
str(disk_img),
"1024M",
],
)
res = subprocess.run(
cmd,
check=False,
stdout=log_fd,
stderr=log_fd,
)
if res.returncode != 0:
raise ClanError(
f"Failed to create disk image: {shlex.join(cmd)} failed with {res.returncode}"
)
cmd = nix_shell(
["e2fsprogs"],
[
"mkfs.ext4",
"-L",
"nixos",
str(disk_img),
],
)
res = subprocess.run(
cmd,
check=False,
stdout=log_fd,
stderr=log_fd,
)
if res.returncode != 0:
raise ClanError(
f"Failed to create ext4 filesystem: {shlex.join(cmd)} failed with {res.returncode}"
)
qemu_cmd = qemu_command(
vm,
nixos_config,
xchg_dir=xchg_dir,
secrets_dir=secrets_dir,
disk_img=disk_img,
spice_socket=spice_socket,
)
print("$ " + shlex.join(qemu_cmd))
packages = ["qemu"]
if vm.graphics:
packages.append("virt-viewer")
env = os.environ.copy()
remote_viewer_mimetypes = module_root() / "vms" / "mimetypes"
env[
"XDG_DATA_DIRS"
] = f"{remote_viewer_mimetypes}:{env.get('XDG_DATA_DIRS', '')}"
print(env["XDG_DATA_DIRS"])
res = subprocess.run(
nix_shell(packages, qemu_cmd),
env=env,
check=False,
stdout=log_fd,
stderr=log_fd,
)
if res.returncode != 0:
raise ClanError(f"qemu failed with {res.returncode}")
def run_command(args: argparse.Namespace) -> None:
flake_url = args.flake_url or args.flake
vm = asyncio.run(inspect_vm(flake_url=flake_url, flake_attr=args.machine))
task = run_vm(vm, args.option)
for line in task.log_lines():
print(line, end="")
run_vm(vm, args.option)
def register_run_parser(parser: argparse.ArgumentParser) -> None:

View File

@ -3,7 +3,6 @@ from enum import Enum
from pydantic import BaseModel, Extra, Field
from ..async_cmd import CmdOut
from ..task_manager import TaskStatus
class Status(Enum):
@ -47,15 +46,6 @@ class VerifyMachineResponse(BaseModel):
error: str | None
class VmStatusResponse(BaseModel):
error: str | None
status: TaskStatus
class VmCreateResponse(BaseModel):
uuid: str
class FlakeAttrResponse(BaseModel):
flake_attrs: list[str]

View File

@ -1,16 +1,6 @@
import deal
from clan_cli import nix, task_manager
@deal.cases(task_manager.get_task)
def test_get_task(case: deal.TestCase) -> None:
case()
@deal.cases(task_manager.create_task)
def test_create_task(case: deal.TestCase) -> None:
case()
from clan_cli import nix
@deal.cases(nix.nix_command)