Files
andreas a4a6c1e3d9 fix: extend fetch_user error guard; escape HTML in login page
Move field-extraction inside the try/except in fetch_user so non-dict
responses from providers with empty profile_data_path (Gitea, GitHub)
raise OAuthError instead of an uncaught AttributeError. Apply
html.escape() to provider name, label, and logo URL in the login page.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-09 08:57:25 -04:00

255 lines
8.8 KiB
Python

"""OAuth2 provider support.
Config shape (in ~/.hb.yaml):
oauth:
my-gitea: # route slug → /login/oauth/my-gitea
type: gitea # "gitea" | "github" | "nextcloud"
# omit type to default to "gitea"
url: https://git.example.com # required for gitea and nextcloud
client_id: <client-id>
client_secret: <client-secret>
label: "Work Gitea" # optional display name on login button
logo: https://example.com/logo.png # optional logo URL
github:
type: github
client_id: <client-id>
client_secret: <client-secret>
nextcloud:
type: nextcloud
url: https://cloud.example.com
client_id: <client-id>
client_secret: <client-secret>
Register the OAuth app with each provider and set the redirect URI to:
https://<hbd-host>/login/oauth/<name>/callback
"""
import logging
import secrets
import time
import urllib.parse
from dataclasses import dataclass
import aiohttp
logger = logging.getLogger(__name__)
STATE_TTL = 600 # 10 minutes
# state_token -> expiry timestamp
_states: dict[str, float] = {}
def make_state() -> str:
"""Generate a CSRF state token, store it with TTL, and return it."""
_purge_states()
token = secrets.token_hex(32)
_states[token] = time.time() + STATE_TTL
return token
def validate_state(state: str) -> bool:
"""Return True if *state* is known and unexpired; always removes it."""
expiry = _states.pop(state, None)
if expiry is None:
return False
return time.time() < expiry
def _purge_states() -> None:
"""Remove all expired CSRF state tokens from the in-memory store."""
now = time.time()
expired = [k for k, exp in list(_states.items()) if exp < now]
for k in expired:
del _states[k]
class OAuthError(Exception):
"""Raised when the OAuth2 flow fails for any reason."""
PROVIDER_DEFS: dict = {
"gitea": {
"authorize_url_tmpl": "{url}/login/oauth/authorize",
"token_url_tmpl": "{url}/login/oauth/access_token",
"profile_url_tmpl": "{url}/api/v1/user",
"scope": "user:email",
"field_map": {"username": "login", "full_name": "full_name", "avatar": "avatar_url"},
"profile_data_path": [],
"requires_url": True,
"default_label": "Gitea",
},
"github": {
"authorize_url_tmpl": "https://github.com/login/oauth/authorize",
"token_url_tmpl": "https://github.com/login/oauth/access_token",
"profile_url_tmpl": "https://api.github.com/user",
"scope": "read:user",
"field_map": {"username": "login", "full_name": "name", "avatar": "avatar_url"},
"profile_data_path": [],
"requires_url": False,
"default_label": "GitHub",
},
"nextcloud": {
"authorize_url_tmpl": "{url}/apps/oauth2/authorize",
"token_url_tmpl": "{url}/apps/oauth2/api/v1/token",
"profile_url_tmpl": "{url}/ocs/v2.php/cloud/user?format=json",
"scope": "",
"field_map": {"username": "id", "full_name": "display-name", "avatar": None},
"profile_data_path": ["ocs", "data"],
"requires_url": True,
"default_label": "Nextcloud",
},
}
@dataclass
class ResolvedProvider:
"""A fully resolved OAuth2 provider instance, ready to use."""
name: str
type: str
label: str
logo: str
authorize_url: str
token_url: str
profile_url: str
scope: str
client_id: str
client_secret: str
field_map: dict
profile_data_path: list
def get_providers(config: dict) -> list[ResolvedProvider]:
"""Return a ResolvedProvider for every valid entry in config['oauth'].
Entries with missing required fields or unknown types are skipped with
a warning log. Order follows config declaration order.
"""
result = []
oauth_cfg = config.get("oauth", {})
if not isinstance(oauth_cfg, dict):
return result
for name, entry in oauth_cfg.items():
if not isinstance(entry, dict):
continue
provider_type = entry.get("type", "gitea")
defn = PROVIDER_DEFS.get(provider_type)
if defn is None:
logger.warning("OAuth: unknown provider type %r for %r, skipping", provider_type, name)
continue
client_id = entry.get("client_id", "")
client_secret = entry.get("client_secret", "")
if not client_id or not client_secret:
logger.warning("OAuth: %r missing client_id or client_secret, skipping", name)
continue
url = entry.get("url", "").rstrip("/")
if defn["requires_url"] and not url:
logger.warning("OAuth: %r requires url but none configured, skipping", name)
continue
label = entry.get("label") or defn["default_label"]
logo = entry.get("logo", "")
result.append(ResolvedProvider(
name=name,
type=provider_type,
label=label,
logo=logo,
authorize_url=defn["authorize_url_tmpl"].format(url=url),
token_url=defn["token_url_tmpl"].format(url=url),
profile_url=defn["profile_url_tmpl"].format(url=url),
scope=defn["scope"],
client_id=client_id,
client_secret=client_secret,
field_map=dict(defn["field_map"]),
profile_data_path=list(defn["profile_data_path"]),
))
return result
def is_enabled(config: dict) -> bool:
"""Return True when at least one OAuth provider is fully configured."""
return bool(get_providers(config))
def build_auth_url(provider: ResolvedProvider, state: str, redirect_uri: str) -> str:
"""Return the provider's OAuth2 authorization URL to redirect the browser to."""
params: dict = {
"client_id": provider.client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
}
if provider.scope:
params["scope"] = provider.scope
return f"{provider.authorize_url}?{urllib.parse.urlencode(params)}"
async def exchange_code(provider: ResolvedProvider, code: str, redirect_uri: str) -> str:
"""Exchange an authorization *code* for an access token.
Returns the access token string. Raises OAuthError on any failure.
"""
payload = {
"client_id": provider.client_id,
"client_secret": provider.client_secret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
}
timeout = aiohttp.ClientTimeout(total=10)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
provider.token_url,
json=payload,
headers={"Accept": "application/json"},
) as resp:
if resp.status != 200:
text = await resp.text()
raise OAuthError(f"Token exchange failed ({resp.status}): {text}")
data = await resp.json()
token = data.get("access_token")
if not token:
raise OAuthError(f"No access_token in response: {data}")
except aiohttp.ClientError as exc:
raise OAuthError(f"Token exchange network error: {exc}") from exc
return token
async def fetch_user(provider: ResolvedProvider, token: str) -> dict:
"""Fetch the authenticated user's profile from the provider.
Returns a dict with keys: login, full_name, avatar_url.
Raises OAuthError on any failure.
"""
timeout = aiohttp.ClientTimeout(total=10)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(
provider.profile_url,
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/json",
},
) as resp:
if resp.status != 200:
text = await resp.text()
raise OAuthError(f"User fetch failed ({resp.status}): {text}")
data = await resp.json()
except aiohttp.ClientError as exc:
raise OAuthError(f"User fetch network error: {exc}") from exc
try:
for key in provider.profile_data_path:
data = data.get(key, {})
avatar_field = provider.field_map.get("avatar")
return {
"login": data.get(provider.field_map["username"], ""),
"full_name": data.get(provider.field_map["full_name"], ""),
"avatar_url": data.get(avatar_field, "") if avatar_field else "",
}
except AttributeError:
raise OAuthError(f"Unexpected profile response structure from {provider.type}")