Improving endpoint

This commit is contained in:
Luis Hebendanz 2023-09-25 20:09:27 +02:00 committed by Mic92
parent d16bb5db26
commit f6c8b963c1
5 changed files with 156 additions and 50 deletions

View File

@ -1,9 +1,9 @@
import datetime
import logging
from typing import Any
class CustomFormatter(logging.Formatter):
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
@ -11,7 +11,8 @@ class CustomFormatter(logging.Formatter):
green = "\u001b[32m"
blue = "\u001b[34m"
def format_str(color):
@staticmethod
def format_str(color: str) -> str:
reset = "\x1b[0m"
return f"{color}%(levelname)s{reset}:(%(filename)s:%(lineno)d): %(message)s"
@ -20,24 +21,23 @@ class CustomFormatter(logging.Formatter):
logging.INFO: format_str(green),
logging.WARNING: format_str(yellow),
logging.ERROR: format_str(red),
logging.CRITICAL: format_str(bold_red)
logging.CRITICAL: format_str(bold_red),
}
def formatTime(self, record,datefmt=None):
def format_time(self, record: Any, datefmt: Any = None) -> str:
now = datetime.datetime.now()
now = now.strftime("%H:%M:%S")
return now
def format(self, record):
def format(self, record: Any) -> str:
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
formatter.formatTime = self.formatTime
formatter.formatTime = self.format_time
return formatter.format(record)
def register(level):
def register(level: Any) -> None:
ch = logging.StreamHandler()
ch.setLevel(level)
ch.setFormatter(CustomFormatter())
logging.basicConfig(level=level, handlers=[ch])

View File

@ -1,10 +1,11 @@
import logging
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles
import logging
from .. import custom_logger
from .. import custom_logger
from .assets import asset_path
from .routers import flake, health, machines, root, vms
@ -39,7 +40,7 @@ def setup_app() -> FastAPI:
return app
#TODO: How do I get the log level from the command line in here?
# TODO: How do I get the log level from the command line in here?
custom_logger.register(logging.DEBUG)
app = setup_app()

View File

