1
0
forked from clan/clan-core

Revert "clan-cli: cmd.py uses pseudo terminal now. Remove tty.py. Refactor password_store.py to use cmd.py."

This reverts commit ba86b49952.
This commit is contained in:
lassulus 2024-06-03 12:23:56 +02:00
parent dbad63f155
commit 578162425d
5 changed files with 86 additions and 84 deletions

3
.gitignore vendored
View File

@ -1,5 +1,4 @@
.direnv .direnv
***/.vscode
***/.hypothesis ***/.hypothesis
out.log out.log
.coverage.* .coverage.*
@ -36,4 +35,4 @@ repo
# node # node
node_modules node_modules
dist dist
.webui .webui

View File

@ -1,6 +1,5 @@
import logging import logging
import os import os
import pty
import select import select
import shlex import shlex
import subprocess import subprocess
@ -9,6 +8,7 @@ import weakref
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import IO, Any
from .custom_logger import get_caller from .custom_logger import get_caller
from .errors import ClanCmdError, CmdOut from .errors import ClanCmdError, CmdOut
@ -23,6 +23,42 @@ class Log(Enum):
NONE = 4 NONE = 4
def handle_output(process: subprocess.Popen, log: Log) -> tuple[str, str]:
rlist = [process.stdout, process.stderr]
stdout_buf = b""
stderr_buf = b""
while len(rlist) != 0:
r, _, _ = select.select(rlist, [], [], 0.1)
if len(r) == 0: # timeout in select
if process.poll() is None:
continue
# Process has exited
break
def handle_fd(fd: IO[Any] | None) -> bytes:
if fd and fd in r:
read = os.read(fd.fileno(), 4096)
if len(read) != 0:
return read
rlist.remove(fd)
return b""
ret = handle_fd(process.stdout)
if ret and log in [Log.STDOUT, Log.BOTH]:
sys.stdout.buffer.write(ret)
sys.stdout.flush()
stdout_buf += ret
ret = handle_fd(process.stderr)
if ret and log in [Log.STDERR, Log.BOTH]:
sys.stderr.buffer.write(ret)
sys.stderr.flush()
stderr_buf += ret
return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace")
class TimeTable: class TimeTable:
""" """
This class is used to store the time taken by each command This class is used to store the time taken by each command
@ -78,91 +114,38 @@ def run(
) )
else: else:
glog.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}") glog.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}")
# Create pseudo-terminals for stdout/stderr and stdin
stdout_master_fd, stdout_slave_fd = pty.openpty()
stderr_master_fd, stderr_slave_fd = pty.openpty()
tstart = datetime.now() tstart = datetime.now()
proc = subprocess.Popen( # Start the subprocess
process = subprocess.Popen(
cmd, cmd,
preexec_fn=os.setsid,
stdin=stdout_slave_fd,
stdout=stdout_slave_fd,
stderr=stderr_slave_fd,
close_fds=True,
env=env,
cwd=str(cwd), cwd=str(cwd),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) )
stdout_buf, stderr_buf = handle_output(process, log)
os.close(stdout_slave_fd) # Close slave FD in parent
os.close(stderr_slave_fd) # Close slave FD in parent
stdout_file = sys.stdout
stderr_file = sys.stderr
stdout_buf = b""
stderr_buf = b""
if input: if input:
written_b = os.write(stdout_master_fd, input) process.communicate(input)
else:
if written_b != len(input): process.wait()
raise ValueError("Could not write all input to subprocess")
rlist = [stdout_master_fd, stderr_master_fd]
def handle_fd(fd: int | None) -> bytes:
if fd and fd in r:
try:
read = os.read(fd, 4096)
if len(read) != 0:
return read
except OSError:
pass
rlist.remove(fd)
return b""
while len(rlist) != 0:
r, w, e = select.select(rlist, [], [], 0.1)
if len(r) == 0: # timeout in select
if proc.poll() is None:
continue
# Process has exited
break
ret = handle_fd(stdout_master_fd)
stdout_buf += ret
if ret and log in [Log.STDOUT, Log.BOTH]:
stdout_file.buffer.write(ret)
stdout_file.flush()
ret = handle_fd(stderr_master_fd)
stderr_buf += ret
if ret and log in [Log.STDERR, Log.BOTH]:
stderr_file.buffer.write(ret)
stderr_file.flush()
os.close(stdout_master_fd)
os.close(stderr_master_fd)
proc.wait()
tend = datetime.now() tend = datetime.now()
global TIME_TABLE global TIME_TABLE
TIME_TABLE.add(shlex.join(cmd), tend - tstart) TIME_TABLE.add(shlex.join(cmd), tend - tstart)
# Wait for the subprocess to finish # Wait for the subprocess to finish
cmd_out = CmdOut( cmd_out = CmdOut(
stdout=stdout_buf.decode("utf-8", "replace"), stdout=stdout_buf,
stderr=stderr_buf.decode("utf-8", "replace"), stderr=stderr_buf,
cwd=cwd, cwd=cwd,
command=shlex.join(cmd), command=shlex.join(cmd),
returncode=proc.returncode, returncode=process.returncode,
msg=error_msg, msg=error_msg,
) )
if check and proc.returncode != 0: if check and process.returncode != 0:
raise ClanCmdError(cmd_out) raise ClanCmdError(cmd_out)
return cmd_out return cmd_out

