CLI: Restructured TaskManager and log collection

This commit is contained in:
Luis Hebendanz 2023-10-02 18:36:50 +02:00
parent 6640c78089
commit 814d81c1d2
3 changed files with 136 additions and 97 deletions

View File

@ -1,10 +1,24 @@
import argparse
import asyncio
from uuid import UUID
import threading
import queue
from ..dirs import get_clan_flake_toplevel
from ..webui.routers import vms
from ..webui.schemas import VmConfig
from typing import Any, Iterator
from fastapi.responses import StreamingResponse
import pdb
def read_stream_response(stream: StreamingResponse) -> Iterator[Any]:
iterator = stream.body_iterator
while True:
try:
tem = asyncio.run(iterator.__anext__())
except StopAsyncIteration:
break
yield tem
def create(args: argparse.Namespace) -> None:
clan_dir = get_clan_flake_toplevel().as_posix()
@ -18,6 +32,13 @@ def create(args: argparse.Namespace) -> None:
res = asyncio.run(vms.create_vm(vm))
print(res.json())
uuid = UUID(res.uuid)
res = asyncio.run(vms.get_vm_logs(uuid))
for line in read_stream_response(res):
print(line)
def register_create_parser(parser: argparse.ArgumentParser) -> None:

View File

@ -1,8 +1,9 @@
import json
import logging
import tempfile
import time
from pathlib import Path
from typing import Annotated, Iterator
from typing import Annotated, Iterator, Iterable
from uuid import UUID
from fastapi import APIRouter, Body
@ -10,7 +11,7 @@ from fastapi.responses import StreamingResponse
from ...nix import nix_build, nix_eval, nix_shell
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, get_task, register_task
from ..task_manager import BaseTask, get_task, register_task, CmdState
from .utils import run_cmd
log = logging.getLogger(__name__)
@ -38,10 +39,11 @@ class BuildVmTask(BaseTask):
super().__init__(uuid)
self.vm = vm
def get_vm_create_info(self) -> dict:
def get_vm_create_info(self, cmds: Iterable[CmdState]) -> dict:
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd_state = self.run_cmd(
cmd = next(cmds)
cmd.run(
nix_build(
[
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
@ -49,41 +51,48 @@ class BuildVmTask(BaseTask):
]
)
)
vm_json = "".join(cmd_state.stdout)
vm_json = "".join(cmd.stdout)
self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f:
return json.load(f)
def task_run(self) -> None:
cmds = self.register_cmds(4)
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")
vm_config = self.get_vm_create_info()
# TODO: We should get this from the vm argument
vm_config = self.get_vm_create_info(cmds)
with tempfile.TemporaryDirectory() as tmpdir_:
xchg_dir = Path(tmpdir_) / "xchg"
xchg_dir.mkdir()
disk_img = f"{tmpdir_}/disk.img"
cmd = nix_shell(
cmd = next(cmds)
cmd.run(nix_shell(
["qemu"],
[
"qemu" "qemu-img",
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
self.run_cmd(cmd)
))
cmd = [
cmd = next(cmds)
cmd.run([
"mkfs.ext4",
"-L",
"nixos",
disk_img,
]
self.run_cmd(cmd)
])
cmd = nix_shell(
cmd = next(cmds)
cmd.run(nix_shell(
["qemu"],
[
# fmt: off
@ -106,26 +115,7 @@ class BuildVmTask(BaseTask):
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
)
self.run_cmd(cmd)
# def run(self) -> None:
# try:
# self.log.debug(f"BuildVM with uuid {self.uuid} started")
# cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
# proc = self.run_cmd(cmd)
# self.log.debug(f"stdout: {proc.stdout}")
# vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
# self.log.debug(f"vm_path: {vm_path}")
# self.run_cmd([vm_path])
# self.finished = True
# except Exception as e:
# self.failed = True
# self.finished = True
# log.exception(e)
))
@router.post("/api/vms/inspect")
@ -154,21 +144,8 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
def stream_logs() -> Iterator[str]:
task = get_task(uuid)
for proc in task.procs:
if proc.done:
log.debug("stream logs and proc is done")
for line in proc.stderr:
yield line + "\n"
for line in proc.stdout:
yield line + "\n"
continue
while True:
out = proc.output
line = out.get()
if line is None:
log.debug("stream logs and line is None")
break
yield line
for line in task.logs_iter():
yield line
return StreamingResponse(
content=stream_logs(),

View File

@ -5,19 +5,72 @@ import select
import shlex
import subprocess
import threading
from typing import Any
from typing import Any, Iterable, Iterator
from uuid import UUID, uuid4
class CmdState:
def __init__(self, proc: subprocess.Popen) -> None:
global LOOP
self.proc: subprocess.Popen = proc
def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log
self.p: subprocess.Popen = None
self.stdout: list[str] = []
self.stderr: list[str] = []
self.output: queue.SimpleQueue = queue.SimpleQueue()
self._output: queue.SimpleQueue = queue.SimpleQueue()
self.returncode: int | None = None
self.done: bool = False
self.running: bool = False
self.cmd_str: str | None = None
self.workdir: str | None = None
def close_queue(self) -> None:
if self.p is not None:
self.returncode = self.p.returncode
self._output.put(None)
self.running = False
self.done = True
def run(self, cmd: list[str]) -> None:
self.running = True
try:
self.cmd_str = shlex.join(cmd)
self.workdir = os.getcwd()
self.log.debug(f"Working directory: {self.workdir}")
self.log.debug(f"Running command: {shlex.join(cmd)}")
self.p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=self.workdir,
)
while self.p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
if self.p.stderr in rlist:
assert self.p.stderr is not None
line = self.p.stderr.readline()
if line != "":
line = line.strip('\n')
self.stderr.append(line)
self.log.debug("stderr: %s", line)
self._output.put(line)
if self.p.stdout in rlist:
assert self.p.stdout is not None
line = self.p.stdout.readline()
if line != "":
line = line.strip('\n')
self.stdout.append(line)
self.log.debug("stdout: %s", line)
self._output.put(line)
if self.p.returncode != 0:
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command")
finally:
self.close_queue()
class BaseTask(threading.Thread):
@ -31,64 +84,52 @@ class BaseTask(threading.Thread):
self.procs: list[CmdState] = []
self.failed: bool = False
self.finished: bool = False
self.logs_lock = threading.Lock()
def run(self) -> None:
try:
self.task_run()
except Exception as e:
for proc in self.procs:
proc.close_queue()
self.failed = True
self.log.exception(e)
finally:
self.finished = True
self.log.exception(e)
def task_run(self) -> None:
raise NotImplementedError
def run_cmd(self, cmd: list[str]) -> CmdState:
cwd = os.getcwd()
self.log.debug(f"Working directory: {cwd}")
self.log.debug(f"Running command: {shlex.join(cmd)}")
p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
# shell=True,
cwd=cwd,
)
self.procs.append(CmdState(p))
p_state = self.procs[-1]
## TODO: If two clients are connected to the same task,
def logs_iter(self) -> Iterator[str]:
with self.logs_lock:
for proc in self.procs:
if self.finished:
self.log.debug("log iter: Task is finished")
break
if proc.done:
for line in proc.stderr:
yield line
for line in proc.stdout:
yield line
continue
while True:
out = proc._output
line = out.get()
if line is None:
break
yield line
while p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([p.stderr, p.stdout], [], [], 0)
if p.stderr in rlist:
assert p.stderr is not None
line = p.stderr.readline()
if line != "":
p_state.stderr.append(line.strip("\n"))
self.log.debug(f"stderr: {line}")
p_state.output.put(line)
def register_cmds(self, num_cmds: int) -> Iterable[CmdState]:
for i in range(num_cmds):
cmd = CmdState(self.log)
self.procs.append(cmd)
if p.stdout in rlist:
assert p.stdout is not None
line = p.stdout.readline()
if line != "":
p_state.stdout.append(line.strip("\n"))
self.log.debug(f"stdout: {line}")
p_state.output.put(line)
p_state.returncode = p.returncode
p_state.output.put(None)
p_state.done = True
if p.returncode != 0:
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command")
return p_state
for cmd in self.procs:
yield cmd
# TODO: We need to test concurrency
class TaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()