Merge pull request 'Restructuring CLI to use API' (#387) from Qubasa-main into main
Some checks reported warnings
checks / test (push) Successful in 31s
assets1 / test (push) Successful in 7s
checks-impure / test (push) Has been cancelled

Reviewed-on: #387
This commit is contained in:
Mic92 2023-10-03 15:34:33 +00:00
commit de3084066c
27 changed files with 622 additions and 416 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
.direnv
democlan
result*
/pkgs/clan-cli/clan_cli/nixpkgs
/pkgs/clan-cli/clan_cli/webui/assets

View File

@ -12,6 +12,15 @@
"justMyCode": false,
"args": [ "--reload", "--no-open", "--log-level", "debug" ],
},
{
"name": "Clan Cli VMs",
"type": "python",
"request": "launch",
"module": "clan_cli",
"justMyCode": false,
"args": [ "vms" ],
}
]
}

View File

@ -1,12 +1,15 @@
import argparse
import logging
import sys
from types import ModuleType
from typing import Optional
from . import config, create, machines, secrets, vms, webui
from . import config, create, custom_logger, machines, secrets, vms, webui
from .errors import ClanError
from .ssh import cli as ssh_cli
log = logging.getLogger(__name__)
argcomplete: Optional[ModuleType] = None
try:
import argcomplete # type: ignore[no-redef]
@ -62,14 +65,20 @@ def create_parser(prog: Optional[str] = None) -> argparse.ArgumentParser:
def main() -> None:
parser = create_parser()
args = parser.parse_args()
if args.debug:
custom_logger.register(logging.DEBUG)
log.debug("Debug logging enabled")
else:
custom_logger.register(logging.INFO)
if not hasattr(args, "func"):
log.error("No argparse function registered")
return
try:
args.func(args)
except ClanError as e:
if args.debug:
raise
print(f"{sys.argv[0]}: {e}")
log.exception(e)
sys.exit(1)

View File

@ -0,0 +1,4 @@
from . import main
if __name__ == "__main__":
main()

View File

