improve task manager to report exceptions better
All checks were successful
checks-impure / test (pull_request) Successful in 19s
checks / test (pull_request) Successful in 31s

This commit is contained in:
Jörg Thalheim 2023-10-04 16:44:26 +02:00
parent ff1fb784e7
commit 04ba80f614
5 changed files with 70 additions and 79 deletions

View File

@ -4,125 +4,117 @@ import queue
import select
import shlex
import subprocess
import sys
import threading
import traceback
from enum import Enum
from typing import Any, Iterator, Type, TypeVar
from uuid import UUID, uuid4
from .errors import ClanError
class CmdState:
class Command:
def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log
self.p: subprocess.Popen | None = None
self.stdout: list[str] = []
self.stderr: list[str] = []
self._output: queue.SimpleQueue = queue.SimpleQueue()
self.returncode: int | None = None
self.done: bool = False
self.running: bool = False
self.cmd_str: str | None = None
self.workdir: str | None = None
self.lines: list[str] = []
def close_queue(self) -> None:
if self.p is not None:
self.returncode = self.p.returncode
self._output.put(None)
self.running = False
self.done = True
def run(self, cmd: list[str]) -> None:
self.running = True
try:
self.cmd_str = shlex.join(cmd)
self.workdir = os.getcwd()
self.log.debug(f"Working directory: {self.workdir}")
self.log.debug(f"Running command: {shlex.join(cmd)}")
self.p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=self.workdir,
)
assert self.p.stdout is not None and self.p.stderr is not None
os.set_blocking(self.p.stdout.fileno(), False)
os.set_blocking(self.p.stderr.fileno(), False)
while self.p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
if self.p.stderr in rlist:
assert self.p.stderr is not None
line = self.p.stderr.readline()
if line != "":
line = line.strip("\n")
self.stderr.append(line)
self.log.debug("stderr: %s", line)
self._output.put(line + "\n")
if self.p.stdout in rlist:
assert self.p.stdout is not None
line = self.p.stdout.readline()
if line != "":
line = line.strip("\n")
self.stdout.append(line)
self.log.debug("stdout: %s", line)
self._output.put(line + "\n")
for fd in rlist:
try:
for line in fd:
self.log.debug("stdout: %s", line)
self.lines.append(line)
self._output.put(line)
except BlockingIOError:
continue
if self.p.returncode != 0:
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command")
finally:
self.close_queue()
class BaseTask(threading.Thread):
def __init__(self, uuid: UUID) -> None:
# calling parent class constructor
threading.Thread.__init__(self)
class TaskStatus(str, Enum):
NOTSTARTED = "NOTSTARTED"
RUNNING = "RUNNING"
FINISHED = "FINISHED"
FAILED = "FAILED"
class BaseTask:
def __init__(self, uuid: UUID) -> None:
# constructor
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
self.procs: list[CmdState] = []
self.failed: bool = False
self.finished: bool = False
self.procs: list[Command] = []
self.status = TaskStatus.NOTSTARTED
self.logs_lock = threading.Lock()
self.error: Exception | None = None
def run(self) -> None:
def _run(self) -> None:
self.status = TaskStatus.RUNNING
try:
self.task_run()
self.run()
except Exception as e:
# FIXME: fix exception handling here
traceback.print_exception(*sys.exc_info())
for proc in self.procs:
proc.close_queue()
self.failed = True
self.error = e
self.log.exception(e)
finally:
self.finished = True
self.status = TaskStatus.FAILED
else:
self.status = TaskStatus.FINISHED
def task_run(self) -> None:
def run(self) -> None:
raise NotImplementedError
## TODO: If two clients are connected to the same task,
def logs_iter(self) -> Iterator[str]:
def log_lines(self) -> Iterator[str]:
with self.logs_lock:
for proc in self.procs:
if self.finished:
self.log.debug("log iter: Task is finished")
break
if self.status == TaskStatus.FINISHED:
return
# process has finished
if proc.done:
for line in proc.stderr:
yield line + "\n"
for line in proc.stdout:
yield line + "\n"
continue
while True:
out = proc._output
line = out.get()
if line is None:
break
yield line
for line in proc.lines:
yield line
else:
while line := proc._output.get():
yield line
def register_cmds(self, num_cmds: int) -> Iterator[CmdState]:
for i in range(num_cmds):
cmd = CmdState(self.log)
def register_commands(self, num_cmds: int) -> Iterator[Command]:
for _ in range(num_cmds):
cmd = Command(self.log)
self.procs.append(cmd)
for cmd in self.procs:
@ -165,6 +157,6 @@ def create_task(task_type: Type[T], *args: Any) -> T:
uuid = uuid4()
task = task_type(uuid, *args)
threading.Thread(target=task._run).start()
POOL[uuid] = task
task.start()
return task

