Extracted threadpool to task_manager.py

This commit is contained in:
Luis Hebendanz 2023-09-26 19:36:01 +02:00 committed by Mic92
parent 04f3547be0
commit c2fb42e953
4 changed files with 185 additions and 150 deletions

View File

@ -28,8 +28,11 @@ def setup_app() -> FastAPI:
app.include_router(flake.router)
app.include_router(health.router)
app.include_router(machines.router)
app.include_router(root.router)
app.include_router(vms.router)
# Needs to be last in register. Because of wildcard route
app.include_router(root.router)
app.add_exception_handler(vms.NixBuildException, vms.nix_build_exception_handler)
app.mount("/static", StaticFiles(directory=asset_path()), name="static")
@ -37,6 +40,7 @@ def setup_app() -> FastAPI:
for route in app.routes:
if isinstance(route, APIRoute):
route.operation_id = route.name # in this case, 'read_items'
log.debug(f"Registered route: {route}")
return app

View File

@ -1,11 +1,10 @@
import asyncio
import json
import logging
import os
import select
import queue
import shlex
import io
import subprocess
import pipes
import threading
from typing import Annotated, AsyncIterator
from uuid import UUID, uuid4
@ -23,13 +22,10 @@ from fastapi.responses import JSONResponse, StreamingResponse
from ...nix import nix_build, nix_eval
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, get_task, register_task
# Logging setup
log = logging.getLogger(__name__)
router = APIRouter()
app = FastAPI()
def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]:
@ -48,35 +44,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
)
@router.post("/api/vms/inspect")
async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
) -> VmInspectResponse:
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
proc = await asyncio.create_subprocess_exec(
cmd[0],
*cmd[1:],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise NixBuildException(
f"""
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
data = json.loads(stdout)
return VmInspectResponse(
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
)
class NixBuildException(HTTPException):
def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
self.uuid = uuid
@ -93,146 +60,97 @@ class NixBuildException(HTTPException):
)
class ProcessState:
def __init__(self, proc: subprocess.Popen):
self.proc: subprocess.Process = proc
self.stdout: list[str] = []
self.stderr: list[str] = []
self.returncode: int | None = None
self.done: bool = False
class BuildVmTask(BaseTask):
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
super().__init__(uuid)
self.vm = vm
class BuildVM(threading.Thread):
def __init__(self, vm: VmConfig, uuid: UUID):
# calling parent class constructor
threading.Thread.__init__(self)
# constructor
self.vm: VmConfig = vm
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
self.procs: list[ProcessState] = []
self.failed: bool = False
self.finished: bool = False
def run(self):
def run(self) -> None:
try:
self.log.debug(f"BuildVM with uuid {self.uuid} started")
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
proc = self.run_cmd(cmd)
out = proc.stdout
self.log.debug(f"out: {out}")
self.log.debug(f"stdout: {proc.stdout}")
vm_path = f"{''.join(out)}/bin/run-nixos-vm"
vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
self.log.debug(f"vm_path: {vm_path}")
self.run_cmd(vm_path)
self.run_cmd(vm_path)
self.finished = True
except Exception as e:
self.failed = True
self.finished = True
log.exception(e)
def run_cmd(self, cmd: list[str]) -> ProcessState:
cwd = os.getcwd()
log.debug(f"Working directory: {cwd}")
log.debug(f"Running command: {shlex.join(cmd)}")
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=cwd,
)
state = ProcessState(process)
self.procs.append(state)
while process.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
if process.stderr in rlist:
line = process.stderr.readline()
state.stderr.append(line)
if process.stdout in rlist:
line = process.stdout.readline()
state.stdout.append(line)
state.returncode = process.returncode
state.done = True
if process.returncode != 0:
raise NixBuildException(
self.uuid, f"Failed to run command: {shlex.join(cmd)}"
)
log.debug("Successfully ran command")
return state
class VmTaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BuildVM] = {}
def __getitem__(self, uuid: str | UUID) -> BuildVM:
with self.lock:
if type(uuid) is UUID:
return self.pool[uuid]
else:
uuid = UUID(uuid)
return self.pool[uuid]
def __setitem__(self, uuid: UUID, vm: BuildVM) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"VM with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = vm
POOL: VmTaskPool = VmTaskPool()
def nix_build_exception_handler(
request: Request, exc: NixBuildException
) -> JSONResponse:
log.error("NixBuildException: %s", exc)
# del POOL[exc.uuid]
return JSONResponse(
status_code=exc.status_code,
content=jsonable_encoder(dict(detail=exc.detail)),
)
##################################
# #
# ======== VM ROUTES ======== #
# #
##################################
@router.post("/api/vms/inspect")
async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
) -> VmInspectResponse:
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
proc = await asyncio.create_subprocess_exec(
cmd[0],
*cmd[1:],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise NixBuildException(
""
f"""
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
data = json.loads(stdout)
return VmInspectResponse(
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
)
@router.get("/api/vms/{uuid}/status")
async def get_status(uuid: str) -> VmStatusResponse:
global POOL
handle = POOL[uuid]
if handle.process.poll() is None:
return VmStatusResponse(running=True, status=0)
else:
return VmStatusResponse(running=False, status=handle.process.returncode)
task = get_task(uuid)
return VmStatusResponse(running=not task.finished, status=0)
@router.get("/api/vms/{uuid}/logs")
async def get_logs(uuid: str) -> StreamingResponse:
async def stream_logs() -> AsyncIterator[str]:
global POOL
handle = POOL[uuid]
for proc in handle.procs.values():
while True:
if proc.stdout.empty() and proc.stderr.empty() and not proc.done:
await asyncio.sleep(0.1)
continue
if proc.stdout.empty() and proc.stderr.empty() and proc.done:
break
task = get_task(uuid)
for proc in task.procs:
if proc.done:
for line in proc.stderr:
yield line
for line in proc.stdout:
yield line
for line in proc.stderr:
else:
while True:
if proc.output_pipe.empty() and proc.done:
break
line = await proc.output_pipe.get()
yield line
return StreamingResponse(
@ -240,14 +158,9 @@ async def get_logs(uuid: str) -> StreamingResponse:
media_type="text/plain",
)
@router.post("/api/vms/create")
async def create_vm(
vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
) -> VmCreateResponse:
global POOL
uuid = uuid4()
handle = BuildVM(vm, uuid)
handle.start()
POOL[uuid] = handle
uuid = register_task(BuildVmTask, vm)
return VmCreateResponse(uuid=str(uuid))

View File

@ -0,0 +1,114 @@
import logging
import os
import queue
import select
import shlex
import subprocess
import threading
from uuid import UUID, uuid4
class CmdState:
def __init__(self, proc: subprocess.Popen) -> None:
self.proc: subprocess.Process = proc
self.stdout: list[str] = []
self.stderr: list[str] = []
self.output_pipe: asyncio.Queue = asyncio.Queue()
self.returncode: int | None = None
self.done: bool = False
class BaseTask(threading.Thread):
def __init__(self, uuid: UUID) -> None:
# calling parent class constructor
threading.Thread.__init__(self)
# constructor
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
self.procs: list[CmdState] = []
self.failed: bool = False
self.finished: bool = False
def run(self) -> None:
self.finished = True
def run_cmd(self, cmd: list[str]) -> CmdState:
cwd = os.getcwd()
self.log.debug(f"Working directory: {cwd}")
self.log.debug(f"Running command: {shlex.join(cmd)}")
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=cwd,
)
state = CmdState(process)
self.procs.append(state)
while process.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
if process.stderr in rlist:
line = process.stderr.readline()
if line != "":
state.stderr.append(line.strip('\n'))
state.output_pipe.put_nowait(line)
if process.stdout in rlist:
line = process.stdout.readline()
if line != "":
state.stdout.append(line.strip('\n'))
state.output_pipe.put_nowait(line)
state.returncode = process.returncode
state.done = True
if process.returncode != 0:
raise RuntimeError(
f"Failed to run command: {shlex.join(cmd)}"
)
self.log.debug("Successfully ran command")
return state
class TaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BaseTask] = {}
def __getitem__(self, uuid: str | UUID) -> BaseTask:
with self.lock:
if type(uuid) is UUID:
return self.pool[uuid]
else:
uuid = UUID(uuid)
return self.pool[uuid]
def __setitem__(self, uuid: UUID, vm: BaseTask) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"VM with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = vm
POOL: TaskPool = TaskPool()
def get_task(uuid: UUID) -> BaseTask:
global POOL
return POOL[uuid]
def register_task(task: BaseTask, *kwargs) -> UUID:
global POOL
if not issubclass(task, BaseTask):
raise TypeError("task must be a subclass of BaseTask")
uuid = uuid4()
inst_task = task(uuid, *kwargs)
POOL[uuid] = inst_task
inst_task.start()
return uuid

View File

@ -31,14 +31,18 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
graphics=True,
),
)
assert response.status_code == 200, "Failed to inspect vm"
assert response.status_code == 200, "Failed to create vm"
uuid = response.json()["uuid"]
assert len(uuid) == 36
assert uuid.count("-") == 4
response = api.get(f"/api/vms/{uuid}/status")
assert response.status_code == 200, "Failed to get vm status"
response = api.get(f"/api/vms/{uuid}/logs")
print("=========LOGS==========")
for line in response.stream:
print(line)
assert response.status_code == 200, "Failed to get vm status"
assert response.status_code == 200, "Failed to get vm logs"