diff --git a/pkgs/clan-cli/clan_cli/clan_uri.py b/pkgs/clan-cli/clan_cli/clan_uri.py index f0a8121f..cca37d77 100644 --- a/pkgs/clan-cli/clan_cli/clan_uri.py +++ b/pkgs/clan-cli/clan_cli/clan_uri.py @@ -4,7 +4,7 @@ import urllib.parse from dataclasses import dataclass from enum import Enum, member from pathlib import Path -from typing import Dict +from typing import Dict, Self from .errors import ClanError @@ -41,8 +41,7 @@ class ClanScheme(Enum): # so make sure there are no conflicts with other webservices @dataclass class ClanParameters: - flake_attr: str | None - machine: str | None + flake_attr: str = "defaultVM" # Define the ClanURI class @@ -60,30 +59,25 @@ class ClanURI: self._components = urllib.parse.urlparse(self._nested_uri) # Parse the query string into a dictionary - self._query = urllib.parse.parse_qs(self._components.query) + query = urllib.parse.parse_qs(self._components.query) - params: Dict[str, str | None] = {} + params: Dict[str, str] = {} for field in dataclasses.fields(ClanParameters): - if field.name in self._query: - # Check if the field type is a list - if issubclass(field.type, list): - setattr(params, field.name, self._query[field.name]) - # Check if the field type is a single value - else: - values = self._query[field.name] - if len(values) > 1: - raise ClanError( - "Multiple values for parameter: {}".format(field.name) - ) - setattr(params, field.name, values[0]) + if field.name in query: + values = query[field.name] + if len(values) > 1: + raise ClanError( + "Multiple values for parameter: {}".format(field.name) + ) + params[field.name] = values[0] # Remove the field from the query dictionary # clan uri and nested uri share one namespace for query parameters # we need to make sure there are no conflicts - del self._query[field.name] - else: - params[field.name] = None + del query[field.name] + new_query = urllib.parse.urlencode(query, doseq=True) + self._components = self._components._replace(query=new_query) self.params = ClanParameters(**params) # Use the match statement to check the scheme and create a ClanScheme member with the value @@ -98,3 +92,9 @@ class ClanURI: raise ClanError( "Unsupported scheme: {}".format(self._components.scheme) ) + + @classmethod + def from_path(cls, path: Path, params: ClanParameters) -> Self: # noqa + urlparams = urllib.parse.urlencode(params.__dict__) + + return cls("clan://file://{}?{}".format(path, urlparams)) diff --git a/pkgs/clan-cli/tests/test_clan_uri.py b/pkgs/clan-cli/tests/test_clan_uri.py index 3a4aac59..e2b2e5e2 100644 --- a/pkgs/clan-cli/tests/test_clan_uri.py +++ b/pkgs/clan-cli/tests/test_clan_uri.py @@ -33,9 +33,9 @@ def test_is_remote() -> None: assert False -def remote_with_clanparams() -> None: +def test_remote_with_clanparams() -> None: # Create a ClanURI object from a remote URI with parameters - uri = ClanURI("clan://https://example.com?flake_attr=defaultVM") + uri = ClanURI("clan://https://example.com") assert uri.params.flake_attr == "defaultVM" @@ -46,17 +46,13 @@ def remote_with_clanparams() -> None: assert False -def remote_with_all_params() -> None: +def test_remote_with_all_params() -> None: # Create a ClanURI object from a remote URI with parameters - uri = ClanURI( - "clan://https://example.com?flake_attr=defaultVM&machine=vm1&password=1234" - ) - - assert uri.params.flake_attr == "defaultVM" - assert uri.params.machine == "vm1" + uri = ClanURI("clan://https://example.com?flake_attr=myVM&password=1234") + assert uri.params.flake_attr == "myVM" match uri.scheme: case ClanScheme.HTTPS.value(url): - assert url == "https://example.com&password=1234" # type: ignore + assert url == "https://example.com?password=1234" # type: ignore case _: assert False diff --git a/pkgs/clan-vm-manager/clan_vm_manager/models.py b/pkgs/clan-vm-manager/clan_vm_manager/models.py index 6dc21519..6c493814 100644 --- a/pkgs/clan-vm-manager/clan_vm_manager/models.py +++ b/pkgs/clan-vm-manager/clan_vm_manager/models.py @@ -68,9 +68,11 @@ class VMBase: vm = asyncio.run( vms.run.inspect_vm(flake_url=self._path, flake_attr="defaultVM") ) - task = vms.run.run_vm(vm) - for line in task.log_lines(): - print(line, end="") + vms.run.run_vm(vm) + + +# for line in task.log_lines(): +# print(line, end="") @dataclass(frozen=True)