Fixed failing tests
Some checks failed
checks-impure / test (pull_request) Failing after 12s
checks / test (pull_request) Successful in 1m21s

This commit is contained in:
Luis Hebendanz 2023-10-03 11:51:31 +02:00
parent 814d81c1d2
commit 653ad99b22
4 changed files with 52 additions and 58 deletions

View File

@ -1,25 +1,25 @@
import argparse
import asyncio
from typing import Any, Iterator
from uuid import UUID
import threading
import queue
from fastapi.responses import StreamingResponse
from ..dirs import get_clan_flake_toplevel
from ..webui.routers import vms
from ..webui.schemas import VmConfig
from typing import Any, Iterator
from fastapi.responses import StreamingResponse
import pdb
def read_stream_response(stream: StreamingResponse) -> Iterator[Any]:
iterator = stream.body_iterator
while True:
try:
tem = asyncio.run(iterator.__anext__())
tem = asyncio.run(iterator.__anext__()) # type: ignore
except StopAsyncIteration:
break
yield tem
def create(args: argparse.Namespace) -> None:
clan_dir = get_clan_flake_toplevel().as_posix()
vm = VmConfig(
@ -34,13 +34,12 @@ def create(args: argparse.Namespace) -> None:
print(res.json())
uuid = UUID(res.uuid)
res = asyncio.run(vms.get_vm_logs(uuid))
stream = asyncio.run(vms.get_vm_logs(uuid))
for line in read_stream_response(res):
for line in read_stream_response(stream):
print(line)
def register_create_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument("machine", type=str)
parser.set_defaults(func=create)

View File

@ -1,9 +1,8 @@
import json
import logging
import tempfile
import time
from pathlib import Path
from typing import Annotated, Iterator, Iterable
from typing import Annotated, Iterator
from uuid import UUID
from fastapi import APIRouter, Body
@ -11,7 +10,7 @@ from fastapi.responses import StreamingResponse
from ...nix import nix_build, nix_eval, nix_shell
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, get_task, register_task, CmdState
from ..task_manager import BaseTask, CmdState, get_task, register_task
from .utils import run_cmd
log = logging.getLogger(__name__)
@ -39,7 +38,7 @@ class BuildVmTask(BaseTask):
super().__init__(uuid)
self.vm = vm
def get_vm_create_info(self, cmds: Iterable[CmdState]) -> dict:
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd = next(cmds)
@ -71,31 +70,36 @@ class BuildVmTask(BaseTask):
disk_img = f"{tmpdir_}/disk.img"
cmd = next(cmds)
cmd.run(nix_shell(
["qemu"],
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
)
cmd = next(cmds)
cmd.run(
[
"qemu-img",
"create",
"-f",
"raw",
"mkfs.ext4",
"-L",
"nixos",
disk_img,
"1024M",
],
))
]
)
cmd = next(cmds)
cmd.run([
"mkfs.ext4",
"-L",
"nixos",
disk_img,
])
cmd = next(cmds)
cmd.run(nix_shell(
["qemu"],
[
# fmt: off
cmd.run(
nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm",
"-name", machine,
"-m", f'{vm_config["memorySize"]}M',
@ -113,9 +117,10 @@ class BuildVmTask(BaseTask):
"-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
))
# fmt: on
],
)
)
@router.post("/api/vms/inspect")
@ -144,8 +149,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
def stream_logs() -> Iterator[str]:
task = get_task(uuid)
for line in task.logs_iter():
yield line
yield from task.logs_iter()
return StreamingResponse(
content=stream_logs(),

View File

@ -5,14 +5,14 @@ import select
import shlex
import subprocess
import threading
from typing import Any, Iterable, Iterator
from typing import Any, Iterator
from uuid import UUID, uuid4
class CmdState:
def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log
self.p: subprocess.Popen = None
self.p: subprocess.Popen | None = None
self.stdout: list[str] = []
self.stderr: list[str] = []
self._output: queue.SimpleQueue = queue.SimpleQueue()
@ -51,7 +51,7 @@ class CmdState:
assert self.p.stderr is not None
line = self.p.stderr.readline()
if line != "":
line = line.strip('\n')
line = line.strip("\n")
self.stderr.append(line)
self.log.debug("stderr: %s", line)
self._output.put(line)
@ -60,7 +60,7 @@ class CmdState:
assert self.p.stdout is not None
line = self.p.stdout.readline()
if line != "":
line = line.strip('\n')
line = line.strip("\n")
self.stdout.append(line)
self.log.debug("stdout: %s", line)
self._output.put(line)
@ -93,14 +93,14 @@ class BaseTask(threading.Thread):
for proc in self.procs:
proc.close_queue()
self.failed = True
self.finished = True
self.log.exception(e)
finally:
self.finished = True
def task_run(self) -> None:
raise NotImplementedError
## TODO: If two clients are connected to the same task,
## TODO: If two clients are connected to the same task,
def logs_iter(self) -> Iterator[str]:
with self.logs_lock:
for proc in self.procs:
@ -120,7 +120,7 @@ class BaseTask(threading.Thread):
break
yield line
def register_cmds(self, num_cmds: int) -> Iterable[CmdState]:
def register_cmds(self, num_cmds: int) -> Iterator[CmdState]:
for i in range(num_cmds):
cmd = CmdState(self.log)
self.procs.append(cmd)

View File

@ -74,20 +74,11 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
assert response.status_code == 200, "Failed to get vm status"
response = api.get(f"/api/vms/{uuid}/logs")
print("=========FLAKE LOGS==========")
assert isinstance(response.stream, SyncByteStream)
for line in response.stream:
assert line != b"", "Failed to get vm logs"
print(line.decode("utf-8"), end="")
print("=========END LOGS==========")
assert response.status_code == 200, "Failed to get vm logs"
response = api.get(f"/api/vms/{uuid}/logs")
assert isinstance(response.stream, SyncByteStream)
print("=========VM LOGS==========")
assert isinstance(response.stream, SyncByteStream)
for line in response.stream:
assert line != b"", "Failed to get vm logs"
print(line.decode("utf-8"), end="")
print(line.decode("utf-8"))
print("=========END LOGS==========")
assert response.status_code == 200, "Failed to get vm logs"