cli: fix remaining typing errors

This commit is contained in:
Jörg Thalheim 2023-09-27 11:45:07 +02:00 committed by Mic92
parent 244ae37144
commit 4317e681cf
3 changed files with 10 additions and 9 deletions

View File

@ -2,7 +2,7 @@ import asyncio
import json
import logging
import shlex
from typing import Annotated
from typing import Annotated, Iterator
from uuid import UUID
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request, status
@ -34,12 +34,10 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
class NixBuildException(HTTPException):
def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
self.uuid = uuid
def __init__(self, msg: str, loc: list = ["body", "flake_attr"]):
detail = [
{
"loc": loc,
"uuid": str(uuid),
"msg": msg,
"type": "value_error",
}
@ -65,7 +63,7 @@ class BuildVmTask(BaseTask):
vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
self.log.debug(f"vm_path: {vm_path}")
self.run_cmd(vm_path)
self.run_cmd([vm_path])
self.finished = True
except Exception as e:
self.failed = True
@ -103,7 +101,6 @@ async def inspect_vm(
if proc.returncode != 0:
raise NixBuildException(
""
f"""
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
command: {shlex.join(cmd)}
@ -127,7 +124,7 @@ async def get_status(uuid: UUID) -> VmStatusResponse:
@router.get("/api/vms/{uuid}/logs")
async def get_logs(uuid: UUID) -> StreamingResponse:
# Generator function that yields log lines as they are available
def stream_logs():
def stream_logs() -> Iterator[str]:
task = get_task(uuid)
for proc in task.procs:

View File

@ -5,6 +5,7 @@ import select
import shlex
import subprocess
import threading
from typing import Any
from uuid import UUID, uuid4
@ -105,14 +106,14 @@ def get_task(uuid: UUID) -> BaseTask:
return POOL[uuid]
def register_task(task: BaseTask, *kwargs) -> UUID:
def register_task(task: type, *args: Any) -> UUID:
global POOL
if not issubclass(task, BaseTask):
raise TypeError("task must be a subclass of BaseTask")
uuid = uuid4()
inst_task = task(uuid, *kwargs)
inst_task = task(uuid, *args)
POOL[uuid] = inst_task
inst_task.start()
return uuid

View File

@ -2,6 +2,7 @@ from pathlib import Path
import pytest
from api import TestClient
from httpx import SyncByteStream
# @pytest.mark.impure
# def test_inspect(api: TestClient, test_flake_with_core: Path) -> None:
@ -41,6 +42,7 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
response = api.get(f"/api/vms/{uuid}/logs")
print("=========FLAKE LOGS==========")
assert isinstance(response.stream, SyncByteStream)
for line in response.stream:
assert line != b"", "Failed to get vm logs"
print(line.decode("utf-8"), end="")
@ -48,6 +50,7 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
assert response.status_code == 200, "Failed to get vm logs"
response = api.get(f"/api/vms/{uuid}/logs")
assert isinstance(response.stream, SyncByteStream)
print("=========VM LOGS==========")
for line in response.stream:
assert line != b"", "Failed to get vm logs"