from typing import Optional from fastapi import APIRouter, Depends from pydantic import BaseModel, Field from sqlalchemy.orm import Session from ..auth_utils import get_current_user from ..database import get_db from ..models import User, UserSettings from ..services import ytdlp router = APIRouter() VALID_BROWSERS = {"", "chrome", "chromium", "firefox", "brave", "edge", "opera", "safari"} VALID_REGIONS = {"US", "SE", "GB", "DE", "JP", "FR", "CA", "AU", "BR", "IN", "KR", "MX"} class SettingsOut(BaseModel): preferred_quality: str max_concurrent_downloads: int hide_watched_from_feed: bool mark_watched_at_percent: int auto_download_on_sync: bool cookies_browser: str = "" cookies_file: str = "" theater_mode: bool = False discovery_regions: str = "US,SE" calm_mode: bool = False hide_subscriber_counts: bool = False autoplay_enabled: bool = False feed_weight_recency: float = 5.0 feed_weight_affinity: float = 5.0 feed_weight_channel: float = 5.0 model_config = {"from_attributes": True} class SettingsPatch(BaseModel): preferred_quality: Optional[str] = None max_concurrent_downloads: Optional[int] = Field(None, ge=1, le=5) hide_watched_from_feed: Optional[bool] = None mark_watched_at_percent: Optional[int] = Field(None, ge=50, le=100) auto_download_on_sync: Optional[bool] = None cookies_browser: Optional[str] = None cookies_file: Optional[str] = None theater_mode: Optional[bool] = None discovery_regions: Optional[str] = None calm_mode: Optional[bool] = None hide_subscriber_counts: Optional[bool] = None autoplay_enabled: Optional[bool] = None feed_weight_recency: Optional[float] = Field(None, ge=0.0, le=10.0) feed_weight_affinity: Optional[float] = Field(None, ge=0.0, le=10.0) feed_weight_channel: Optional[float] = Field(None, ge=0.0, le=10.0) def _get_or_create(db: Session, user_id: int) -> UserSettings: s = db.query(UserSettings).filter_by(user_id=user_id).first() if not s: s = UserSettings(user_id=user_id) db.add(s) db.commit() db.refresh(s) return s @router.get("", response_model=SettingsOut) def get_settings( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): return _get_or_create(db, current_user.id) @router.patch("", response_model=SettingsOut) def update_settings( body: SettingsPatch, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): s = _get_or_create(db, current_user.id) if body.preferred_quality is not None and body.preferred_quality in ytdlp.QUALITY_FORMATS: s.preferred_quality = body.preferred_quality if body.max_concurrent_downloads is not None: s.max_concurrent_downloads = body.max_concurrent_downloads ytdlp.set_max_concurrent(body.max_concurrent_downloads) if body.hide_watched_from_feed is not None: s.hide_watched_from_feed = body.hide_watched_from_feed if body.mark_watched_at_percent is not None: s.mark_watched_at_percent = body.mark_watched_at_percent if body.auto_download_on_sync is not None: s.auto_download_on_sync = body.auto_download_on_sync if body.cookies_browser is not None and body.cookies_browser in VALID_BROWSERS: s.cookies_browser = body.cookies_browser ytdlp.set_cookies_browser(body.cookies_browser) if body.cookies_file is not None: s.cookies_file = body.cookies_file.strip() ytdlp.set_cookies_file(body.cookies_file) if body.theater_mode is not None: s.theater_mode = body.theater_mode if body.discovery_regions is not None: # Validate: comma-separated list of known region codes codes = [r.strip().upper() for r in body.discovery_regions.split(",") if r.strip()] valid = [c for c in codes if c in VALID_REGIONS] if valid: s.discovery_regions = ",".join(valid) if body.calm_mode is not None: s.calm_mode = body.calm_mode if body.hide_subscriber_counts is not None: s.hide_subscriber_counts = body.hide_subscriber_counts if body.autoplay_enabled is not None: s.autoplay_enabled = body.autoplay_enabled if body.feed_weight_recency is not None: s.feed_weight_recency = body.feed_weight_recency if body.feed_weight_affinity is not None: s.feed_weight_affinity = body.feed_weight_affinity if body.feed_weight_channel is not None: s.feed_weight_channel = body.feed_weight_channel db.commit() db.refresh(s) return s