get rid of task_manager
All checks were successful
checks-impure / test (pull_request) Successful in 1m7s
checks / test (pull_request) Successful in 1m36s

This commit is contained in:
Jörg Thalheim 2023-12-06 15:36:36 +01:00
parent 9576047adb
commit 7bc54cb524
4 changed files with 177 additions and 370 deletions

View File

@ -1,203 +0,0 @@
import logging
import os
import queue
import select
import shlex
import subprocess
import sys
import threading
import traceback
from collections.abc import Iterator
from enum import Enum
from pathlib import Path
from typing import Any, TypeVar
from uuid import UUID, uuid4
from .custom_logger import ThreadFormatter, get_caller
from .deal import deal
from .errors import ClanError
class Command:
def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log
self.p: subprocess.Popen[str] | None = None
self._output: queue.SimpleQueue[str | None] = queue.SimpleQueue()
self.returncode: int | None = None
self.done: bool = False
self.stdout: list[str] = []
self.stderr: list[str] = []
def close_queue(self) -> None:
if self.p is not None:
self.returncode = self.p.returncode
self._output.put(None)
self.done = True
def run(
self,
cmd: list[str],
env: dict[str, str] | None = None,
cwd: Path | None = None,
name: str = "command",
) -> None:
self.running = True
self.log.debug(f"Command: {shlex.join(cmd)}")
self.log.debug(f"Caller: {get_caller()}")
cwd_res = None
if cwd is not None:
if not cwd.exists():
raise ClanError(f"Working directory {cwd} does not exist")
if not cwd.is_dir():
raise ClanError(f"Working directory {cwd} is not a directory")
cwd_res = cwd.resolve()
self.log.debug(f"Working directory: {cwd_res}")
self.p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=cwd_res,
env=env,
)
assert self.p.stdout is not None and self.p.stderr is not None
os.set_blocking(self.p.stdout.fileno(), False)
os.set_blocking(self.p.stderr.fileno(), False)
while self.p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 1)
for fd in rlist:
try:
for line in fd:
self.log.debug(f"[{name}] {line.rstrip()}")
if fd == self.p.stderr:
self.stderr.append(line)
else:
self.stdout.append(line)
self._output.put(line)
except BlockingIOError:
continue
if self.p.returncode != 0:
raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
class TaskStatus(str, Enum):
NOTSTARTED = "NOTSTARTED"
RUNNING = "RUNNING"
FINISHED = "FINISHED"
FAILED = "FAILED"
class BaseTask:
def __init__(self, uuid: UUID, num_cmds: int) -> None:
# constructor
self.uuid: UUID = uuid
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
handler.setFormatter(ThreadFormatter())
logger = logging.getLogger(__name__)
logger.addHandler(handler)
self.log = logger
self.log = logger
self.procs: list[Command] = []
self.status = TaskStatus.NOTSTARTED
self.logs_lock = threading.Lock()
self.error: Exception | None = None
for _ in range(num_cmds):
cmd = Command(self.log)
self.procs.append(cmd)
def _run(self) -> None:
self.status = TaskStatus.RUNNING
try:
self.run()
# TODO: We need to check, if too many commands have been initialized,
# but not run. This would deadlock the log_lines() function.
# Idea: Run next(cmds) and check if it raises StopIteration if not,
# we have too many commands
except Exception as e:
# FIXME: fix exception handling here
traceback.print_exception(*sys.exc_info())
self.error = e
self.log.exception(e)
self.status = TaskStatus.FAILED
else:
self.status = TaskStatus.FINISHED
finally:
for proc in self.procs:
proc.close_queue()
def run(self) -> None:
raise NotImplementedError
## TODO: Test when two clients are connected to the same task
def log_lines(self) -> Iterator[str]:
with self.logs_lock:
for proc in self.procs:
if self.status == TaskStatus.FINISHED:
return
# process has finished
if proc.done:
for line in proc.stdout:
yield line
for line in proc.stderr:
yield line
else:
while maybe_line := proc._output.get():
yield maybe_line
def commands(self) -> Iterator[Command]:
yield from self.procs
# TODO: We need to test concurrency
class TaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BaseTask] = {}
def __getitem__(self, uuid: UUID) -> BaseTask:
with self.lock:
if uuid not in self.pool:
raise ClanError(f"Task with uuid {uuid} does not exist")
return self.pool[uuid]
def __setitem__(self, uuid: UUID, task: BaseTask) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"Task with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = task
POOL: TaskPool = TaskPool()
@deal.raises(ClanError)
def get_task(uuid: UUID) -> BaseTask:
global POOL
return POOL[uuid]
T = TypeVar("T", bound="BaseTask")
@deal.raises(ClanError)
def create_task(task_type: type[T], *args: Any) -> T:
global POOL
# check if task_type is a callable
if not callable(task_type):
raise ClanError("task_type must be callable")
uuid = uuid4()
task = task_type(uuid, *args)
POOL[uuid] = task
threading.Thread(target=task._run).start()
return task