@ -2,16 +2,27 @@ import asyncio
import json
import logging
import os
import select
import queue
import shlex
import uuid
import subprocess
import threading
from typing import Annotated, AsyncIterator
from uuid import UUID, uuid4
from fastapi import APIRouter, Body, FastAPI, HTTPException, Request, status, BackgroundTasks
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
HTTPException,
Request,
status,
)
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, StreamingResponse
from ...nix import nix_build, nix_eval
from ..schemas import VmConfig, VmInspectResponse, VmCreateResponse
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
# Logging setup
log = logging.getLogger(__name__)
@ -37,8 +48,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
)
@router.post("/api/vms/inspect")
async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
@ -68,9 +77,8 @@ command output:
)
class NixBuildException(HTTPException):
def __init__(self, uuid: uuid.UUID, msg: str,loc: list = ["body", "flake_attr"]):
def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
self.uuid = uuid
detail = [
{
@ -85,74 +93,161 @@ class NixBuildException(HTTPException):
)
import threading
import subprocess
import uuid
class ProcessState:
def __init__(self, proc: subprocess.Popen):
self.proc: subprocess.Process = proc
self.stdout: list[str] = []
self.stderr: list[str] = []
self.returncode: int | None = None
self.done: bool = False
class BuildVM(threading.Thread):
def __init__(self, vm: VmConfig, uuid: uuid.UUID):
def __init__(self, vm: VmConfig, uuid: UUID):
# calling parent class constructor
threading.Thread.__init__(self)
# constructor
self.vm: VmConfig = vm
self.uuid: uuid.UUID = uuid
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
self.process: subprocess.Popen = None
self.procs: list[ProcessState] = []
self.failed: bool = False
self.finished: bool = False
def run(self):
self.log.debug(f"BuildVM with uuid {self.uuid} started")
try:
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
(out, err) = self.run_cmd(cmd)
vm_path = f'{out.strip()}/bin/run-nixos-vm'
self.log.debug(f"BuildVM with uuid {self.uuid} started")
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
self.log.debug(f"vm_path: {vm_path}")
proc = self.run_cmd(cmd)
out = proc.stdout
self.log.debug(f"out: {out}")
(out, err) = self.run_cmd(vm_path)
vm_path = f"{''.join(out)}/bin/run-nixos-vm"
self.log.debug(f"vm_path: {vm_path}")
self.run_cmd(vm_path)
self.finished = True
except Exception as e:
self.failed = True
self.finished = True
log.exception(e)
def run_cmd(self, cmd: list[str]):
cwd=os.getcwd()
def run_cmd(self, cmd: list[str]) -> ProcessState:
cwd = os.getcwd()
log.debug(f"Working directory: {cwd}")
log.debug(f"Running command: {shlex.join(cmd)}")
self.process = subprocess.Popen(
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=cwd,
)
state = ProcessState(process)
self.procs.append(state)
self.process.wait()
if self.process.returncode != 0:
raise NixBuildException(self.uuid, f"Failed to run command: {shlex.join(cmd)}")
while process.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
if process.stderr in rlist:
line = process.stderr.readline()
state.stderr.append(line)
if process.stdout in rlist:
line = process.stdout.readline()
state.stdout.append(line)
log.info("Successfully ran command")
return (self.process.stdout, self.process.stderr)
state.returncode = process.returncode
state.done = True
POOL: dict[uuid.UUID, BuildVM] = {}
if process.returncode != 0:
raise NixBuildException(
self.uuid, f"Failed to run command: {shlex.join(cmd)}"
)
log.debug("Successfully ran command")
return state
class VmTaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BuildVM] = {}
def __getitem__(self, uuid: str | UUID) -> BuildVM:
with self.lock:
if type(uuid) is UUID:
return self.pool[uuid]
else:
uuid = UUID(uuid)
return self.pool[uuid]
def __setitem__(self, uuid: UUID, vm: BuildVM) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"VM with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = vm
POOL: VmTaskPool = VmTaskPool()
def nix_build_exception_handler(
request: Request, exc: NixBuildException
) -> JSONResponse:
log.error("NixBuildException: %s", exc)
del POOL[exc.uuid]
# del POOL[exc.uuid]
return JSONResponse(
status_code=exc.status_code,
content=jsonable_encoder(dict(detail=exc.detail)),
)
@router.get("/api/vms/{uuid}/status")
async def get_status(uuid: str) -> VmStatusResponse:
global POOL
handle = POOL[uuid]
if handle.process.poll() is None:
return VmStatusResponse(running=True, status=0)
else:
return VmStatusResponse(running=False, status=handle.process.returncode)
@router.get("/api/vms/{uuid}/logs")
async def get_logs(uuid: str) -> StreamingResponse:
async def stream_logs() -> AsyncIterator[str]:
global POOL
handle = POOL[uuid]
for proc in handle.procs.values():
while True:
if proc.stdout.empty() and proc.stderr.empty() and not proc.done:
await asyncio.sleep(0.1)
continue
if proc.stdout.empty() and proc.stderr.empty() and proc.done:
break
for line in proc.stdout:
yield line
for line in proc.stderr:
yield line
return StreamingResponse(
content=stream_logs(),
media_type="text/plain",
)
@router.post("/api/vms/create")
async def create_vm(vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks) -> StreamingResponse:
handle_id = uuid.uuid4()
handle = BuildVM(vm, handle_id)
async def create_vm(
vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
) -> VmCreateResponse:
global POOL
uuid = uuid4()
handle = BuildVM(vm, uuid)
handle.start()
POOL[handle_id] = handle
return VmCreateResponse(uuid=str(handle_id))
POOL[uuid] = handle
return VmCreateResponse(uuid=str(uuid))

View File

@ -43,9 +43,16 @@ class VmConfig(BaseModel):
memory_size: int
graphics: bool
class VmStatusResponse(BaseModel):
status: int
running: bool
class VmCreateResponse(BaseModel):
uuid: str
class VmInspectResponse(BaseModel):
config: VmConfig

View File

@ -1,4 +1,5 @@
import argparse
import logging
import subprocess
import time
import urllib.request
@ -11,6 +12,8 @@ from typing import Iterator
# XXX: can we dynamically load this using nix develop?
from uvicorn import run
log = logging.getLogger(__name__)
def defer_open_browser(base_url: str) -> None:
for i in range(5):
@ -24,7 +27,7 @@ def defer_open_browser(base_url: str) -> None:
@contextmanager
def spawn_node_dev_server(host: str, port: int) -> Iterator[None]:
logger.info("Starting node dev server...")
log.info("Starting node dev server...")
path = Path(__file__).parent.parent.parent.parent / "ui"
with subprocess.Popen(
[