@ -0,0 +1,28 @@
import asyncio
import logging
import shlex
from .errors import ClanError
log = logging.getLogger(__name__)
async def run(cmd: list[str]) -> bytes:
log.debug(f"$: {shlex.join(cmd)}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise ClanError(
f"""
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
return stdout

View File

@ -1,12 +1,16 @@
import argparse
import logging
import os
from .folders import machines_folder
from .types import validate_hostname
log = logging.getLogger(__name__)
def list_machines() -> list[str]:
path = machines_folder()
log.debug(f"Listing machines in {path}")
if not path.exists():
return []
objs: list[str] = []

View File

@ -1,14 +1,18 @@
import argparse
import logging
import os
import shlex
import subprocess
import sys
from pathlib import Path
from clan_cli.errors import ClanError
from ..dirs import get_clan_flake_toplevel, module_root
from ..dirs import get_clan_flake_toplevel
from ..nix import nix_build, nix_config
log = logging.getLogger(__name__)
def build_generate_script(machine: str, clan_dir: Path) -> str:
config = nix_config()
@ -31,7 +35,8 @@ def build_generate_script(machine: str, clan_dir: Path) -> str:
def run_generate_secrets(secret_generator_script: str, clan_dir: Path) -> None:
env = os.environ.copy()
env["CLAN_DIR"] = str(clan_dir)
env["PYTHONPATH"] = str(module_root().parent) # TODO do this in the clanCore module
env["PYTHONPATH"] = ":".join(sys.path) # TODO do this in the clanCore module
print(f"generating secrets... {secret_generator_script}")
proc = subprocess.run(
[secret_generator_script],
@ -39,6 +44,8 @@ def run_generate_secrets(secret_generator_script: str, clan_dir: Path) -> None:
)
if proc.returncode != 0:
log.error("stdout: %s", proc.stdout)
log.error("stderr: %s", proc.stderr)
raise ClanError("failed to generate secrets")
else:
print("successfully generated secrets")

View File

@ -1,16 +1,20 @@
import argparse
import json
import logging
import os
import shlex
import subprocess
import sys
from pathlib import Path
from tempfile import TemporaryDirectory
from ..dirs import get_clan_flake_toplevel, module_root
from ..dirs import get_clan_flake_toplevel
from ..errors import ClanError
from ..nix import nix_build, nix_config, nix_shell
from ..ssh import parse_deployment_address
log = logging.getLogger(__name__)
def build_upload_script(machine: str, clan_dir: Path) -> str:
config = nix_config()
@ -53,7 +57,7 @@ def run_upload_secrets(
) -> None:
env = os.environ.copy()
env["CLAN_DIR"] = str(clan_dir)
env["PYTHONPATH"] = str(module_root().parent) # TODO do this in the clanCore module
env["PYTHONPATH"] = ":".join(sys.path) # TODO do this in the clanCore module
print(f"uploading secrets... {flake_attr}")
with TemporaryDirectory() as tempdir_:
tempdir = Path(tempdir_)
@ -67,6 +71,8 @@ def run_upload_secrets(
)
if proc.returncode != 0:
log.error("Stdout: %s", proc.stdout)
log.error("Stderr: %s", proc.stderr)
raise ClanError("failed to upload secrets")
h = parse_deployment_address(flake_attr, target)

View File

@ -0,0 +1,169 @@
import logging
import os
import queue
import select
import shlex
import subprocess
import threading
from typing import Any, Iterator
from uuid import UUID, uuid4
class CmdState:
def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log
self.p: subprocess.Popen | None = None
self.stdout: list[str] = []
self.stderr: list[str] = []
self._output: queue.SimpleQueue = queue.SimpleQueue()
self.returncode: int | None = None
self.done: bool = False
self.running: bool = False
self.cmd_str: str | None = None
self.workdir: str | None = None
def close_queue(self) -> None:
if self.p is not None:
self.returncode = self.p.returncode
self._output.put(None)
self.running = False
self.done = True
def run(self, cmd: list[str]) -> None:
self.running = True
try:
self.cmd_str = shlex.join(cmd)
self.workdir = os.getcwd()
self.log.debug(f"Working directory: {self.workdir}")
self.log.debug(f"Running command: {shlex.join(cmd)}")
self.p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=self.workdir,
)
while self.p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
if self.p.stderr in rlist:
assert self.p.stderr is not None
line = self.p.stderr.readline()
if line != "":
line = line.strip("\n")
self.stderr.append(line)
self.log.debug("stderr: %s", line)
self._output.put(line + "\n")
if self.p.stdout in rlist:
assert self.p.stdout is not None
line = self.p.stdout.readline()
if line != "":
line = line.strip("\n")
self.stdout.append(line)
self.log.debug("stdout: %s", line)
self._output.put(line + "\n")
if self.p.returncode != 0:
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command")
finally:
self.close_queue()
class BaseTask(threading.Thread):
def __init__(self, uuid: UUID) -> None:
# calling parent class constructor
threading.Thread.__init__(self)
# constructor
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
self.procs: list[CmdState] = []
self.failed: bool = False
self.finished: bool = False
self.logs_lock = threading.Lock()
def run(self) -> None:
try:
self.task_run()
except Exception as e:
for proc in self.procs:
proc.close_queue()
self.failed = True
self.log.exception(e)
finally:
self.finished = True
def task_run(self) -> None:
raise NotImplementedError
## TODO: If two clients are connected to the same task,
def logs_iter(self) -> Iterator[str]:
with self.logs_lock:
for proc in self.procs:
if self.finished:
self.log.debug("log iter: Task is finished")
break
if proc.done:
for line in proc.stderr:
yield line + "\n"
for line in proc.stdout:
yield line + "\n"
continue
while True:
out = proc._output
line = out.get()
if line is None:
break
yield line
def register_cmds(self, num_cmds: int) -> Iterator[CmdState]:
for i in range(num_cmds):
cmd = CmdState(self.log)
self.procs.append(cmd)
for cmd in self.procs:
yield cmd
# TODO: We need to test concurrency
class TaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BaseTask] = {}
def __getitem__(self, uuid: UUID) -> BaseTask:
with self.lock:
return self.pool[uuid]
def __setitem__(self, uuid: UUID, task: BaseTask) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"Task with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = task
POOL: TaskPool = TaskPool()
def get_task(uuid: UUID) -> BaseTask:
global POOL
return POOL[uuid]
def register_task(task: type, *args: Any) -> UUID:
global POOL
if not issubclass(task, BaseTask):
raise TypeError("task must be a subclass of BaseTask")
uuid = uuid4()
inst_task = task(uuid, *args)
POOL[uuid] = inst_task
inst_task.start()
return uuid

View File

@ -1,101 +1,129 @@
import argparse
import json
import subprocess
import tempfile
from pathlib import Path
from typing import Iterator
from uuid import UUID
from ..dirs import get_clan_flake_toplevel
from ..nix import nix_build, nix_shell
from ..task_manager import BaseTask, CmdState, get_task, register_task
from .inspect import VmConfig
def get_vm_create_info(machine: str) -> dict:
class BuildVmTask(BaseTask):
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
super().__init__(uuid)
self.vm = vm
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd = next(cmds)
cmd.run(
nix_build(
[
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create'
]
)
)
vm_json = "".join(cmd.stdout)
self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f:
return json.load(f)
def task_run(self) -> None:
cmds = self.register_cmds(4)
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")
# TODO: We should get this from the vm argument
vm_config = self.get_vm_create_info(cmds)
with tempfile.TemporaryDirectory() as tmpdir_:
xchg_dir = Path(tmpdir_) / "xchg"
xchg_dir.mkdir()
disk_img = f"{tmpdir_}/disk.img"
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
)
cmd = next(cmds)
cmd.run(
nix_shell(
["e2fsprogs"],
[
"mkfs.ext4",
"-L",
"nixos",
disk_img,
],
)
)
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm",
"-name", machine,
"-m", f'{vm_config["memorySize"]}M',
"-smp", str(vm_config["cores"]),
"-device", "virtio-rng-pci",
"-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0",
"-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg",
"-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report',
"-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root",
"-device", "virtio-keyboard",
"-usb",
"-device", "usb-tablet,bus=usb-bus.0",
"-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
)
)
def create_vm(vm: VmConfig) -> UUID:
return register_task(BuildVmTask, vm)
def create_command(args: argparse.Namespace) -> None:
clan_dir = get_clan_flake_toplevel().as_posix()
vm = VmConfig(
flake_url=clan_dir,
flake_attr=args.machine,
cores=0,
graphics=False,
memory_size=0,
)
# config = nix_config()
# system = config["system"]
vm_json = subprocess.run(
nix_build(
[
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create'
]
),
stdout=subprocess.PIPE,
check=True,
text=True,
).stdout.strip()
with open(vm_json) as f:
return json.load(f)
def create(args: argparse.Namespace) -> None:
print(f"Creating VM for {args.machine}")
machine = args.machine
vm_config = get_vm_create_info(machine)
with tempfile.TemporaryDirectory() as tmpdir_:
xchg_dir = Path(tmpdir_) / "xchg"
xchg_dir.mkdir()
disk_img = f"{tmpdir_}/disk.img"
subprocess.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
),
stdout=subprocess.PIPE,
check=True,
text=True,
)
subprocess.run(
[
"mkfs.ext4",
"-L",
"nixos",
disk_img,
],
stdout=subprocess.PIPE,
check=True,
text=True,
)
subprocess.run(
nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm",
"-name", machine,
"-m", f'{vm_config["memorySize"]}M',
"-smp", str(vm_config["cores"]),
"-device", "virtio-rng-pci",
"-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0",
"-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg",
"-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report',
"-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root",
"-device", "virtio-keyboard",
"-usb",
"-device", "usb-tablet,bus=usb-bus.0",
"-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
),
stdout=subprocess.PIPE,
check=True,
text=True,
)
uuid = create_vm(vm)
task = get_task(uuid)
for line in task.logs_iter():
print(line, end="")
def register_create_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument("machine", type=str)
parser.set_defaults(func=create)
parser.set_defaults(func=create_command)

View File

@ -1,38 +1,42 @@
import argparse
import asyncio
import json
import subprocess
from pydantic import BaseModel
from ..async_cmd import run
from ..dirs import get_clan_flake_toplevel
from ..nix import nix_eval
def get_vm_inspect_info(machine: str) -> dict:
clan_dir = get_clan_flake_toplevel().as_posix()
class VmConfig(BaseModel):
flake_url: str
flake_attr: str
# config = nix_config()
# system = config["system"]
cores: int
memory_size: int
graphics: bool
return json.loads(
subprocess.run(
nix_eval(
[
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation' # TODO use this
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.config'
]
),
stdout=subprocess.PIPE,
check=True,
text=True,
).stdout
async def inspect_vm(flake_url: str, flake_attr: str) -> VmConfig:
cmd = nix_eval(
[
f"{flake_url}#nixosConfigurations.{json.dumps(flake_attr)}.config.system.clan.vm.config"
]
)
stdout = await run(cmd)
data = json.loads(stdout)
return VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
def inspect(args: argparse.Namespace) -> None:
print(f"Creating VM for {args.machine}")
machine = args.machine
print(get_vm_inspect_info(machine))
def inspect_command(args: argparse.Namespace) -> None:
clan_dir = get_clan_flake_toplevel().as_posix()
res = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine))
print("Cores:", res.cores)
print("Memory size:", res.memory_size)
print("Graphics:", res.graphics)
def register_inspect_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument("machine", type=str)
parser.set_defaults(func=inspect)
parser.set_defaults(func=inspect_command)

View File

@ -45,6 +45,8 @@ def register_parser(parser: argparse.ArgumentParser) -> None:
help="Log level",
choices=["critical", "error", "warning", "info", "debug", "trace"],
)
# Set the args.func variable in args
if start_server is None:
parser.set_defaults(func=fastapi_is_not_installed)
else:

View File

@ -5,6 +5,11 @@ from . import register_parser
if __name__ == "__main__":
# this is use in our integration test
parser = argparse.ArgumentParser()
# call the register_parser function, which adds arguments to the parser
register_parser(parser)
args = parser.parse_args()
# call the function that is stored
# in the func attribute of args, and pass args as the argument
# look into register_parser to see how this is done
args.func(args)

View File

@ -5,9 +5,10 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles
from .. import custom_logger
from ..errors import ClanError
from .assets import asset_path
from .routers import flake, health, machines, root, utils, vms
from .error_handlers import clan_error_handler
from .routers import flake, health, machines, root, vms
origins = [
"http://localhost:3000",
@ -33,9 +34,7 @@ def setup_app() -> FastAPI:
# Needs to be last in register. Because of wildcard route
app.include_router(root.router)
app.add_exception_handler(
utils.NixBuildException, utils.nix_build_exception_handler
)
app.add_exception_handler(ClanError, clan_error_handler)
app.mount("/static", StaticFiles(directory=asset_path()), name="static")
@ -43,15 +42,11 @@ def setup_app() -> FastAPI:
if isinstance(route, APIRoute):
route.operation_id = route.name # in this case, 'read_items'
log.debug(f"Registered route: {route}")
for i in app.exception_handlers.items():
log.debug(f"Registered exception handler: {i}")
return app
# TODO: How do I get the log level from the command line in here?
custom_logger.register(logging.DEBUG)
app = setup_app()
for i in app.exception_handlers.items():
log.info(f"Registered exception handler: {i}")
log.warning("log warn")
log.debug("log debug")

View File

@ -1,7 +1,39 @@
import functools
import logging
from pathlib import Path
log = logging.getLogger(__name__)
def get_hash(string: str) -> str:
"""
This function takes a string like '/nix/store/kkvk20b8zh8aafdnfjp6dnf062x19732-source'
and returns the hash part 'kkvk20b8zh8aafdnfjp6dnf062x19732' after '/nix/store/' and before '-source'.
"""
# Split the string by '/' and get the last element
last_element = string.split("/")[-1]
# Split the last element by '-' and get the first element
hash_part = last_element.split("-")[0]
# Return the hash part
return hash_part
def check_divergence(path: Path) -> None:
p = path.resolve()
log.info("Absolute web asset path: %s", p)
if not p.is_dir():
raise FileNotFoundError(p)
# Get the hash part of the path
gh = get_hash(str(p))
log.debug(f"Serving webui asset with hash {gh}")
@functools.cache
def asset_path() -> Path:
return Path(__file__).parent / "assets"
path = Path(__file__).parent / "assets"
log.debug("Serving assets from: %s", path)
check_divergence(path)
return path

View File

@ -0,0 +1,23 @@
import logging
from fastapi import Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from ..errors import ClanError
log = logging.getLogger(__name__)
def clan_error_handler(request: Request, exc: ClanError) -> JSONResponse:
log.error("ClanError: %s", exc)
detail = [
{
"loc": [],
"msg": str(exc),
}
]
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder(dict(detail=detail)),
)

View File

@ -6,15 +6,15 @@ from fastapi import APIRouter, HTTPException
from clan_cli.webui.schemas import FlakeAction, FlakeAttrResponse, FlakeResponse
from ...async_cmd import run
from ...nix import nix_command, nix_flake_show
from .utils import run_cmd
router = APIRouter()
async def get_attrs(url: str) -> list[str]:
cmd = nix_flake_show(url)
stdout = await run_cmd(cmd)
stdout = await run(cmd)
data: dict[str, dict] = {}
try:
@ -45,7 +45,7 @@ async def inspect_flake(
# Extract the flake from the given URL
# We do this by running 'nix flake prefetch {url} --json'
cmd = nix_command(["flake", "prefetch", url, "--json", "--refresh"])
stdout = await run_cmd(cmd)
stdout = await run(cmd)
data: dict[str, str] = json.loads(stdout)
if data.get("storePath") is None:

View File

@ -1,3 +1,4 @@
import logging
import os
from mimetypes import guess_type
from pathlib import Path
@ -8,6 +9,8 @@ from ..assets import asset_path
router = APIRouter()
log = logging.getLogger(__name__)
@router.get("/{path_name:path}")
async def root(path_name: str) -> Response:
@ -16,6 +19,7 @@ async def root(path_name: str) -> Response:
filename = Path(os.path.normpath(asset_path() / path_name))
if not filename.is_relative_to(asset_path()):
log.error("Prevented directory traversal: %s", filename)
# prevent directory traversal
return Response(status_code=403)
@ -23,6 +27,7 @@ async def root(path_name: str) -> Response:
if filename.suffix == "":
filename = filename.with_suffix(".html")
if not filename.is_file():
log.error("File not found: %s", filename)
return Response(status_code=404)
content_type, _ = guess_type(filename)

View File

@ -1,54 +0,0 @@
import asyncio
import logging
import shlex
from fastapi import HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
log = logging.getLogger(__name__)
class NixBuildException(HTTPException):
def __init__(self, msg: str, loc: list = ["body", "flake_attr"]):
detail = [
{
"loc": loc,
"msg": msg,
"type": "value_error",
}
]
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail
)
def nix_build_exception_handler(
request: Request, exc: NixBuildException
) -> JSONResponse:
log.error("NixBuildException: %s", exc)
return JSONResponse(
status_code=exc.status_code,
content=jsonable_encoder(dict(detail=exc.detail)),
)
async def run_cmd(cmd: list[str]) -> bytes:
log.debug(f"Running command: {shlex.join(cmd)}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise NixBuildException(
f"""
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
return stdout

View File

@ -1,73 +1,27 @@
import json
import logging
from typing import Annotated, Iterator
from uuid import UUID
from fastapi import APIRouter, BackgroundTasks, Body, status
from fastapi import APIRouter, Body, status
from fastapi.exceptions import HTTPException
from fastapi.responses import StreamingResponse
from clan_cli.webui.routers.flake import get_attrs
from ...nix import nix_build, nix_eval
from ...task_manager import get_task
from ...vms import create, inspect
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, get_task, register_task
from .utils import run_cmd
log = logging.getLogger(__name__)
router = APIRouter()
def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]:
return nix_eval(
[
f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.clan.vm.config"
]
)
def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
return nix_build(
[
f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.build.vm"
]
)
class BuildVmTask(BaseTask):
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
super().__init__(uuid)
self.vm = vm
def run(self) -> None:
try:
self.log.debug(f"BuildVM with uuid {self.uuid} started")
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
proc = self.run_cmd(cmd)
self.log.debug(f"stdout: {proc.stdout}")
vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
self.log.debug(f"vm_path: {vm_path}")
self.run_cmd([vm_path])
self.finished = True
except Exception as e:
self.failed = True
self.finished = True
log.exception(e)
@router.post("/api/vms/inspect")
async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
) -> VmInspectResponse:
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
stdout = await run_cmd(cmd)
data = json.loads(stdout)
return VmInspectResponse(
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
)
config = await inspect.inspect_vm(flake_url, flake_attr)
return VmInspectResponse(config=config)
@router.get("/api/vms/{uuid}/status")
@ -84,21 +38,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
def stream_logs() -> Iterator[str]:
task = get_task(uuid)
for proc in task.procs:
if proc.done:
log.debug("stream logs and proc is done")
for line in proc.stderr:
yield line + "\n"
for line in proc.stdout:
yield line + "\n"
continue
while True:
out = proc.output
line = out.get()
if line is None:
log.debug("stream logs and line is None")
break
yield line
yield from task.logs_iter()
return StreamingResponse(
content=stream_logs(),
@ -107,14 +47,12 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
@router.post("/api/vms/create")
async def create_vm(
vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
) -> VmCreateResponse:
async def create_vm(vm: Annotated[VmConfig, Body()]) -> VmCreateResponse:
flake_attrs = await get_attrs(vm.flake_url)
if vm.flake_attr not in flake_attrs:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Provided attribute '{vm.flake_attr}' does not exist.",
)
uuid = register_task(BuildVmTask, vm)
uuid = create.create_vm(vm)
return VmCreateResponse(uuid=str(uuid))

