tests: rewrite port allocation function

This commit is contained in:
Jörg Thalheim 2023-08-27 09:34:36 +02:00
parent 11dd70bf43
commit 81d02bb218
3 changed files with 52 additions and 38 deletions

View File

@ -1,46 +1,55 @@
#!/usr/bin/env python3
import contextlib
import socket
from typing import Callable
import pytest
NEXT_PORT = 10000
def _unused_port(socket_type: int) -> int:
"""Find an unused localhost port from 1024-65535 and return it."""
with contextlib.closing(socket.socket(type=socket_type)) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]
def check_port(port: int) -> bool:
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
with tcp, udp:
try:
tcp.bind(("127.0.0.1", port))
udp.bind(("127.0.0.1", port))
return True
except socket.error:
return False
PortFunction = Callable[[], int]
def check_port_range(port_range: range) -> bool:
for port in port_range:
if not check_port(port):
return False
return True
@pytest.fixture(scope="session")
def unused_tcp_port() -> PortFunction:
"""A function, producing different unused TCP ports."""
produced = set()
def factory() -> int:
"""Return an unused port."""
port = _unused_port(socket.SOCK_STREAM)
while port in produced:
port = _unused_port(socket.SOCK_STREAM)
produced.add(port)
return port
return factory
class Ports:
def allocate(self, num: int) -> int:
"""
Allocates
"""
global NEXT_PORT
while NEXT_PORT + num <= 65535:
start = NEXT_PORT
NEXT_PORT += num
if not check_port_range(range(start, NEXT_PORT)):
continue
return start
raise Exception("cannot find enough free port")
@pytest.fixture(scope="session")
def unused_udp_port() -> PortFunction:
"""A function, producing different unused UDP ports."""
produced = set()
def factory() -> int:
"""Return an unused port."""
port = _unused_port(socket.SOCK_DGRAM)
@pytest.fixture
def ports() -> Ports:
return Ports()
while port in produced:
port = _unused_port(socket.SOCK_DGRAM)
produced.add(port)
return port
return factory

View File

@ -11,7 +11,7 @@ import pytest
if TYPE_CHECKING:
from command import Command
from ports import Ports
from ports import PortFunction
class Sshd:
@ -104,8 +104,13 @@ exec {bash} -l "${{@}}"
@pytest.fixture
def sshd(sshd_config: SshdConfig, command: "Command", ports: "Ports") -> Iterator[Sshd]:
port = ports.allocate(1)
def sshd(
sshd_config: SshdConfig, command: "Command", unused_tcp_port: "PortFunction"
) -> Iterator[Sshd]:
import subprocess
subprocess.run(["echo", "hello"], check=True)
port = unused_tcp_port()
sshd = shutil.which("sshd")
assert sshd is not None, "no sshd binary found"
env = {}

View File

@ -5,11 +5,11 @@ import subprocess
import sys
from pathlib import Path
from ports import Ports
from ports import PortFunction
def test_start_server(ports: Ports, temporary_dir: Path) -> None:
port = ports.allocate(1)
def test_start_server(unused_tcp_port: PortFunction, temporary_dir: Path) -> None:
port = unused_tcp_port()
fifo = temporary_dir / "fifo"
os.mkfifo(fifo)