a4a6c1e3d9
Move field-extraction inside the try/except in fetch_user so non-dict responses from providers with empty profile_data_path (Gitea, GitHub) raise OAuthError instead of an uncaught AttributeError. Apply html.escape() to provider name, label, and logo URL in the login page. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
255 lines
8.8 KiB
Python
255 lines
8.8 KiB
Python
"""OAuth2 provider support.
|
|
|
|
Config shape (in ~/.hb.yaml):
|
|
|
|
oauth:
|
|
my-gitea: # route slug → /login/oauth/my-gitea
|
|
type: gitea # "gitea" | "github" | "nextcloud"
|
|
# omit type to default to "gitea"
|
|
url: https://git.example.com # required for gitea and nextcloud
|
|
client_id: <client-id>
|
|
client_secret: <client-secret>
|
|
label: "Work Gitea" # optional display name on login button
|
|
logo: https://example.com/logo.png # optional logo URL
|
|
|
|
github:
|
|
type: github
|
|
client_id: <client-id>
|
|
client_secret: <client-secret>
|
|
|
|
nextcloud:
|
|
type: nextcloud
|
|
url: https://cloud.example.com
|
|
client_id: <client-id>
|
|
client_secret: <client-secret>
|
|
|
|
Register the OAuth app with each provider and set the redirect URI to:
|
|
https://<hbd-host>/login/oauth/<name>/callback
|
|
"""
|
|
|
|
import logging
|
|
import secrets
|
|
import time
|
|
import urllib.parse
|
|
from dataclasses import dataclass
|
|
|
|
import aiohttp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
STATE_TTL = 600 # 10 minutes
|
|
|
|
# state_token -> expiry timestamp
|
|
_states: dict[str, float] = {}
|
|
|
|
|
|
def make_state() -> str:
|
|
"""Generate a CSRF state token, store it with TTL, and return it."""
|
|
_purge_states()
|
|
token = secrets.token_hex(32)
|
|
_states[token] = time.time() + STATE_TTL
|
|
return token
|
|
|
|
|
|
def validate_state(state: str) -> bool:
|
|
"""Return True if *state* is known and unexpired; always removes it."""
|
|
expiry = _states.pop(state, None)
|
|
if expiry is None:
|
|
return False
|
|
return time.time() < expiry
|
|
|
|
|
|
def _purge_states() -> None:
|
|
"""Remove all expired CSRF state tokens from the in-memory store."""
|
|
now = time.time()
|
|
expired = [k for k, exp in list(_states.items()) if exp < now]
|
|
for k in expired:
|
|
del _states[k]
|
|
|
|
|
|
class OAuthError(Exception):
|
|
"""Raised when the OAuth2 flow fails for any reason."""
|
|
|
|
|
|
PROVIDER_DEFS: dict = {
|
|
"gitea": {
|
|
"authorize_url_tmpl": "{url}/login/oauth/authorize",
|
|
"token_url_tmpl": "{url}/login/oauth/access_token",
|
|
"profile_url_tmpl": "{url}/api/v1/user",
|
|
"scope": "user:email",
|
|
"field_map": {"username": "login", "full_name": "full_name", "avatar": "avatar_url"},
|
|
"profile_data_path": [],
|
|
"requires_url": True,
|
|
"default_label": "Gitea",
|
|
},
|
|
"github": {
|
|
"authorize_url_tmpl": "https://github.com/login/oauth/authorize",
|
|
"token_url_tmpl": "https://github.com/login/oauth/access_token",
|
|
"profile_url_tmpl": "https://api.github.com/user",
|
|
"scope": "read:user",
|
|
"field_map": {"username": "login", "full_name": "name", "avatar": "avatar_url"},
|
|
"profile_data_path": [],
|
|
"requires_url": False,
|
|
"default_label": "GitHub",
|
|
},
|
|
"nextcloud": {
|
|
"authorize_url_tmpl": "{url}/apps/oauth2/authorize",
|
|
"token_url_tmpl": "{url}/apps/oauth2/api/v1/token",
|
|
"profile_url_tmpl": "{url}/ocs/v2.php/cloud/user?format=json",
|
|
"scope": "",
|
|
"field_map": {"username": "id", "full_name": "display-name", "avatar": None},
|
|
"profile_data_path": ["ocs", "data"],
|
|
"requires_url": True,
|
|
"default_label": "Nextcloud",
|
|
},
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class ResolvedProvider:
|
|
"""A fully resolved OAuth2 provider instance, ready to use."""
|
|
name: str
|
|
type: str
|
|
label: str
|
|
logo: str
|
|
authorize_url: str
|
|
token_url: str
|
|
profile_url: str
|
|
scope: str
|
|
client_id: str
|
|
client_secret: str
|
|
field_map: dict
|
|
profile_data_path: list
|
|
|
|
|
|
def get_providers(config: dict) -> list[ResolvedProvider]:
|
|
"""Return a ResolvedProvider for every valid entry in config['oauth'].
|
|
|
|
Entries with missing required fields or unknown types are skipped with
|
|
a warning log. Order follows config declaration order.
|
|
"""
|
|
result = []
|
|
oauth_cfg = config.get("oauth", {})
|
|
if not isinstance(oauth_cfg, dict):
|
|
return result
|
|
for name, entry in oauth_cfg.items():
|
|
if not isinstance(entry, dict):
|
|
continue
|
|
provider_type = entry.get("type", "gitea")
|
|
defn = PROVIDER_DEFS.get(provider_type)
|
|
if defn is None:
|
|
logger.warning("OAuth: unknown provider type %r for %r, skipping", provider_type, name)
|
|
continue
|
|
client_id = entry.get("client_id", "")
|
|
client_secret = entry.get("client_secret", "")
|
|
if not client_id or not client_secret:
|
|
logger.warning("OAuth: %r missing client_id or client_secret, skipping", name)
|
|
continue
|
|
url = entry.get("url", "").rstrip("/")
|
|
if defn["requires_url"] and not url:
|
|
logger.warning("OAuth: %r requires url but none configured, skipping", name)
|
|
continue
|
|
label = entry.get("label") or defn["default_label"]
|
|
logo = entry.get("logo", "")
|
|
result.append(ResolvedProvider(
|
|
name=name,
|
|
type=provider_type,
|
|
label=label,
|
|
logo=logo,
|
|
authorize_url=defn["authorize_url_tmpl"].format(url=url),
|
|
token_url=defn["token_url_tmpl"].format(url=url),
|
|
profile_url=defn["profile_url_tmpl"].format(url=url),
|
|
scope=defn["scope"],
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
field_map=dict(defn["field_map"]),
|
|
profile_data_path=list(defn["profile_data_path"]),
|
|
))
|
|
return result
|
|
|
|
|
|
def is_enabled(config: dict) -> bool:
|
|
"""Return True when at least one OAuth provider is fully configured."""
|
|
return bool(get_providers(config))
|
|
|
|
|
|
def build_auth_url(provider: ResolvedProvider, state: str, redirect_uri: str) -> str:
|
|
"""Return the provider's OAuth2 authorization URL to redirect the browser to."""
|
|
params: dict = {
|
|
"client_id": provider.client_id,
|
|
"redirect_uri": redirect_uri,
|
|
"response_type": "code",
|
|
"state": state,
|
|
}
|
|
if provider.scope:
|
|
params["scope"] = provider.scope
|
|
return f"{provider.authorize_url}?{urllib.parse.urlencode(params)}"
|
|
|
|
|
|
async def exchange_code(provider: ResolvedProvider, code: str, redirect_uri: str) -> str:
|
|
"""Exchange an authorization *code* for an access token.
|
|
|
|
Returns the access token string. Raises OAuthError on any failure.
|
|
"""
|
|
payload = {
|
|
"client_id": provider.client_id,
|
|
"client_secret": provider.client_secret,
|
|
"code": code,
|
|
"grant_type": "authorization_code",
|
|
"redirect_uri": redirect_uri,
|
|
}
|
|
timeout = aiohttp.ClientTimeout(total=10)
|
|
try:
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.post(
|
|
provider.token_url,
|
|
json=payload,
|
|
headers={"Accept": "application/json"},
|
|
) as resp:
|
|
if resp.status != 200:
|
|
text = await resp.text()
|
|
raise OAuthError(f"Token exchange failed ({resp.status}): {text}")
|
|
data = await resp.json()
|
|
token = data.get("access_token")
|
|
if not token:
|
|
raise OAuthError(f"No access_token in response: {data}")
|
|
except aiohttp.ClientError as exc:
|
|
raise OAuthError(f"Token exchange network error: {exc}") from exc
|
|
return token
|
|
|
|
|
|
async def fetch_user(provider: ResolvedProvider, token: str) -> dict:
|
|
"""Fetch the authenticated user's profile from the provider.
|
|
|
|
Returns a dict with keys: login, full_name, avatar_url.
|
|
Raises OAuthError on any failure.
|
|
"""
|
|
timeout = aiohttp.ClientTimeout(total=10)
|
|
try:
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.get(
|
|
provider.profile_url,
|
|
headers={
|
|
"Authorization": f"Bearer {token}",
|
|
"Accept": "application/json",
|
|
},
|
|
) as resp:
|
|
if resp.status != 200:
|
|
text = await resp.text()
|
|
raise OAuthError(f"User fetch failed ({resp.status}): {text}")
|
|
data = await resp.json()
|
|
except aiohttp.ClientError as exc:
|
|
raise OAuthError(f"User fetch network error: {exc}") from exc
|
|
|
|
try:
|
|
for key in provider.profile_data_path:
|
|
data = data.get(key, {})
|
|
avatar_field = provider.field_map.get("avatar")
|
|
return {
|
|
"login": data.get(provider.field_map["username"], ""),
|
|
"full_name": data.get(provider.field_map["full_name"], ""),
|
|
"avatar_url": data.get(avatar_field, "") if avatar_field else "",
|
|
}
|
|
except AttributeError:
|
|
raise OAuthError(f"Unexpected profile response structure from {provider.type}")
|