View File

@ -3,6 +3,8 @@ from typing import List
from pydantic import BaseModel, Field
from ..vms.inspect import VmConfig
class Status(Enum):
ONLINE = "online"
@ -35,15 +37,6 @@ class SchemaResponse(BaseModel):
schema_: dict = Field(alias="schema")
class VmConfig(BaseModel):
flake_url: str
flake_attr: str
cores: int
memory_size: int
graphics: bool
class VmStatusResponse(BaseModel):
returncode: list[int | None]
running: bool

View File

@ -1,6 +1,11 @@
import argparse
import logging
import multiprocessing as mp
import os
import socket
import subprocess
import sys
import syslog
import time
import urllib.request
import webbrowser
@ -90,3 +95,98 @@ def start_server(args: argparse.Namespace) -> None:
access_log=args.log_level == "debug",
headers=headers,
)
# Define a function that takes the path of the file socket as input and returns True if it is served, False otherwise
def is_served(file_socket: Path) -> bool:
# Create a Unix stream socket
client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# Try to connect to the file socket
try:
client.connect(str(file_socket))
# Connection succeeded, return True
return True
except OSError:
# Connection failed, return False
return False
finally:
# Close the client socket
client.close()
def set_out_to_syslog() -> None: # type: ignore
# Define some constants for convenience
log_levels = {
"emerg": syslog.LOG_EMERG,
"alert": syslog.LOG_ALERT,
"crit": syslog.LOG_CRIT,
"err": syslog.LOG_ERR,
"warning": syslog.LOG_WARNING,
"notice": syslog.LOG_NOTICE,
"info": syslog.LOG_INFO,
"debug": syslog.LOG_DEBUG,
}
facility = syslog.LOG_USER # Use user facility for custom applications
# Open a connection to the system logger
syslog.openlog("clan-cli", 0, facility) # Use "myapp" as the prefix for messages
# Define a custom write function that sends messages to syslog
def write(message: str) -> int:
# Strip the newline character from the message
message = message.rstrip("\n")
# Check if the message is not empty
if message:
# Send the message to syslog with the appropriate level
if message.startswith("ERROR:"):
# Use error level for messages that start with "ERROR:"
syslog.syslog(log_levels["err"], message)
else:
# Use info level for other messages
syslog.syslog(log_levels["info"], message)
return 0
# Assign the custom write function to sys.stdout and sys.stderr
setattr(sys.stdout, "write", write)
setattr(sys.stderr, "write", write)
# Define a dummy flush function to prevent errors
def flush() -> None:
pass
# Assign the dummy flush function to sys.stdout and sys.stderr
setattr(sys.stdout, "flush", flush)
setattr(sys.stderr, "flush", flush)
def _run_socketfile(socket_file: Path, debug: bool) -> None:
set_out_to_syslog()
run(
"clan_cli.webui.app:app",
uds=str(socket_file),
access_log=debug,
reload=False,
log_level="debug" if debug else "info",
)
@contextmanager
def api_server(debug: bool) -> Iterator[Path]:
runtime_dir = os.getenv("XDG_RUNTIME_DIR")
if runtime_dir is None:
raise RuntimeError("XDG_RUNTIME_DIR not set")
socket_path = Path(runtime_dir) / "clan.sock"
socket_path = socket_path.resolve()
log.debug("Socketfile lies at %s", socket_path)
if not is_served(socket_path):
log.debug("Starting api server...")
mp.set_start_method(method="spawn")
proc = mp.Process(target=_run_socketfile, args=(socket_path, debug))
proc.start()
else:
log.info("Api server is already running on %s", socket_path)
yield socket_path
proc.terminate()