View File

@ -2,7 +2,7 @@ import os
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from clan_cli.cmd import run from clan_cli.cmd import Log, run
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.nix import nix_shell from clan_cli.nix import nix_shell
@ -16,13 +16,14 @@ class SecretStore(SecretStoreBase):
def set( def set(
self, service: str, name: str, value: bytes, groups: list[str] self, service: str, name: str, value: bytes, groups: list[str]
) -> Path | None: ) -> Path | None:
subprocess.run( run(
nix_shell( nix_shell(
["nixpkgs#pass"], ["nixpkgs#pass"],
["pass", "insert", "-m", f"machines/{self.machine.name}/{name}"], ["pass", "insert", "-m", f"machines/{self.machine.name}/{name}"],
), ),
input=value, input=value,
check=True, log=Log.BOTH,
error_msg=f"Failed to insert secret {name}",
) )
return None # we manage the files outside of the git repo return None # we manage the files outside of the git repo

View File

@ -1,6 +1,5 @@
import argparse import argparse
import getpass import getpass
import logging
import os import os
import shutil import shutil
import sys import sys
@ -9,6 +8,7 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import IO from typing import IO
from .. import tty
from ..errors import ClanError from ..errors import ClanError
from ..git import commit_files from ..git import commit_files
from .folders import ( from .folders import (
@ -21,13 +21,6 @@ from .folders import (
from .sops import decrypt_file, encrypt_file, ensure_sops_key, read_key, update_keys from .sops import decrypt_file, encrypt_file, ensure_sops_key, read_key, update_keys
from .types import VALID_SECRET_NAME, secret_name_type from .types import VALID_SECRET_NAME, secret_name_type
log = logging.getLogger(__name__)
def tty_is_interactive() -> bool:
"""Returns true if the current process is interactive"""
return sys.stdin.isatty() and sys.stdout.isatty()
def update_secrets( def update_secrets(
flake_dir: Path, filter_secrets: Callable[[Path], bool] = lambda _: True flake_dir: Path, filter_secrets: Callable[[Path], bool] = lambda _: True
@ -56,11 +49,11 @@ def collect_keys_for_type(folder: Path) -> set[str]:
try: try:
target = p.resolve() target = p.resolve()
except FileNotFoundError: except FileNotFoundError:
log.warn(f"Ignoring broken symlink {p}") tty.warn(f"Ignoring broken symlink {p}")
continue continue
kind = target.parent.name kind = target.parent.name
if folder.name != kind: if folder.name != kind:
log.warn(f"Expected {p} to point to {folder} but points to {target.parent}") tty.warn(f"Expected {p} to point to {folder} but points to {target.parent}")
continue continue
keys.add(read_key(target)) keys.add(read_key(target))
return keys return keys
@ -292,7 +285,7 @@ def set_command(args: argparse.Namespace) -> None:
secret_value = None secret_value = None
elif env_value: elif env_value:
secret_value = env_value secret_value = env_value
elif tty_is_interactive(): elif tty.is_interactive():
secret_value = getpass.getpass(prompt="Paste your secret: ") secret_value = getpass.getpass(prompt="Paste your secret: ")
encrypt_secret( encrypt_secret(
Path(args.flake), Path(args.flake),

View File

@ -0,0 +1,26 @@
import sys
from collections.abc import Callable
from typing import IO, Any
def is_interactive() -> bool:
"""Returns true if the current process is interactive"""
return sys.stdin.isatty() and sys.stdout.isatty()
def color_text(code: int, file: IO[Any] = sys.stdout) -> Callable[[str], None]:
"""
Print with color if stderr is a tty
"""
def wrapper(text: str) -> None:
if file.isatty():
print(f"\x1b[{code}m{text}\x1b[0m", file=file)
else:
print(text, file=file)
return wrapper
warn = color_text(91, file=sys.stderr)
info = color_text(92, file=sys.stderr)