From 8207cd7b5fa04aa01cc2897b82a7d308d0751c90 Mon Sep 17 00:00:00 2001 From: Andreas Wrede Date: Sat, 9 May 2026 08:29:07 -0400 Subject: [PATCH] feat: add PROVIDER_DEFS, ResolvedProvider, get_providers() to oauth.py Co-Authored-By: Claude Sonnet 4.6 --- hbd/server/oauth.py | 98 +++++++++++++++++++++++++++++++ tests/test_oauth.py | 138 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 236 insertions(+) diff --git a/hbd/server/oauth.py b/hbd/server/oauth.py index 9435dcd..fbe619b 100644 --- a/hbd/server/oauth.py +++ b/hbd/server/oauth.py @@ -18,6 +18,7 @@ import logging import secrets import time import urllib.parse +from dataclasses import dataclass import aiohttp @@ -57,6 +58,103 @@ class OAuthError(Exception): """Raised when the OAuth2 flow fails for any reason.""" +PROVIDER_DEFS: dict = { + "gitea": { + "authorize_url_tmpl": "{url}/login/oauth/authorize", + "token_url_tmpl": "{url}/login/oauth/access_token", + "profile_url_tmpl": "{url}/api/v1/user", + "scope": "user:email", + "field_map": {"username": "login", "full_name": "full_name", "avatar": "avatar_url"}, + "profile_data_path": [], + "requires_url": True, + "default_label": "Gitea", + }, + "github": { + "authorize_url_tmpl": "https://github.com/login/oauth/authorize", + "token_url_tmpl": "https://github.com/login/oauth/access_token", + "profile_url_tmpl": "https://api.github.com/user", + "scope": "read:user", + "field_map": {"username": "login", "full_name": "name", "avatar": "avatar_url"}, + "profile_data_path": [], + "requires_url": False, + "default_label": "GitHub", + }, + "nextcloud": { + "authorize_url_tmpl": "{url}/apps/oauth2/authorize", + "token_url_tmpl": "{url}/apps/oauth2/api/v1/token", + "profile_url_tmpl": "{url}/ocs/v2.php/cloud/user?format=json", + "scope": "", + "field_map": {"username": "id", "full_name": "display-name", "avatar": None}, + "profile_data_path": ["ocs", "data"], + "requires_url": True, + "default_label": "Nextcloud", + }, +} + + +@dataclass +class ResolvedProvider: + """A fully resolved OAuth2 provider instance, ready to use.""" + name: str + type: str + label: str + logo: str + authorize_url: str + token_url: str + profile_url: str + scope: str + client_id: str + client_secret: str + field_map: dict + profile_data_path: list + + +def get_providers(config: dict) -> list[ResolvedProvider]: + """Return a ResolvedProvider for every valid entry in config['oauth']. + + Entries with missing required fields or unknown types are skipped with + a warning log. Order follows config declaration order. + """ + result = [] + oauth_cfg = config.get("oauth", {}) + if not isinstance(oauth_cfg, dict): + return result + for name, entry in oauth_cfg.items(): + if not isinstance(entry, dict): + continue + provider_type = entry.get("type", "gitea") + defn = PROVIDER_DEFS.get(provider_type) + if defn is None: + logger.warning("OAuth: unknown provider type %r for %r, skipping", provider_type, name) + continue + client_id = entry.get("client_id", "") + client_secret = entry.get("client_secret", "") + if not client_id or not client_secret: + logger.warning("OAuth: %r missing client_id or client_secret, skipping", name) + continue + url = entry.get("url", "").rstrip("/") + if defn["requires_url"] and not url: + logger.warning("OAuth: %r requires url but none configured, skipping", name) + continue + label = entry.get("label") or defn["default_label"] + logo = entry.get("logo", "") + result.append(ResolvedProvider( + name=name, + type=provider_type, + label=label, + logo=logo, + authorize_url=defn["authorize_url_tmpl"].format(url=url), + token_url=defn["token_url_tmpl"].format(url=url), + profile_url=defn["profile_url_tmpl"].format(url=url), + scope=defn["scope"], + client_id=client_id, + client_secret=client_secret, + field_map=defn["field_map"], + profile_data_path=defn["profile_data_path"], + )) + return result + + def _gitea_cfg(config: dict) -> dict: """Return the gitea sub-dict or {} if absent/incomplete.""" return config.get("oauth", {}).get("gitea", {}) diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 2767d25..056e736 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -322,3 +322,141 @@ async def test_full_oauth_flow_chain(): ) assert user.username == "flowuser" assert user.check_password("anything") is False + + +# --------------------------------------------------------------------------- +# get_providers() +# --------------------------------------------------------------------------- + +CFG_GITHUB = { + "oauth": { + "github": {"type": "github", "client_id": "ghid", "client_secret": "ghs"}, + } +} + +CFG_NEXTCLOUD = { + "oauth": { + "nc": { + "type": "nextcloud", + "url": "https://nc.example.com", + "client_id": "ncid", + "client_secret": "ncs", + } + } +} + +CFG_MULTI = { + "oauth": { + "mygitea": { + "type": "gitea", + "url": "https://git.example.com", + "client_id": "cid", + "client_secret": "cs", + "label": "Work Gitea", + "logo": "https://example.com/logo.png", + }, + "github": {"type": "github", "client_id": "ghid", "client_secret": "ghs"}, + "nc": { + "type": "nextcloud", + "url": "https://nc.example.com", + "client_id": "ncid", + "client_secret": "ncs", + }, + } +} + + +def test_get_providers_backward_compat_no_type_field(): + """Old config without 'type' defaults to gitea.""" + providers = oauth.get_providers(CFG_ON) + assert len(providers) == 1 + p = providers[0] + assert p.name == "gitea" + assert p.type == "gitea" + assert p.label == "Gitea" + assert p.client_id == "cid" + assert p.authorize_url == "https://git.example.com/login/oauth/authorize" + assert p.token_url == "https://git.example.com/login/oauth/access_token" + assert p.profile_url == "https://git.example.com/api/v1/user" + assert p.scope == "user:email" + assert p.profile_data_path == [] + + +def test_get_providers_multiple(): + providers = oauth.get_providers(CFG_MULTI) + assert len(providers) == 3 + names = [p.name for p in providers] + assert "mygitea" in names + assert "github" in names + assert "nc" in names + + +def test_get_providers_custom_label_and_logo(): + providers = oauth.get_providers(CFG_MULTI) + gitea = next(p for p in providers if p.name == "mygitea") + assert gitea.label == "Work Gitea" + assert gitea.logo == "https://example.com/logo.png" + + +def test_get_providers_github_default_label(): + providers = oauth.get_providers(CFG_GITHUB) + assert providers[0].label == "GitHub" + assert providers[0].logo == "" + + +def test_get_providers_github_fixed_urls(): + providers = oauth.get_providers(CFG_GITHUB) + p = providers[0] + assert p.authorize_url == "https://github.com/login/oauth/authorize" + assert p.token_url == "https://github.com/login/oauth/access_token" + assert p.profile_url == "https://api.github.com/user" + assert p.scope == "read:user" + + +def test_get_providers_nextcloud_urls_and_path(): + providers = oauth.get_providers(CFG_NEXTCLOUD) + p = providers[0] + assert p.authorize_url == "https://nc.example.com/apps/oauth2/authorize" + assert p.token_url == "https://nc.example.com/apps/oauth2/api/v1/token" + assert p.profile_url == "https://nc.example.com/ocs/v2.php/cloud/user?format=json" + assert p.profile_data_path == ["ocs", "data"] + assert p.scope == "" + + +def test_get_providers_skips_missing_client_id(): + cfg = {"oauth": {"gitea": {"url": "https://git.example.com", "client_secret": "cs"}}} + assert oauth.get_providers(cfg) == [] + + +def test_get_providers_skips_missing_client_secret(): + cfg = {"oauth": {"gitea": {"url": "https://git.example.com", "client_id": "cid"}}} + assert oauth.get_providers(cfg) == [] + + +def test_get_providers_skips_missing_url_for_gitea(): + cfg = {"oauth": {"gitea": {"type": "gitea", "client_id": "cid", "client_secret": "cs"}}} + assert oauth.get_providers(cfg) == [] + + +def test_get_providers_skips_missing_url_for_nextcloud(): + cfg = {"oauth": {"nc": {"type": "nextcloud", "client_id": "cid", "client_secret": "cs"}}} + assert oauth.get_providers(cfg) == [] + + +def test_get_providers_github_no_url_required(): + providers = oauth.get_providers(CFG_GITHUB) + assert len(providers) == 1 + + +def test_get_providers_skips_unknown_type(caplog): + cfg = {"oauth": {"mystery": {"type": "saml", "client_id": "cid", "client_secret": "cs"}}} + import logging + with caplog.at_level(logging.WARNING, logger="hbd.server.oauth"): + result = oauth.get_providers(cfg) + assert result == [] + assert "saml" in caplog.text + + +def test_get_providers_empty_config(): + assert oauth.get_providers({}) == [] + assert oauth.get_providers(CFG_OFF) == []