View File

@ -9,7 +9,7 @@ from uuid import UUID
from ..dirs import get_clan_flake_toplevel
from ..nix import nix_build, nix_shell
from ..task_manager import BaseTask, CmdState, create_task
from ..task_manager import BaseTask, Command, create_task
from .inspect import VmConfig, inspect_vm
@ -18,7 +18,7 @@ class BuildVmTask(BaseTask):
super().__init__(uuid)
self.vm = vm
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
def get_vm_create_info(self, cmds: Iterator[Command]) -> dict:
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd = next(cmds)
@ -30,13 +30,13 @@ class BuildVmTask(BaseTask):
]
)
)
vm_json = "".join(cmd.stdout)
vm_json = "".join(cmd.lines)
self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f:
with open(vm_json.strip()) as f:
return json.load(f)
def task_run(self) -> None:
cmds = self.register_cmds(4)
def run(self) -> None:
cmds = self.register_commands(4)
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")
@ -121,7 +121,7 @@ def create_command(args: argparse.Namespace) -> None:
vm = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine))
task = create_vm(vm)
for line in task.logs_iter():
for line in task.log_lines():
print(line, end="")

View File

@ -27,9 +27,8 @@ async def inspect_vm(
@router.get("/api/vms/{uuid}/status")
async def get_vm_status(uuid: UUID) -> VmStatusResponse:
task = get_task(uuid)
status: list[int | None] = list(map(lambda x: x.returncode, task.procs))
log.debug(msg=f"returncodes: {status}. task.finished: {task.finished}")
return VmStatusResponse(running=not task.finished, returncode=status)
log.debug(msg=f"error: {task.error}, task.status: {task.status}")
return VmStatusResponse(status=task.status, error=str(task.error))
@router.get("/api/vms/{uuid}/logs")
@ -38,7 +37,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
def stream_logs() -> Iterator[str]:
task = get_task(uuid)
yield from task.logs_iter()
yield from task.log_lines()
return StreamingResponse(
content=stream_logs(),

View File

@ -3,6 +3,7 @@ from typing import List
from pydantic import BaseModel, Field
from ..task_manager import TaskStatus
from ..vms.inspect import VmConfig
@ -38,8 +39,8 @@ class SchemaResponse(BaseModel):
class VmStatusResponse(BaseModel):
returncode: list[int | None]
running: bool
error: str | None
status: TaskStatus
class VmCreateResponse(BaseModel):

View File

@ -58,14 +58,13 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
print("=========VM LOGS==========")
assert isinstance(response.stream, SyncByteStream)
for line in response.stream:
assert line != b"", "Failed to get vm logs"
print(line.decode("utf-8"))
print("=========END LOGS==========")
assert response.status_code == 200, "Failed to get vm logs"
response = api.get(f"/api/vms/{uuid}/status")
assert response.status_code == 200, "Failed to get vm status"
returncodes = response.json()["returncode"]
assert response.json()["running"] is False, "VM is still running. Should be stopped"
for exit_code in returncodes:
assert exit_code == 0, "One VM failed with exit code != 0"
data = response.json()
assert (
data["status"] == "FINISHED"
), f"Expected to be finished, but got {data['status']} ({data})"