View File

@ -1,19 +1,22 @@
import argparse import argparse
import asyncio import asyncio
import json import json
import logging
import os import os
import shlex import shlex
import subprocess
import sys import sys
import tempfile import tempfile
from collections.abc import Iterator
from pathlib import Path from pathlib import Path
from uuid import UUID from typing import IO
from ..dirs import module_root from ..dirs import module_root
from ..errors import ClanError
from ..nix import nix_build, nix_config, nix_eval, nix_shell from ..nix import nix_build, nix_config, nix_eval, nix_shell
from ..task_manager import BaseTask, Command, create_task
from .inspect import VmConfig, inspect_vm from .inspect import VmConfig, inspect_vm
log = logging.getLogger(__name__)
def qemu_command( def qemu_command(
vm: VmConfig, vm: VmConfig,
@ -87,162 +90,189 @@ def qemu_command(
return command return command
class BuildVmTask(BaseTask): def get_vm_create_info(vm: VmConfig, nix_options: list[str]) -> dict[str, str]:
def __init__(self, uuid: UUID, vm: VmConfig, nix_options: list[str] = []) -> None: config = nix_config()
super().__init__(uuid, num_cmds=7) system = config["system"]
self.vm = vm
self.nix_options = nix_options
def get_vm_create_info(self, cmds: Iterator[Command]) -> dict[str, str]: clan_dir = vm.flake_url
config = nix_config() machine = vm.flake_attr
system = config["system"] cmd = nix_build(
[
clan_dir = self.vm.flake_url f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.system.clan.vm.create',
machine = self.vm.flake_attr *nix_options,
cmd = next(cmds) ]
cmd.run( )
nix_build( proc = subprocess.run(
[ cmd,
f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.system.clan.vm.create', check=False,
*self.nix_options, stdout=subprocess.PIPE,
] text=True,
), )
name="buildvm", if proc.returncode != 0:
raise ClanError(
f"Failed to build vm config: {shlex.join(cmd)} failed with: {proc.returncode}"
) )
vm_json = "".join(cmd.stdout).strip() try:
self.log.debug(f"VM JSON path: {vm_json}") return json.loads(Path(proc.stdout.strip()).read_text())
with open(vm_json) as f: except json.JSONDecodeError as e:
return json.load(f) raise ClanError(f"Failed to parse vm config: {e}")
def get_clan_name(self, cmds: Iterator[Command]) -> str:
clan_dir = self.vm.flake_url def get_clan_name(vm: VmConfig, nix_options: list[str]) -> str:
cmd = next(cmds) clan_dir = vm.flake_url
cmd.run( cmd = nix_eval([f"{clan_dir}#clanInternals.clanName"]) + nix_options
nix_eval([f"{clan_dir}#clanInternals.clanName"]) + self.nix_options, proc = subprocess.run(
name="clanname", cmd,
stdout=subprocess.PIPE,
check=False,
text=True,
)
if proc.returncode != 0:
raise ClanError(
f"Failed to get clan name: {shlex.join(cmd)} failed with: {proc.returncode}"
) )
clan_name = cmd.stdout[0].strip().strip('"') return proc.stdout.strip().strip('"')
return clan_name
def run(self) -> None:
cmds = self.commands()
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")
# TODO: We should get this from the vm argument
nixos_config = self.get_vm_create_info(cmds)
clan_name = self.get_clan_name(cmds)
self.log.debug(f"Building VM for clan name: {clan_name}")
flake_dir = Path(self.vm.flake_url)
flake_dir.mkdir(exist_ok=True)
with tempfile.TemporaryDirectory() as tmpdir_:
tmpdir = Path(tmpdir_)
xchg_dir = tmpdir / "xchg"
xchg_dir.mkdir(exist_ok=True)
secrets_dir = tmpdir / "secrets"
secrets_dir.mkdir(exist_ok=True)
disk_img = tmpdir / "disk.img"
spice_socket = tmpdir / "spice.sock"
env = os.environ.copy()
env["CLAN_DIR"] = str(self.vm.flake_url)
env["PYTHONPATH"] = str(
":".join(sys.path)
) # TODO do this in the clanCore module
env["SECRETS_DIR"] = str(secrets_dir)
# Only generate secrets for local clans
if isinstance(self.vm.flake_url, Path) and self.vm.flake_url.is_dir():
cmd = next(cmds)
if Path(self.vm.flake_url).is_dir():
cmd.run(
[nixos_config["generateSecrets"], clan_name],
env=env,
name="generateSecrets",
)
else:
self.log.warning("won't generate secrets for non local clan")
cmd = next(cmds)
cmd.run(
[nixos_config["uploadSecrets"]],
env=env,
name="uploadSecrets",
)
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
str(disk_img),
"1024M",
],
),
name="createDisk",
)
cmd = next(cmds)
cmd.run(
nix_shell(
["e2fsprogs"],
[
"mkfs.ext4",
"-L",
"nixos",
str(disk_img),
],
),
name="formatDisk",
)
cmd = next(cmds)
qemu_cmd = qemu_command(
self.vm,
nixos_config,
xchg_dir=xchg_dir,
secrets_dir=secrets_dir,
disk_img=disk_img,
spice_socket=spice_socket,
)
print("$ " + shlex.join(qemu_cmd))
packages = ["qemu"]
if self.vm.graphics:
packages.append("virt-viewer")
env = os.environ.copy()
remote_viewer_mimetypes = module_root() / "vms" / "mimetypes"
env[
"XDG_DATA_DIRS"
] = f"{remote_viewer_mimetypes}:{env.get('XDG_DATA_DIRS', '')}"
print(env["XDG_DATA_DIRS"])
cmd.run(nix_shell(packages, qemu_cmd), name="qemu", env=env)
def run_vm( def run_vm(
vm: VmConfig, nix_options: list[str] = [], env: dict[str, str] = {} vm: VmConfig, nix_options: list[str] = [], log_fd: IO[str] | None = None
) -> BuildVmTask: ) -> None:
return create_task(BuildVmTask, vm, nix_options) """
log_fd can be used to stream the output of all commands to a UI
"""
machine = vm.flake_attr
log.debug(f"Creating VM for {machine}")
# TODO: We should get this from the vm argument
nixos_config = get_vm_create_info(vm, nix_options)
clan_name = get_clan_name(vm, nix_options)
log.debug(f"Building VM for clan name: {clan_name}")
flake_dir = Path(vm.flake_url)
flake_dir.mkdir(exist_ok=True)
with tempfile.TemporaryDirectory() as tmpdir_:
tmpdir = Path(tmpdir_)
xchg_dir = tmpdir / "xchg"
xchg_dir.mkdir(exist_ok=True)
secrets_dir = tmpdir / "secrets"
secrets_dir.mkdir(exist_ok=True)
disk_img = tmpdir / "disk.img"
spice_socket = tmpdir / "spice.sock"
env = os.environ.copy()
env["CLAN_DIR"] = str(vm.flake_url)
env["PYTHONPATH"] = str(
":".join(sys.path)
) # TODO do this in the clanCore module
env["SECRETS_DIR"] = str(secrets_dir)
# Only generate secrets for local clans
if isinstance(vm.flake_url, Path) and vm.flake_url.is_dir():
if Path(vm.flake_url).is_dir():
subprocess.run(
[nixos_config["generateSecrets"], clan_name],
env=env,
check=False,
stdout=log_fd,
stderr=log_fd,
)
else:
log.warning("won't generate secrets for non local clan")
cmd = [nixos_config["uploadSecrets"]]
res = subprocess.run(
cmd,
env=env,
check=False,
stdout=log_fd,
stderr=log_fd,
)
if res.returncode != 0:
raise ClanError(
f"Failed to upload secrets: {shlex.join(cmd)} failed with {res.returncode}"
)
cmd = nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
str(disk_img),
"1024M",
],
)
res = subprocess.run(
cmd,
check=False,
stdout=log_fd,
stderr=log_fd,
)
if res.returncode != 0:
raise ClanError(
f"Failed to create disk image: {shlex.join(cmd)} failed with {res.returncode}"
)
cmd = nix_shell(
["e2fsprogs"],
[
"mkfs.ext4",
"-L",
"nixos",
str(disk_img),
],
)
res = subprocess.run(
cmd,
check=False,
stdout=log_fd,
stderr=log_fd,
)
if res.returncode != 0:
raise ClanError(
f"Failed to create ext4 filesystem: {shlex.join(cmd)} failed with {res.returncode}"
)
qemu_cmd = qemu_command(
vm,
nixos_config,
xchg_dir=xchg_dir,
secrets_dir=secrets_dir,
disk_img=disk_img,
spice_socket=spice_socket,
)
print("$ " + shlex.join(qemu_cmd))
packages = ["qemu"]
if vm.graphics:
packages.append("virt-viewer")
env = os.environ.copy()
remote_viewer_mimetypes = module_root() / "vms" / "mimetypes"
env[
"XDG_DATA_DIRS"
] = f"{remote_viewer_mimetypes}:{env.get('XDG_DATA_DIRS', '')}"
print(env["XDG_DATA_DIRS"])
res = subprocess.run(
nix_shell(packages, qemu_cmd),
env=env,
check=False,
stdout=log_fd,
stderr=log_fd,
)
if res.returncode != 0:
raise ClanError(f"qemu failed with {res.returncode}")
def run_command(args: argparse.Namespace) -> None: def run_command(args: argparse.Namespace) -> None:
flake_url = args.flake_url or args.flake flake_url = args.flake_url or args.flake
vm = asyncio.run(inspect_vm(flake_url=flake_url, flake_attr=args.machine)) vm = asyncio.run(inspect_vm(flake_url=flake_url, flake_attr=args.machine))
task = run_vm(vm, args.option) run_vm(vm, args.option)
for line in task.log_lines():
print(line, end="")
def register_run_parser(parser: argparse.ArgumentParser) -> None: def register_run_parser(parser: argparse.ArgumentParser) -> None:

View File

@ -3,7 +3,6 @@ from enum import Enum
from pydantic import BaseModel, Extra, Field from pydantic import BaseModel, Extra, Field
from ..async_cmd import CmdOut from ..async_cmd import CmdOut
from ..task_manager import TaskStatus
class Status(Enum): class Status(Enum):
@ -47,15 +46,6 @@ class VerifyMachineResponse(BaseModel):
error: str | None error: str | None
class VmStatusResponse(BaseModel):
error: str | None
status: TaskStatus
class VmCreateResponse(BaseModel):
uuid: str
class FlakeAttrResponse(BaseModel): class FlakeAttrResponse(BaseModel):
flake_attrs: list[str] flake_attrs: list[str]

View File

@ -1,16 +1,6 @@
import deal import deal
from clan_cli import nix, task_manager from clan_cli import nix
@deal.cases(task_manager.get_task)
def test_get_task(case: deal.TestCase) -> None:
case()
@deal.cases(task_manager.create_task)
def test_create_task(case: deal.TestCase) -> None:
case()
@deal.cases(nix.nix_command) @deal.cases(nix.nix_command)