View File

@ -1,119 +0,0 @@
import logging
import os
import queue
import select
import shlex
import subprocess
import threading
from typing import Any
from uuid import UUID, uuid4
class CmdState:
def __init__(self, proc: subprocess.Popen) -> None:
global LOOP
self.proc: subprocess.Popen = proc
self.stdout: list[str] = []
self.stderr: list[str] = []
self.output: queue.SimpleQueue = queue.SimpleQueue()
self.returncode: int | None = None
self.done: bool = False
class BaseTask(threading.Thread):
def __init__(self, uuid: UUID) -> None:
# calling parent class constructor
threading.Thread.__init__(self)
# constructor
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
self.procs: list[CmdState] = []
self.failed: bool = False
self.finished: bool = False
def run(self) -> None:
self.finished = True
def run_cmd(self, cmd: list[str]) -> CmdState:
cwd = os.getcwd()
self.log.debug(f"Working directory: {cwd}")
self.log.debug(f"Running command: {shlex.join(cmd)}")
p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
# shell=True,
cwd=cwd,
)
self.procs.append(CmdState(p))
p_state = self.procs[-1]
while p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([p.stderr, p.stdout], [], [], 0)
if p.stderr in rlist:
assert p.stderr is not None
line = p.stderr.readline()
if line != "":
p_state.stderr.append(line.strip("\n"))
self.log.debug(f"stderr: {line}")
p_state.output.put(line)
if p.stdout in rlist:
assert p.stdout is not None
line = p.stdout.readline()
if line != "":
p_state.stdout.append(line.strip("\n"))
self.log.debug(f"stdout: {line}")
p_state.output.put(line)
p_state.returncode = p.returncode
p_state.output.put(None)
p_state.done = True
if p.returncode != 0:
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command")
return p_state
class TaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BaseTask] = {}
def __getitem__(self, uuid: UUID) -> BaseTask:
with self.lock:
return self.pool[uuid]
def __setitem__(self, uuid: UUID, task: BaseTask) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"Task with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = task
POOL: TaskPool = TaskPool()
def get_task(uuid: UUID) -> BaseTask:
global POOL
return POOL[uuid]
def register_task(task: type, *args: Any) -> UUID:
global POOL
if not issubclass(task, BaseTask):
raise TypeError("task must be a subclass of BaseTask")
uuid = uuid4()
inst_task = task(uuid, *args)
POOL[uuid] = inst_task
inst_task.start()
return uuid

