only enable corsmiddleware when in dev mode and allow all origins
All checks were successful
checks-impure / test (pull_request) Successful in 1m37s
checks / test (pull_request) Successful in 2m48s

This commit is contained in:
Jörg Thalheim 2023-11-14 15:32:03 +01:00
parent b2bbddd1f9
commit 18627baa9c
No known key found for this signature in database
2 changed files with 36 additions and 24 deletions

View File

@ -1,4 +1,6 @@
import logging
import os
from enum import Enum
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -11,22 +13,43 @@ from .error_handlers import clan_error_handler
from .routers import clan_modules, flake, health, machines, root, vms
from .tags import tags_metadata
origins = [
"http://localhost:3000",
]
# Logging setup
log = logging.getLogger(__name__)
class EnvType(Enum):
production = "production"
development = "development"
@staticmethod
def from_environment() -> "EnvType":
t = os.environ.get("CLAN_WEBUI_ENV", "production")
try:
return EnvType[t]
except KeyError:
log.warning(f"Invalid environment type: {t}, fallback to production")
return EnvType.production
def is_production(self) -> bool:
return self == EnvType.production
def is_development(self) -> bool:
return self == EnvType.development
def setup_app() -> FastAPI:
env = EnvType.from_environment()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if env.is_development():
# Allow CORS in development mode for nextjs dev server
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(clan_modules.router)
app.include_router(flake.router)
app.include_router(health.router)

View File

@ -1,5 +1,6 @@
import argparse
import logging
import os
import shutil
import subprocess
import time
@ -76,6 +77,8 @@ def spawn_node_dev_server(host: IPvAnyAddress, port: int) -> Iterator[None]:
def start_server(args: argparse.Namespace) -> None:
os.environ["CLAN_WEBUI_ENV"] = "development" if args.dev else "production"
with ExitStack() as stack:
headers: list[tuple[str, str]] = []
if args.dev:
@ -85,20 +88,6 @@ def start_server(args: argparse.Namespace) -> None:
host = args.dev_host
if ":" in host:
host = f"[{host}]"
headers = [
# (
# "Access-Control-Allow-Origin",
# f"http://{host}:{args.dev_port}",
# ),
# (
# "Access-Control-Allow-Methods",
# "DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT"
# ),
# (
# "Allow",
# "DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT"
# )
]
else:
base_url = f"http://{args.host}:{args.port}"