task_manager: fix race conditions

This commit is contained in:
lassulus 2023-10-04 17:41:20 +02:00
parent 827fcbfe46
commit fe1a3f0541
3 changed files with 48 additions and 44 deletions

View File

@ -21,7 +21,8 @@ class Command:
self._output: queue.SimpleQueue = queue.SimpleQueue()
self.returncode: int | None = None
self.done: bool = False
self.lines: list[str] = []
self.stdout: list[str] = []
self.stderr: list[str] = []
def close_queue(self) -> None:
if self.p is not None:
@ -31,36 +32,36 @@ class Command:
def run(self, cmd: list[str]) -> None:
self.running = True
try:
self.log.debug(f"Running command: {shlex.join(cmd)}")
self.p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert self.p.stdout is not None and self.p.stderr is not None
os.set_blocking(self.p.stdout.fileno(), False)
os.set_blocking(self.p.stderr.fileno(), False)
self.log.debug(f"Running command: {shlex.join(cmd)}")
self.p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert self.p.stdout is not None and self.p.stderr is not None
os.set_blocking(self.p.stdout.fileno(), False)
os.set_blocking(self.p.stderr.fileno(), False)
while self.p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
for fd in rlist:
try:
for line in fd:
self.log.debug("stdout: %s", line)
self.lines.append(line)
self._output.put(line)
except BlockingIOError:
continue
while self.p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
for fd in rlist:
try:
for line in fd:
self.log.debug("stdout: %s", line)
if fd == self.p.stderr:
self.stderr.append(line)
else:
self.stdout.append(line)
self._output.put(line)
except BlockingIOError:
continue
if self.p.returncode != 0:
raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
if self.p.returncode != 0:
raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command")
finally:
self.close_queue()
self.log.debug("Successfully ran command")
class TaskStatus(str, Enum):
@ -71,7 +72,7 @@ class TaskStatus(str, Enum):
class BaseTask:
def __init__(self, uuid: UUID) -> None:
def __init__(self, uuid: UUID, num_cmds: int) -> None:
# constructor
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
@ -80,6 +81,10 @@ class BaseTask:
self.logs_lock = threading.Lock()
self.error: Exception | None = None
for _ in range(num_cmds):
cmd = Command(self.log)
self.procs.append(cmd)
def _run(self) -> None:
self.status = TaskStatus.RUNNING
try:
@ -87,13 +92,14 @@ class BaseTask:
except Exception as e:
# FIXME: fix exception handling here
traceback.print_exception(*sys.exc_info())
for proc in self.procs:
proc.close_queue()
self.error = e
self.log.exception(e)
self.status = TaskStatus.FAILED
else:
self.status = TaskStatus.FINISHED
finally:
for proc in self.procs:
proc.close_queue()
def run(self) -> None:
raise NotImplementedError
@ -106,19 +112,16 @@ class BaseTask:
return
# process has finished
if proc.done:
for line in proc.lines:
for line in proc.stdout:
yield line
for line in proc.stderr:
yield line
else:
while line := proc._output.get():
yield line
def register_commands(self, num_cmds: int) -> Iterator[Command]:
for _ in range(num_cmds):
cmd = Command(self.log)
self.procs.append(cmd)
for cmd in self.procs:
yield cmd
def commands(self) -> Iterator[Command]:
yield from self.procs
# TODO: We need to test concurrency
@ -157,6 +160,6 @@ def create_task(task_type: Type[T], *args: Any) -> T:
uuid = uuid4()
task = task_type(uuid, *args)
threading.Thread(target=task._run).start()
POOL[uuid] = task
threading.Thread(target=task._run).start()
return task

View File

@ -15,7 +15,7 @@ from .inspect import VmConfig, inspect_vm
class BuildVmTask(BaseTask):
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
super().__init__(uuid)
super().__init__(uuid, num_cmds=4)
self.vm = vm
def get_vm_create_info(self, cmds: Iterator[Command]) -> dict:
@ -30,13 +30,13 @@ class BuildVmTask(BaseTask):
]
)
)
vm_json = "".join(cmd.lines)
vm_json = "".join(cmd.stdout)
self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json.strip()) as f:
return json.load(f)
def run(self) -> None:
cmds = self.register_commands(4)
cmds = self.commands()
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")

View File

@ -28,7 +28,8 @@ async def inspect_vm(
async def get_vm_status(uuid: UUID) -> VmStatusResponse:
task = get_task(uuid)
log.debug(msg=f"error: {task.error}, task.status: {task.status}")
return VmStatusResponse(status=task.status, error=str(task.error))
error = str(task.error) if task.error is not None else None
return VmStatusResponse(status=task.status, error=error)
@router.get("/api/vms/{uuid}/logs")