"""OAuth2 provider support. Config shape (in ~/.hb.yaml): oauth: my-gitea: # route slug → /login/oauth/my-gitea type: gitea # "gitea" | "github" | "nextcloud" # omit type to default to "gitea" url: https://git.example.com # required for gitea and nextcloud client_id: client_secret: label: "Work Gitea" # optional display name on login button logo: https://example.com/logo.png # optional logo URL github: type: github client_id: client_secret: nextcloud: type: nextcloud url: https://cloud.example.com client_id: client_secret: Register the OAuth app with each provider and set the redirect URI to: https:///login/oauth//callback """ import logging import secrets import time import urllib.parse from dataclasses import dataclass import aiohttp logger = logging.getLogger(__name__) STATE_TTL = 600 # 10 minutes # state_token -> expiry timestamp _states: dict[str, float] = {} def make_state() -> str: """Generate a CSRF state token, store it with TTL, and return it.""" _purge_states() token = secrets.token_hex(32) _states[token] = time.time() + STATE_TTL return token def validate_state(state: str) -> bool: """Return True if *state* is known and unexpired; always removes it.""" expiry = _states.pop(state, None) if expiry is None: return False return time.time() < expiry def _purge_states() -> None: """Remove all expired CSRF state tokens from the in-memory store.""" now = time.time() expired = [k for k, exp in list(_states.items()) if exp < now] for k in expired: del _states[k] class OAuthError(Exception): """Raised when the OAuth2 flow fails for any reason.""" PROVIDER_DEFS: dict = { "gitea": { "authorize_url_tmpl": "{url}/login/oauth/authorize", "token_url_tmpl": "{url}/login/oauth/access_token", "profile_url_tmpl": "{url}/api/v1/user", "scope": "user:email", "field_map": {"username": "login", "full_name": "full_name", "avatar": "avatar_url"}, "profile_data_path": [], "requires_url": True, "default_label": "Gitea", }, "github": { "authorize_url_tmpl": "https://github.com/login/oauth/authorize", "token_url_tmpl": "https://github.com/login/oauth/access_token", "profile_url_tmpl": "https://api.github.com/user", "scope": "read:user", "field_map": {"username": "login", "full_name": "name", "avatar": "avatar_url"}, "profile_data_path": [], "requires_url": False, "default_label": "GitHub", }, "nextcloud": { "authorize_url_tmpl": "{url}/apps/oauth2/authorize", "token_url_tmpl": "{url}/apps/oauth2/api/v1/token", "profile_url_tmpl": "{url}/ocs/v2.php/cloud/user?format=json", "scope": "", "field_map": {"username": "id", "full_name": "display-name", "avatar": None}, "profile_data_path": ["ocs", "data"], "requires_url": True, "default_label": "Nextcloud", }, } @dataclass class ResolvedProvider: """A fully resolved OAuth2 provider instance, ready to use.""" name: str type: str label: str logo: str authorize_url: str token_url: str profile_url: str scope: str client_id: str client_secret: str field_map: dict profile_data_path: list def get_providers(config: dict) -> list[ResolvedProvider]: """Return a ResolvedProvider for every valid entry in config['oauth']. Entries with missing required fields or unknown types are skipped with a warning log. Order follows config declaration order. """ result = [] oauth_cfg = config.get("oauth", {}) if not isinstance(oauth_cfg, dict): return result for name, entry in oauth_cfg.items(): if not isinstance(entry, dict): continue provider_type = entry.get("type", "gitea") defn = PROVIDER_DEFS.get(provider_type) if defn is None: logger.warning("OAuth: unknown provider type %r for %r, skipping", provider_type, name) continue client_id = entry.get("client_id", "") client_secret = entry.get("client_secret", "") if not client_id or not client_secret: logger.warning("OAuth: %r missing client_id or client_secret, skipping", name) continue url = entry.get("url", "").rstrip("/") if defn["requires_url"] and not url: logger.warning("OAuth: %r requires url but none configured, skipping", name) continue label = entry.get("label") or defn["default_label"] logo = entry.get("logo", "") result.append(ResolvedProvider( name=name, type=provider_type, label=label, logo=logo, authorize_url=defn["authorize_url_tmpl"].format(url=url), token_url=defn["token_url_tmpl"].format(url=url), profile_url=defn["profile_url_tmpl"].format(url=url), scope=defn["scope"], client_id=client_id, client_secret=client_secret, field_map=dict(defn["field_map"]), profile_data_path=list(defn["profile_data_path"]), )) return result def is_enabled(config: dict) -> bool: """Return True when at least one OAuth provider is fully configured.""" return bool(get_providers(config)) def build_auth_url(provider: ResolvedProvider, state: str, redirect_uri: str) -> str: """Return the provider's OAuth2 authorization URL to redirect the browser to.""" params: dict = { "client_id": provider.client_id, "redirect_uri": redirect_uri, "response_type": "code", "state": state, } if provider.scope: params["scope"] = provider.scope return f"{provider.authorize_url}?{urllib.parse.urlencode(params)}" async def exchange_code(provider: ResolvedProvider, code: str, redirect_uri: str) -> str: """Exchange an authorization *code* for an access token. Returns the access token string. Raises OAuthError on any failure. """ payload = { "client_id": provider.client_id, "client_secret": provider.client_secret, "code": code, "grant_type": "authorization_code", "redirect_uri": redirect_uri, } timeout = aiohttp.ClientTimeout(total=10) try: async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post( provider.token_url, json=payload, headers={"Accept": "application/json"}, ) as resp: if resp.status != 200: text = await resp.text() raise OAuthError(f"Token exchange failed ({resp.status}): {text}") data = await resp.json() token = data.get("access_token") if not token: raise OAuthError(f"No access_token in response: {data}") except aiohttp.ClientError as exc: raise OAuthError(f"Token exchange network error: {exc}") from exc return token async def fetch_user(provider: ResolvedProvider, token: str) -> dict: """Fetch the authenticated user's profile from the provider. Returns a dict with keys: login, full_name, avatar_url. Raises OAuthError on any failure. """ timeout = aiohttp.ClientTimeout(total=10) try: async with aiohttp.ClientSession(timeout=timeout) as session: async with session.get( provider.profile_url, headers={ "Authorization": f"Bearer {token}", "Accept": "application/json", }, ) as resp: if resp.status != 200: text = await resp.text() raise OAuthError(f"User fetch failed ({resp.status}): {text}") data = await resp.json() except aiohttp.ClientError as exc: raise OAuthError(f"User fetch network error: {exc}") from exc try: for key in provider.profile_data_path: data = data.get(key, {}) except AttributeError: raise OAuthError(f"Unexpected profile response structure from {provider.type}") avatar_field = provider.field_map.get("avatar") return { "login": data.get(provider.field_map["username"], ""), "full_name": data.get(provider.field_map["full_name"], ""), "avatar_url": data.get(avatar_field, "") if avatar_field else "", }