View File

@ -29,6 +29,7 @@
, copyDesktopItems
, qemu
, gnupg
, e2fsprogs
}:
let
@ -63,6 +64,7 @@ let
sops
git
qemu
e2fsprogs
];
runtimeDependenciesAsSet = builtins.listToAttrs (builtins.map (p: lib.nameValuePair (lib.getName p.name) p) runtimeDependencies);

View File

@ -1,3 +1,4 @@
import json
from pathlib import Path
import pytest
@ -28,3 +29,23 @@ def test_inspect_err(api: TestClient) -> None:
data = response.json()
print("Data: ", data)
assert data.get("detail")
@pytest.mark.impure
def test_inspect_flake(api: TestClient, test_flake_with_core: Path) -> None:
params = {"url": str(test_flake_with_core)}
response = api.get(
"/api/flake",
params=params,
)
assert response.status_code == 200, "Failed to inspect vm"
data = response.json()
print("Data: ", json.dumps(data, indent=2))
assert data.get("content") is not None
actions = data.get("actions")
assert actions is not None
assert len(actions) == 2
assert actions[0].get("id") == "vms/inspect"
assert actions[0].get("uri") == "api/vms/inspect"
assert actions[1].get("id") == "vms/create"
assert actions[1].get("uri") == "api/vms/create"

