move out vm logic out of controller

This commit is contained in:
Jörg Thalheim 2023-10-03 16:47:14 +02:00
parent dbe289f702
commit ff11340507
10 changed files with 203 additions and 218 deletions

View File

@ -0,0 +1,28 @@
import asyncio
import logging
import shlex
from .errors import ClanError
log = logging.getLogger(__name__)
async def run(cmd: list[str]) -> bytes:
log.debug(f"$: {shlex.join(cmd)}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise ClanError(
f"""
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
return stdout

View File

@ -1,26 +1,114 @@
import argparse
import asyncio
from typing import Any, Iterator
import json
import tempfile
from pathlib import Path
from typing import Iterator
from uuid import UUID
from fastapi.responses import StreamingResponse
from ..dirs import get_clan_flake_toplevel
from ..webui.routers import vms
from ..webui.schemas import VmConfig
from ..nix import nix_build, nix_shell
from ..task_manager import BaseTask, CmdState, get_task, register_task
from .inspect import VmConfig
def read_stream_response(stream: StreamingResponse) -> Iterator[Any]:
iterator = stream.body_iterator
while True:
try:
tem = asyncio.run(iterator.__anext__()) # type: ignore
except StopAsyncIteration:
break
yield tem
class BuildVmTask(BaseTask):
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
super().__init__(uuid)
self.vm = vm
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd = next(cmds)
cmd.run(
nix_build(
[
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create'
]
)
)
vm_json = "".join(cmd.stdout)
self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f:
return json.load(f)
def task_run(self) -> None:
cmds = self.register_cmds(4)
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")
# TODO: We should get this from the vm argument
vm_config = self.get_vm_create_info(cmds)
with tempfile.TemporaryDirectory() as tmpdir_:
xchg_dir = Path(tmpdir_) / "xchg"
xchg_dir.mkdir()
disk_img = f"{tmpdir_}/disk.img"
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
)
cmd = next(cmds)
cmd.run(
nix_shell(
["e2fsprogs"],
[
"mkfs.ext4",
"-L",
"nixos",
disk_img,
],
)
)
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm",
"-name", machine,
"-m", f'{vm_config["memorySize"]}M',
"-smp", str(vm_config["cores"]),
"-device", "virtio-rng-pci",
"-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0",
"-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg",
"-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report',
"-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root",
"-device", "virtio-keyboard",
"-usb",
"-device", "usb-tablet,bus=usb-bus.0",
"-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
)
)
def create(args: argparse.Namespace) -> None:
def create_vm(vm: VmConfig) -> UUID:
return register_task(BuildVmTask, vm)
def create_command(args: argparse.Namespace) -> None:
clan_dir = get_clan_flake_toplevel().as_posix()
vm = VmConfig(
flake_url=clan_dir,
@ -30,17 +118,12 @@ def create(args: argparse.Namespace) -> None:
memory_size=0,
)
res = asyncio.run(vms.create_vm(vm))
print(res.json())
uuid = UUID(res.uuid)
stream = asyncio.run(vms.get_vm_logs(uuid))
for line in read_stream_response(stream):
uuid = create_vm(vm)
task = get_task(uuid)
for line in task.logs_iter():
print(line, end="")
print("")
def register_create_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument("machine", type=str)
parser.set_defaults(func=create)
parser.set_defaults(func=create_command)

View File

@ -1,16 +1,42 @@
import argparse
import asyncio
import json
from pydantic import BaseModel
from ..async_cmd import run
from ..dirs import get_clan_flake_toplevel
from ..webui.routers import vms
from ..nix import nix_eval
def inspect(args: argparse.Namespace) -> None:
class VmConfig(BaseModel):
flake_url: str
flake_attr: str
cores: int
memory_size: int
graphics: bool
async def inspect_vm(flake_url: str, flake_attr: str) -> VmConfig:
cmd = nix_eval(
[
f"{flake_url}#nixosConfigurations.{json.dumps(flake_attr)}.config.system.clan.vm.config"
]
)
stdout = await run(cmd)
data = json.loads(stdout)
return VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
def inspect_command(args: argparse.Namespace) -> None:
clan_dir = get_clan_flake_toplevel().as_posix()
res = asyncio.run(vms.inspect_vm(flake_url=clan_dir, flake_attr=args.machine))
print(res.json())
res = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine))
print("Cores:", res.cores)
print("Memory size:", res.memory_size)
print("Graphics:", res.graphics)
def register_inspect_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument("machine", type=str)
parser.set_defaults(func=inspect)
parser.set_defaults(func=inspect_command)

View File

@ -5,8 +5,10 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles
from ..errors import ClanError
from .assets import asset_path
from .routers import flake, health, machines, root, utils, vms
from .error_handlers import clan_error_handler
from .routers import flake, health, machines, root, vms
origins = [
"http://localhost:3000",
@ -32,9 +34,7 @@ def setup_app() -> FastAPI:
# Needs to be last in register. Because of wildcard route
app.include_router(root.router)
app.add_exception_handler(
utils.NixBuildException, utils.nix_build_exception_handler
)
app.add_exception_handler(ClanError, clan_error_handler)
app.mount("/static", StaticFiles(directory=asset_path()), name="static")

View File

@ -0,0 +1,23 @@
import logging
from fastapi import Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from ..errors import ClanError
log = logging.getLogger(__name__)
def clan_error_handler(request: Request, exc: ClanError) -> JSONResponse:
log.error("ClanError: %s", exc)
detail = [
{
"loc": [],
"msg": str(exc),
}
]
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder(dict(detail=detail)),
)

View File

@ -6,15 +6,15 @@ from fastapi import APIRouter, HTTPException
from clan_cli.webui.schemas import FlakeAction, FlakeAttrResponse, FlakeResponse
from ...async_cmd import run
from ...nix import nix_command, nix_flake_show
from .utils import run_cmd
router = APIRouter()
async def get_attrs(url: str) -> list[str]:
cmd = nix_flake_show(url)
stdout = await run_cmd(cmd)
stdout = await run(cmd)
data: dict[str, dict] = {}
try:
@ -45,7 +45,7 @@ async def inspect_flake(
# Extract the flake from the given URL
# We do this by running 'nix flake prefetch {url} --json'
cmd = nix_command(["flake", "prefetch", url, "--json", "--refresh"])
stdout = await run_cmd(cmd)
stdout = await run(cmd)
data: dict[str, str] = json.loads(stdout)
if data.get("storePath") is None:

View File

@ -1,54 +0,0 @@
import asyncio
import logging
import shlex
from fastapi import HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
log = logging.getLogger(__name__)
class NixBuildException(HTTPException):
def __init__(self, msg: str, loc: list = ["body", "flake_attr"]):
detail = [
{
"loc": loc,
"msg": msg,
"type": "value_error",
}
]
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail
)
def nix_build_exception_handler(
request: Request, exc: NixBuildException
) -> JSONResponse:
log.error("NixBuildException: %s", exc)
return JSONResponse(
status_code=exc.status_code,
content=jsonable_encoder(dict(detail=exc.detail)),
)
async def run_cmd(cmd: list[str]) -> bytes:
log.debug(f"Running command: {shlex.join(cmd)}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise NixBuildException(
f"""
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
return stdout

View File

@ -1,7 +1,4 @@
import json
import logging
import tempfile
from pathlib import Path
from typing import Annotated, Iterator
from uuid import UUID
@ -11,131 +8,20 @@ from fastapi.responses import StreamingResponse
from clan_cli.webui.routers.flake import get_attrs
from ...nix import nix_build, nix_eval, nix_shell
from ...task_manager import get_task
from ...vms import create, inspect
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, CmdState, get_task, register_task
from .utils import run_cmd
log = logging.getLogger(__name__)
router = APIRouter()
def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]:
return nix_eval(
[
f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.clan.vm.config"
]
)
def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
return nix_build(
[
f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.build.vm"
]
)
class BuildVmTask(BaseTask):
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
super().__init__(uuid)
self.vm = vm
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd = next(cmds)
cmd.run(
nix_build(
[
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create'
]
)
)
vm_json = "".join(cmd.stdout)
self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f:
return json.load(f)
def task_run(self) -> None:
cmds = self.register_cmds(4)
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")
# TODO: We should get this from the vm argument
vm_config = self.get_vm_create_info(cmds)
with tempfile.TemporaryDirectory() as tmpdir_:
xchg_dir = Path(tmpdir_) / "xchg"
xchg_dir.mkdir()
disk_img = f"{tmpdir_}/disk.img"
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
)
cmd = next(cmds)
cmd.run(
[
"mkfs.ext4",
"-L",
"nixos",
disk_img,
]
)
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm",
"-name", machine,
"-m", f'{vm_config["memorySize"]}M',
"-smp", str(vm_config["cores"]),
"-device", "virtio-rng-pci",
"-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0",
"-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg",
"-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report',
"-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root",
"-device", "virtio-keyboard",
"-usb",
"-device", "usb-tablet,bus=usb-bus.0",
"-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
)
)
@router.post("/api/vms/inspect")
async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
) -> VmInspectResponse:
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
stdout = await run_cmd(cmd)
data = json.loads(stdout)
return VmInspectResponse(
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
)
config = await inspect.inspect_vm(flake_url, flake_attr)
return VmInspectResponse(config=config)
@router.get("/api/vms/{uuid}/status")
@ -168,5 +54,5 @@ async def create_vm(vm: Annotated[VmConfig, Body()]) -> VmCreateResponse:
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Provided attribute '{vm.flake_attr}' does not exist.",
)
uuid = register_task(BuildVmTask, vm)
uuid = create.create_vm(vm)
return VmCreateResponse(uuid=str(uuid))

View File

@ -3,6 +3,8 @@ from typing import List
from pydantic import BaseModel, Field
from ..vms.inspect import VmConfig
class Status(Enum):
ONLINE = "online"
@ -35,15 +37,6 @@ class SchemaResponse(BaseModel):
schema_: dict = Field(alias="schema")
class VmConfig(BaseModel):
flake_url: str
flake_attr: str
cores: int
memory_size: int
graphics: bool
class VmStatusResponse(BaseModel):
returncode: list[int | None]
running: bool