View File

@ -6,24 +6,6 @@ from api import TestClient
from httpx import SyncByteStream
def is_running_in_ci() -> bool:
# Check if pytest is running in GitHub Actions
if os.getenv("GITHUB_ACTIONS") == "true":
print("Running on GitHub Actions")
return True
# Check if pytest is running in Travis CI
if os.getenv("TRAVIS") == "true":
print("Running on Travis CI")
return True
# Check if pytest is running in Circle CI
if os.getenv("CIRCLECI") == "true":
print("Running on Circle CI")
return True
return False
@pytest.mark.impure
def test_inspect(api: TestClient, test_flake_with_core: Path) -> None:
response = api.post(
@ -49,10 +31,9 @@ def test_incorrect_uuid(api: TestClient) -> None:
assert response.status_code == 422, "Failed to get vm status"
@pytest.mark.skipif(not os.path.exists("/dev/kvm"), reason="Requires KVM")
@pytest.mark.impure
def test_create(api: TestClient, test_flake_with_core: Path) -> None:
if is_running_in_ci():
pytest.skip("Skipping test in CI. As it requires KVM")
print(f"flake_url: {test_flake_with_core} ")
response = api.post(
"/api/vms/create",
@ -74,20 +55,11 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
assert response.status_code == 200, "Failed to get vm status"
response = api.get(f"/api/vms/{uuid}/logs")
print("=========FLAKE LOGS==========")
assert isinstance(response.stream, SyncByteStream)
for line in response.stream:
assert line != b"", "Failed to get vm logs"
print(line.decode("utf-8"), end="")
print("=========END LOGS==========")
assert response.status_code == 200, "Failed to get vm logs"
response = api.get(f"/api/vms/{uuid}/logs")
assert isinstance(response.stream, SyncByteStream)
print("=========VM LOGS==========")
assert isinstance(response.stream, SyncByteStream)
for line in response.stream:
assert line != b"", "Failed to get vm logs"
print(line.decode("utf-8"), end="")
print(line.decode("utf-8"))
print("=========END LOGS==========")
assert response.status_code == 200, "Failed to get vm logs"

View File

@ -0,0 +1,22 @@
import os
from pathlib import Path
import pytest
from cli import Cli
no_kvm = not os.path.exists("/dev/kvm")
@pytest.mark.impure
def test_inspect(test_flake_with_core: Path, capsys: pytest.CaptureFixture) -> None:
cli = Cli()
cli.run(["vms", "inspect", "vm1"])
out = capsys.readouterr() # empty the buffer
assert "Cores" in out.out
@pytest.mark.skipif(no_kvm, reason="Requires KVM")
@pytest.mark.impure
def test_create(test_flake_with_core: Path) -> None:
cli = Cli()
cli.run(["vms", "create", "vm1"])