1029 lines
37 KiB
Markdown
1029 lines
37 KiB
Markdown
# Multi-Provider OAuth2 Implementation Plan
|
||
|
||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||
|
||
**Goal:** Replace the single hardcoded Gitea OAuth2 integration with a generic multi-provider system supporting Gitea, GitHub, and Nextcloud; all configured providers appear as login buttons.
|
||
|
||
**Architecture:** Add a `PROVIDER_DEFS` registry and `ResolvedProvider` dataclass to `oauth.py`; a new `get_providers()` function resolves raw config entries into typed, URL-complete provider objects. The two hardcoded Gitea routes in `http.py` become generic `{name}` routes that look up the provider by slug. The login page loops over all resolved providers to render buttons.
|
||
|
||
**Tech Stack:** Python 3.11+, aiohttp (already in use), dataclasses (stdlib), pytest + pytest-asyncio for tests.
|
||
|
||
---
|
||
|
||
## File Map
|
||
|
||
| File | Change |
|
||
|---|---|
|
||
| `hbd/server/oauth.py` | Add `PROVIDER_DEFS`, `ResolvedProvider`, `get_providers()`; replace `authorization_url` / `exchange_code` / `fetch_user` signatures to accept `ResolvedProvider`; update `is_enabled()` |
|
||
| `hbd/server/http.py` | Replace hardcoded Gitea routes with generic `{name}` routes; update login page to loop over providers; rename CSS classes `gitea-btn` → `oauth-btn`, `gitea-logo` → `oauth-logo` |
|
||
| `tests/test_oauth.py` | Add tests for `get_providers()` and all provider types; port existing tests to new signatures |
|
||
|
||
---
|
||
|
||
## Task 1: Provider registry and `get_providers()` in `oauth.py`
|
||
|
||
**Files:**
|
||
- Modify: `hbd/server/oauth.py`
|
||
- Test: `tests/test_oauth.py`
|
||
|
||
- [ ] **Step 1: Write failing tests for `get_providers()`**
|
||
|
||
Add these tests to `tests/test_oauth.py` (place after the existing `test_is_enabled_*` tests):
|
||
|
||
```python
|
||
# ---------------------------------------------------------------------------
|
||
# get_providers()
|
||
# ---------------------------------------------------------------------------
|
||
|
||
CFG_GITHUB = {
|
||
"oauth": {
|
||
"github": {"type": "github", "client_id": "ghid", "client_secret": "ghs"},
|
||
}
|
||
}
|
||
|
||
CFG_NEXTCLOUD = {
|
||
"oauth": {
|
||
"nc": {
|
||
"type": "nextcloud",
|
||
"url": "https://nc.example.com",
|
||
"client_id": "ncid",
|
||
"client_secret": "ncs",
|
||
}
|
||
}
|
||
}
|
||
|
||
CFG_MULTI = {
|
||
"oauth": {
|
||
"mygitea": {
|
||
"type": "gitea",
|
||
"url": "https://git.example.com",
|
||
"client_id": "cid",
|
||
"client_secret": "cs",
|
||
"label": "Work Gitea",
|
||
"logo": "https://example.com/logo.png",
|
||
},
|
||
"github": {"type": "github", "client_id": "ghid", "client_secret": "ghs"},
|
||
"nc": {
|
||
"type": "nextcloud",
|
||
"url": "https://nc.example.com",
|
||
"client_id": "ncid",
|
||
"client_secret": "ncs",
|
||
},
|
||
}
|
||
}
|
||
|
||
|
||
def test_get_providers_backward_compat_no_type_field():
|
||
"""Old config without 'type' defaults to gitea."""
|
||
providers = oauth.get_providers(CFG_ON)
|
||
assert len(providers) == 1
|
||
p = providers[0]
|
||
assert p.name == "gitea"
|
||
assert p.type == "gitea"
|
||
assert p.label == "Gitea"
|
||
assert p.client_id == "cid"
|
||
assert p.authorize_url == "https://git.example.com/login/oauth/authorize"
|
||
assert p.token_url == "https://git.example.com/login/oauth/access_token"
|
||
assert p.profile_url == "https://git.example.com/api/v1/user"
|
||
assert p.scope == "user:email"
|
||
assert p.profile_data_path == []
|
||
|
||
|
||
def test_get_providers_multiple():
|
||
providers = oauth.get_providers(CFG_MULTI)
|
||
assert len(providers) == 3
|
||
names = [p.name for p in providers]
|
||
assert "mygitea" in names
|
||
assert "github" in names
|
||
assert "nc" in names
|
||
|
||
|
||
def test_get_providers_custom_label_and_logo():
|
||
providers = oauth.get_providers(CFG_MULTI)
|
||
gitea = next(p for p in providers if p.name == "mygitea")
|
||
assert gitea.label == "Work Gitea"
|
||
assert gitea.logo == "https://example.com/logo.png"
|
||
|
||
|
||
def test_get_providers_github_default_label():
|
||
providers = oauth.get_providers(CFG_GITHUB)
|
||
assert providers[0].label == "GitHub"
|
||
assert providers[0].logo == ""
|
||
|
||
|
||
def test_get_providers_github_fixed_urls():
|
||
providers = oauth.get_providers(CFG_GITHUB)
|
||
p = providers[0]
|
||
assert p.authorize_url == "https://github.com/login/oauth/authorize"
|
||
assert p.token_url == "https://github.com/login/oauth/access_token"
|
||
assert p.profile_url == "https://api.github.com/user"
|
||
assert p.scope == "read:user"
|
||
|
||
|
||
def test_get_providers_nextcloud_urls_and_path():
|
||
providers = oauth.get_providers(CFG_NEXTCLOUD)
|
||
p = providers[0]
|
||
assert p.authorize_url == "https://nc.example.com/apps/oauth2/authorize"
|
||
assert p.token_url == "https://nc.example.com/apps/oauth2/api/v1/token"
|
||
assert p.profile_url == "https://nc.example.com/ocs/v2.php/cloud/user?format=json"
|
||
assert p.profile_data_path == ["ocs", "data"]
|
||
assert p.scope == ""
|
||
|
||
|
||
def test_get_providers_skips_missing_client_id():
|
||
cfg = {"oauth": {"gitea": {"url": "https://git.example.com", "client_secret": "cs"}}}
|
||
assert oauth.get_providers(cfg) == []
|
||
|
||
|
||
def test_get_providers_skips_missing_client_secret():
|
||
cfg = {"oauth": {"gitea": {"url": "https://git.example.com", "client_id": "cid"}}}
|
||
assert oauth.get_providers(cfg) == []
|
||
|
||
|
||
def test_get_providers_skips_missing_url_for_gitea():
|
||
cfg = {"oauth": {"gitea": {"type": "gitea", "client_id": "cid", "client_secret": "cs"}}}
|
||
assert oauth.get_providers(cfg) == []
|
||
|
||
|
||
def test_get_providers_skips_missing_url_for_nextcloud():
|
||
cfg = {"oauth": {"nc": {"type": "nextcloud", "client_id": "cid", "client_secret": "cs"}}}
|
||
assert oauth.get_providers(cfg) == []
|
||
|
||
|
||
def test_get_providers_github_no_url_required():
|
||
providers = oauth.get_providers(CFG_GITHUB)
|
||
assert len(providers) == 1
|
||
|
||
|
||
def test_get_providers_skips_unknown_type(caplog):
|
||
cfg = {"oauth": {"mystery": {"type": "saml", "client_id": "cid", "client_secret": "cs"}}}
|
||
import logging
|
||
with caplog.at_level(logging.WARNING, logger="hbd.server.oauth"):
|
||
result = oauth.get_providers(cfg)
|
||
assert result == []
|
||
assert "saml" in caplog.text
|
||
|
||
|
||
def test_get_providers_empty_config():
|
||
assert oauth.get_providers({}) == []
|
||
assert oauth.get_providers(CFG_OFF) == []
|
||
```
|
||
|
||
- [ ] **Step 2: Run the new tests to confirm they fail**
|
||
|
||
```bash
|
||
cd /home/andreas/git/heartbeat
|
||
python -m pytest tests/test_oauth.py::test_get_providers_backward_compat_no_type_field -v
|
||
```
|
||
|
||
Expected: `FAILED` — `AttributeError: module 'hbd.server.oauth' has no attribute 'get_providers'`
|
||
|
||
- [ ] **Step 3: Add `PROVIDER_DEFS`, `ResolvedProvider`, and `get_providers()` to `oauth.py`**
|
||
|
||
Add the following after the `OAuthError` class definition (after line 57), before `_gitea_cfg`:
|
||
|
||
```python
|
||
from dataclasses import dataclass
|
||
|
||
|
||
PROVIDER_DEFS: dict = {
|
||
"gitea": {
|
||
"authorize_url_tmpl": "{url}/login/oauth/authorize",
|
||
"token_url_tmpl": "{url}/login/oauth/access_token",
|
||
"profile_url_tmpl": "{url}/api/v1/user",
|
||
"scope": "user:email",
|
||
"field_map": {"username": "login", "full_name": "full_name", "avatar": "avatar_url"},
|
||
"profile_data_path": [],
|
||
"requires_url": True,
|
||
"default_label": "Gitea",
|
||
},
|
||
"github": {
|
||
"authorize_url_tmpl": "https://github.com/login/oauth/authorize",
|
||
"token_url_tmpl": "https://github.com/login/oauth/access_token",
|
||
"profile_url_tmpl": "https://api.github.com/user",
|
||
"scope": "read:user",
|
||
"field_map": {"username": "login", "full_name": "name", "avatar": "avatar_url"},
|
||
"profile_data_path": [],
|
||
"requires_url": False,
|
||
"default_label": "GitHub",
|
||
},
|
||
"nextcloud": {
|
||
"authorize_url_tmpl": "{url}/apps/oauth2/authorize",
|
||
"token_url_tmpl": "{url}/apps/oauth2/api/v1/token",
|
||
"profile_url_tmpl": "{url}/ocs/v2.php/cloud/user?format=json",
|
||
"scope": "",
|
||
"field_map": {"username": "id", "full_name": "display-name", "avatar": None},
|
||
"profile_data_path": ["ocs", "data"],
|
||
"requires_url": True,
|
||
"default_label": "Nextcloud",
|
||
},
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ResolvedProvider:
|
||
"""A fully resolved OAuth2 provider instance, ready to use."""
|
||
name: str # route slug (dict key from config)
|
||
type: str # "gitea" | "github" | "nextcloud"
|
||
label: str # display name for login button
|
||
logo: str # URL or ""
|
||
authorize_url: str # fully computed authorization endpoint
|
||
token_url: str # fully computed token endpoint
|
||
profile_url: str # fully computed user profile endpoint
|
||
scope: str # OAuth scope string (may be "")
|
||
client_id: str
|
||
client_secret: str
|
||
field_map: dict # {"username": str, "full_name": str, "avatar": str|None}
|
||
profile_data_path: list # keys to navigate before field_map (e.g. ["ocs","data"])
|
||
|
||
|
||
def get_providers(config: dict) -> list[ResolvedProvider]:
|
||
"""Return a ResolvedProvider for every valid entry in config['oauth'].
|
||
|
||
Entries with missing required fields or unknown types are skipped with
|
||
a warning log. Order follows config declaration order.
|
||
"""
|
||
result = []
|
||
oauth_cfg = config.get("oauth", {})
|
||
if not isinstance(oauth_cfg, dict):
|
||
return result
|
||
for name, entry in oauth_cfg.items():
|
||
if not isinstance(entry, dict):
|
||
continue
|
||
provider_type = entry.get("type", "gitea")
|
||
defn = PROVIDER_DEFS.get(provider_type)
|
||
if defn is None:
|
||
logger.warning("OAuth: unknown provider type %r for %r, skipping", provider_type, name)
|
||
continue
|
||
client_id = entry.get("client_id", "")
|
||
client_secret = entry.get("client_secret", "")
|
||
if not client_id or not client_secret:
|
||
logger.warning("OAuth: %r missing client_id or client_secret, skipping", name)
|
||
continue
|
||
url = entry.get("url", "").rstrip("/")
|
||
if defn["requires_url"] and not url:
|
||
logger.warning("OAuth: %r requires url but none configured, skipping", name)
|
||
continue
|
||
label = entry.get("label") or defn["default_label"]
|
||
logo = entry.get("logo", "")
|
||
result.append(ResolvedProvider(
|
||
name=name,
|
||
type=provider_type,
|
||
label=label,
|
||
logo=logo,
|
||
authorize_url=defn["authorize_url_tmpl"].format(url=url),
|
||
token_url=defn["token_url_tmpl"].format(url=url),
|
||
profile_url=defn["profile_url_tmpl"].format(url=url),
|
||
scope=defn["scope"],
|
||
client_id=client_id,
|
||
client_secret=client_secret,
|
||
field_map=defn["field_map"],
|
||
profile_data_path=defn["profile_data_path"],
|
||
))
|
||
return result
|
||
```
|
||
|
||
- [ ] **Step 4: Run all new `get_providers` tests**
|
||
|
||
```bash
|
||
python -m pytest tests/test_oauth.py -k "get_providers" -v
|
||
```
|
||
|
||
Expected: all `test_get_providers_*` tests PASS.
|
||
|
||
- [ ] **Step 5: Commit**
|
||
|
||
```bash
|
||
git add hbd/server/oauth.py tests/test_oauth.py
|
||
git commit -m "feat: add PROVIDER_DEFS, ResolvedProvider, get_providers() to oauth.py"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 2: Generic `build_auth_url`, `exchange_code`, `fetch_user`, `is_enabled`
|
||
|
||
**Files:**
|
||
- Modify: `hbd/server/oauth.py`
|
||
- Test: `tests/test_oauth.py`
|
||
|
||
- [ ] **Step 1: Write failing tests for the updated function signatures**
|
||
|
||
Add after the `get_providers` tests in `tests/test_oauth.py`:
|
||
|
||
```python
|
||
# ---------------------------------------------------------------------------
|
||
# build_auth_url / exchange_code / fetch_user (generic, ResolvedProvider-based)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _gitea_provider() -> oauth.ResolvedProvider:
|
||
return oauth.get_providers(CFG_ON)[0]
|
||
|
||
|
||
def _github_provider() -> oauth.ResolvedProvider:
|
||
return oauth.get_providers(CFG_GITHUB)[0]
|
||
|
||
|
||
def _nextcloud_provider() -> oauth.ResolvedProvider:
|
||
return oauth.get_providers(CFG_NEXTCLOUD)[0]
|
||
|
||
|
||
def test_build_auth_url_gitea():
|
||
p = _gitea_provider()
|
||
url = oauth.build_auth_url(p, "teststate", "https://hbd.example.com/login/oauth/gitea/callback")
|
||
parsed = urlparse(url)
|
||
qs = parse_qs(parsed.query)
|
||
assert parsed.netloc == "git.example.com"
|
||
assert parsed.path == "/login/oauth/authorize"
|
||
assert qs["client_id"] == ["cid"]
|
||
assert qs["state"] == ["teststate"]
|
||
assert qs["scope"] == ["user:email"]
|
||
assert qs["response_type"] == ["code"]
|
||
|
||
|
||
def test_build_auth_url_github():
|
||
p = _github_provider()
|
||
url = oauth.build_auth_url(p, "st", "https://hbd.example.com/login/oauth/github/callback")
|
||
parsed = urlparse(url)
|
||
qs = parse_qs(parsed.query)
|
||
assert parsed.netloc == "github.com"
|
||
assert qs["scope"] == ["read:user"]
|
||
|
||
|
||
def test_build_auth_url_nextcloud_no_scope_param():
|
||
"""Nextcloud scope is empty — the 'scope' key must be absent from the URL."""
|
||
p = _nextcloud_provider()
|
||
url = oauth.build_auth_url(p, "st", "https://hbd.example.com/login/oauth/nc/callback")
|
||
qs = parse_qs(urlparse(url).query)
|
||
assert "scope" not in qs
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_exchange_code_generic_returns_token():
|
||
p = _gitea_provider()
|
||
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={"access_token": "tok123"})
|
||
|
||
mock_session = MagicMock()
|
||
mock_session.post = MagicMock(return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_response),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
))
|
||
|
||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_session),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
)):
|
||
token = await oauth.exchange_code(p, "mycode", redirect_uri)
|
||
assert token == "tok123"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_exchange_code_sends_accept_json():
|
||
"""Accept: application/json must be present for all providers (required by GitHub)."""
|
||
p = _github_provider()
|
||
captured_headers = {}
|
||
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={"access_token": "ghtoken"})
|
||
|
||
mock_session = MagicMock()
|
||
|
||
def capture_post(url, **kwargs):
|
||
captured_headers.update(kwargs.get("headers", {}))
|
||
return AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_response),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
)
|
||
|
||
mock_session.post = capture_post
|
||
|
||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_session),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
)):
|
||
await oauth.exchange_code(p, "code", "https://hbd.example.com/login/oauth/github/callback")
|
||
|
||
assert captured_headers.get("Accept") == "application/json"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_exchange_code_raises_on_error_status():
|
||
p = _gitea_provider()
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 401
|
||
mock_response.text = AsyncMock(return_value="unauthorized")
|
||
|
||
mock_session = MagicMock()
|
||
mock_session.post = MagicMock(return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_response),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
))
|
||
|
||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_session),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
)):
|
||
with pytest.raises(oauth.OAuthError):
|
||
await oauth.exchange_code(p, "badcode", "https://hbd.example.com/login/oauth/gitea/callback")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_exchange_code_raises_when_no_access_token():
|
||
p = _gitea_provider()
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={"error": "bad_request"})
|
||
|
||
mock_session = MagicMock()
|
||
mock_session.post = MagicMock(return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_response),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
))
|
||
|
||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_session),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
)):
|
||
with pytest.raises(oauth.OAuthError):
|
||
await oauth.exchange_code(p, "mycode", "https://hbd.example.com/login/oauth/gitea/callback")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_fetch_user_gitea_returns_profile():
|
||
p = _gitea_provider()
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={
|
||
"login": "alice",
|
||
"full_name": "Alice Smith",
|
||
"avatar_url": "https://git.example.com/avatars/alice.png",
|
||
})
|
||
|
||
mock_session = MagicMock()
|
||
mock_session.get = MagicMock(return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_response),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
))
|
||
|
||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_session),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
)):
|
||
profile = await oauth.fetch_user(p, "tok123")
|
||
|
||
assert profile == {
|
||
"login": "alice",
|
||
"full_name": "Alice Smith",
|
||
"avatar_url": "https://git.example.com/avatars/alice.png",
|
||
}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_fetch_user_github_maps_name_field():
|
||
p = _github_provider()
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={
|
||
"login": "bobgh",
|
||
"name": "Bob GitHub",
|
||
"avatar_url": "https://avatars.githubusercontent.com/u/1",
|
||
})
|
||
|
||
mock_session = MagicMock()
|
||
mock_session.get = MagicMock(return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_response),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
))
|
||
|
||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_session),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
)):
|
||
profile = await oauth.fetch_user(p, "ghtoken")
|
||
|
||
assert profile["login"] == "bobgh"
|
||
assert profile["full_name"] == "Bob GitHub"
|
||
assert profile["avatar_url"] == "https://avatars.githubusercontent.com/u/1"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_fetch_user_nextcloud_nested_extraction():
|
||
"""Nextcloud profile is nested under ocs.data; avatar is absent."""
|
||
p = _nextcloud_provider()
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={
|
||
"ocs": {
|
||
"meta": {"status": "ok", "statuscode": 200},
|
||
"data": {
|
||
"id": "ncuser",
|
||
"display-name": "NC User",
|
||
"email": "nc@example.com",
|
||
},
|
||
}
|
||
})
|
||
|
||
mock_session = MagicMock()
|
||
mock_session.get = MagicMock(return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_response),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
))
|
||
|
||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_session),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
)):
|
||
profile = await oauth.fetch_user(p, "nctoken")
|
||
|
||
assert profile["login"] == "ncuser"
|
||
assert profile["full_name"] == "NC User"
|
||
assert profile["avatar_url"] == "" # Nextcloud has no avatar field
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_fetch_user_raises_on_error_status():
|
||
p = _gitea_provider()
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 401
|
||
mock_response.text = AsyncMock(return_value="unauthorized")
|
||
|
||
mock_session = MagicMock()
|
||
mock_session.get = MagicMock(return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_response),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
))
|
||
|
||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_session),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
)):
|
||
with pytest.raises(oauth.OAuthError):
|
||
await oauth.fetch_user(p, "badtoken")
|
||
|
||
|
||
def test_is_enabled_with_valid_provider():
|
||
assert oauth.is_enabled(CFG_ON) is True
|
||
|
||
|
||
def test_is_enabled_false_when_no_providers():
|
||
assert oauth.is_enabled(CFG_OFF) is False
|
||
|
||
|
||
def test_is_enabled_false_partial_config():
|
||
assert oauth.is_enabled(CFG_PARTIAL) is False
|
||
```
|
||
|
||
- [ ] **Step 2: Run the new tests — confirm they fail**
|
||
|
||
```bash
|
||
python -m pytest tests/test_oauth.py -k "build_auth_url or exchange_code_generic or exchange_code_sends or fetch_user" -v
|
||
```
|
||
|
||
Expected: `FAILED` — `AttributeError: module 'hbd.server.oauth' has no attribute 'build_auth_url'`
|
||
|
||
- [ ] **Step 3: Replace `_gitea_cfg`, `is_enabled`, `authorization_url`, `exchange_code`, `fetch_user` in `oauth.py`**
|
||
|
||
Delete these functions entirely (lines ~60–142 of the original file):
|
||
- `_gitea_cfg()`
|
||
- `is_enabled()`
|
||
- `authorization_url()`
|
||
- `exchange_code()`
|
||
- `fetch_user()`
|
||
|
||
Replace with:
|
||
|
||
```python
|
||
def is_enabled(config: dict) -> bool:
|
||
"""Return True when at least one OAuth provider is fully configured."""
|
||
return bool(get_providers(config))
|
||
|
||
|
||
def build_auth_url(provider: ResolvedProvider, state: str, redirect_uri: str) -> str:
|
||
"""Return the provider's OAuth2 authorization URL to redirect the browser to."""
|
||
params: dict = {
|
||
"client_id": provider.client_id,
|
||
"redirect_uri": redirect_uri,
|
||
"response_type": "code",
|
||
"state": state,
|
||
}
|
||
if provider.scope:
|
||
params["scope"] = provider.scope
|
||
return f"{provider.authorize_url}?{urllib.parse.urlencode(params)}"
|
||
|
||
|
||
async def exchange_code(provider: ResolvedProvider, code: str, redirect_uri: str) -> str:
|
||
"""Exchange an authorization *code* for an access token.
|
||
|
||
Returns the access token string. Raises OAuthError on any failure.
|
||
"""
|
||
payload = {
|
||
"client_id": provider.client_id,
|
||
"client_secret": provider.client_secret,
|
||
"code": code,
|
||
"grant_type": "authorization_code",
|
||
"redirect_uri": redirect_uri,
|
||
}
|
||
timeout = aiohttp.ClientTimeout(total=10)
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.post(
|
||
provider.token_url,
|
||
json=payload,
|
||
headers={"Accept": "application/json"},
|
||
) as resp:
|
||
if resp.status != 200:
|
||
text = await resp.text()
|
||
raise OAuthError(f"Token exchange failed ({resp.status}): {text}")
|
||
data = await resp.json()
|
||
token = data.get("access_token")
|
||
if not token:
|
||
raise OAuthError(f"No access_token in response: {data}")
|
||
except aiohttp.ClientError as exc:
|
||
raise OAuthError(f"Token exchange network error: {exc}") from exc
|
||
return token
|
||
|
||
|
||
async def fetch_user(provider: ResolvedProvider, token: str) -> dict:
|
||
"""Fetch the authenticated user's profile from the provider.
|
||
|
||
Returns a dict with keys: login, full_name, avatar_url.
|
||
Raises OAuthError on any failure.
|
||
"""
|
||
timeout = aiohttp.ClientTimeout(total=10)
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.get(
|
||
provider.profile_url,
|
||
headers={
|
||
"Authorization": f"Bearer {token}",
|
||
"Accept": "application/json",
|
||
},
|
||
) as resp:
|
||
if resp.status != 200:
|
||
text = await resp.text()
|
||
raise OAuthError(f"User fetch failed ({resp.status}): {text}")
|
||
data = await resp.json()
|
||
except aiohttp.ClientError as exc:
|
||
raise OAuthError(f"User fetch network error: {exc}") from exc
|
||
|
||
# Navigate nested path (e.g. ["ocs", "data"] for Nextcloud)
|
||
for key in provider.profile_data_path:
|
||
data = data.get(key, {})
|
||
|
||
avatar_field = provider.field_map.get("avatar")
|
||
return {
|
||
"login": data.get(provider.field_map["username"], ""),
|
||
"full_name": data.get(provider.field_map["full_name"], ""),
|
||
"avatar_url": data.get(avatar_field, "") if avatar_field else "",
|
||
}
|
||
```
|
||
|
||
Also update the module docstring at the top of `oauth.py` to describe the new multi-provider config shape:
|
||
|
||
```python
|
||
"""OAuth2 provider support.
|
||
|
||
Config shape (in ~/.hb.yaml):
|
||
|
||
oauth:
|
||
my-gitea: # route slug → /login/oauth/my-gitea
|
||
type: gitea # "gitea" | "github" | "nextcloud"
|
||
# omit type to default to "gitea"
|
||
url: https://git.example.com # required for gitea and nextcloud
|
||
client_id: <client-id>
|
||
client_secret: <client-secret>
|
||
label: "Work Gitea" # optional display name on login button
|
||
logo: https://example.com/logo.png # optional logo URL
|
||
|
||
github:
|
||
type: github
|
||
client_id: <client-id>
|
||
client_secret: <client-secret>
|
||
|
||
nextcloud:
|
||
type: nextcloud
|
||
url: https://cloud.example.com
|
||
client_id: <client-id>
|
||
client_secret: <client-secret>
|
||
|
||
Register the OAuth app with each provider and set the redirect URI to:
|
||
https://<hbd-host>/login/oauth/<name>/callback
|
||
"""
|
||
```
|
||
|
||
- [ ] **Step 4: Remove the old tests that used the old signatures**
|
||
|
||
In `tests/test_oauth.py`, delete these now-obsolete test functions (they tested the old `authorization_url(config, ...)` / `exchange_code(config, ...)` / `fetch_user(config, ...)` signatures):
|
||
- `test_authorization_url_shape`
|
||
- `test_exchange_code_returns_token`
|
||
- `test_exchange_code_raises_on_error_status`
|
||
- `test_fetch_user_returns_profile`
|
||
- `test_exchange_code_raises_when_no_access_token`
|
||
- `test_fetch_user_raises_on_error_status`
|
||
- `test_is_enabled_when_all_keys_present`
|
||
- `test_is_enabled_false_when_no_oauth_key`
|
||
- `test_is_enabled_false_when_partial_config`
|
||
|
||
Also update `test_full_oauth_flow_chain` to use the new API:
|
||
|
||
```python
|
||
@pytest.mark.asyncio
|
||
async def test_full_oauth_flow_chain():
|
||
"""Integration-style test: state → exchange → fetch → provision chain."""
|
||
p = _gitea_provider()
|
||
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||
|
||
# Step 1: create a state token
|
||
state = oauth.make_state()
|
||
assert oauth.validate_state(state) is True # consumed; replay would return False
|
||
|
||
# Step 2: exchange code → token (mocked)
|
||
mock_token_response = AsyncMock()
|
||
mock_token_response.status = 200
|
||
mock_token_response.json = AsyncMock(return_value={"access_token": "flow_token"})
|
||
|
||
mock_user_response = AsyncMock()
|
||
mock_user_response.status = 200
|
||
mock_user_response.json = AsyncMock(return_value={
|
||
"login": "flowuser",
|
||
"full_name": "Flow User",
|
||
"avatar_url": "https://git.example.com/avatars/flow.png",
|
||
})
|
||
|
||
mock_session = MagicMock()
|
||
mock_session.post = MagicMock(return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_token_response),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
))
|
||
mock_session.get = MagicMock(return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_user_response),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
))
|
||
|
||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||
__aenter__=AsyncMock(return_value=mock_session),
|
||
__aexit__=AsyncMock(return_value=False),
|
||
)):
|
||
token = await oauth.exchange_code(p, "authcode", redirect_uri)
|
||
profile = await oauth.fetch_user(p, token)
|
||
|
||
assert token == "flow_token"
|
||
assert profile["login"] == "flowuser"
|
||
|
||
# Step 3: provision user
|
||
_reset_users()
|
||
user = users_mod.provision_oauth_user(
|
||
profile["login"], profile["full_name"], profile["avatar_url"]
|
||
)
|
||
assert user.username == "flowuser"
|
||
assert user.check_password("anything") is False
|
||
```
|
||
|
||
- [ ] **Step 5: Run the full test suite**
|
||
|
||
```bash
|
||
python -m pytest tests/test_oauth.py -v
|
||
```
|
||
|
||
Expected: all tests PASS, no failures.
|
||
|
||
- [ ] **Step 6: Commit**
|
||
|
||
```bash
|
||
git add hbd/server/oauth.py tests/test_oauth.py
|
||
git commit -m "feat: generic build_auth_url/exchange_code/fetch_user for multi-provider OAuth2"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 3: Update `http.py` — generic routes and multi-provider login page
|
||
|
||
**Files:**
|
||
- Modify: `hbd/server/http.py`
|
||
|
||
- [ ] **Step 1: Replace `_oauth_redirect_uri`, `oauth_gitea_redirect`, `oauth_gitea_callback` with generic handlers**
|
||
|
||
Find this block in `http.py` (around line 921):
|
||
|
||
```python
|
||
def _oauth_redirect_uri(request) -> str:
|
||
base = config.get("base_url", "").rstrip("/") or str(request.url.origin())
|
||
return f"{base}/login/oauth/gitea/callback"
|
||
|
||
async def oauth_gitea_redirect(request):
|
||
"""GET /login/oauth/gitea — kick off the Gitea OAuth2 flow."""
|
||
if not oauth_mod.is_enabled(config):
|
||
return web.Response(status=404, text="OAuth not configured")
|
||
state = oauth_mod.make_state()
|
||
raise web.HTTPFound(oauth_mod.authorization_url(config, state, _oauth_redirect_uri(request)))
|
||
|
||
async def oauth_gitea_callback(request):
|
||
"""GET /login/oauth/gitea/callback — handle Gitea's redirect back."""
|
||
if not oauth_mod.is_enabled(config):
|
||
return web.Response(status=404, text="OAuth not configured")
|
||
code = request.rel_url.query.get("code", "")
|
||
state = request.rel_url.query.get("state", "")
|
||
if not code or not state:
|
||
return web.Response(status=400, text="Missing code or state")
|
||
if not oauth_mod.validate_state(state):
|
||
logger.warning("OAuth: invalid or expired state token from %s", request.remote)
|
||
raise web.HTTPFound("/login?error=1")
|
||
try:
|
||
token = await oauth_mod.exchange_code(config, code, _oauth_redirect_uri(request))
|
||
profile = await oauth_mod.fetch_user(config, token)
|
||
except oauth_mod.OAuthError as exc:
|
||
logger.warning("OAuth error: %s", exc)
|
||
raise web.HTTPFound("/login?error=1")
|
||
user = users_mod.provision_oauth_user(
|
||
profile["login"],
|
||
profile["full_name"],
|
||
profile["avatar_url"],
|
||
)
|
||
session_token = users_mod.create_session(user.username)
|
||
resp = web.HTTPFound("/")
|
||
resp.set_cookie(
|
||
SESSION_COOKIE,
|
||
session_token,
|
||
max_age=users_mod.SESSION_TTL,
|
||
httponly=True,
|
||
samesite="Lax",
|
||
)
|
||
raise resp
|
||
```
|
||
|
||
Replace it with:
|
||
|
||
```python
|
||
def _oauth_redirect_uri(request, provider_name: str) -> str:
|
||
base = config.get("base_url", "").rstrip("/") or str(request.url.origin())
|
||
return f"{base}/login/oauth/{provider_name}/callback"
|
||
|
||
def _get_oauth_provider(name: str):
|
||
"""Return the ResolvedProvider for *name*, or None if not found."""
|
||
return next(
|
||
(p for p in oauth_mod.get_providers(config) if p.name == name),
|
||
None,
|
||
)
|
||
|
||
async def oauth_redirect(request):
|
||
"""GET /login/oauth/{name} — kick off the OAuth2 flow for the named provider."""
|
||
name = request.match_info["name"]
|
||
provider = _get_oauth_provider(name)
|
||
if provider is None:
|
||
return web.Response(status=404, text="OAuth provider not found")
|
||
state = oauth_mod.make_state()
|
||
raise web.HTTPFound(
|
||
oauth_mod.build_auth_url(provider, state, _oauth_redirect_uri(request, name))
|
||
)
|
||
|
||
async def oauth_callback(request):
|
||
"""GET /login/oauth/{name}/callback — handle the provider's redirect back."""
|
||
name = request.match_info["name"]
|
||
provider = _get_oauth_provider(name)
|
||
if provider is None:
|
||
return web.Response(status=404, text="OAuth provider not found")
|
||
code = request.rel_url.query.get("code", "")
|
||
state = request.rel_url.query.get("state", "")
|
||
if not code or not state:
|
||
return web.Response(status=400, text="Missing code or state")
|
||
if not oauth_mod.validate_state(state):
|
||
logger.warning("OAuth: invalid or expired state token from %s", request.remote)
|
||
raise web.HTTPFound("/login?error=1")
|
||
try:
|
||
token = await oauth_mod.exchange_code(provider, code, _oauth_redirect_uri(request, name))
|
||
profile = await oauth_mod.fetch_user(provider, token)
|
||
except oauth_mod.OAuthError as exc:
|
||
logger.warning("OAuth error: %s", exc)
|
||
raise web.HTTPFound("/login?error=1")
|
||
user = users_mod.provision_oauth_user(
|
||
profile["login"],
|
||
profile["full_name"],
|
||
profile["avatar_url"],
|
||
)
|
||
session_token = users_mod.create_session(user.username)
|
||
resp = web.HTTPFound("/")
|
||
resp.set_cookie(
|
||
SESSION_COOKIE,
|
||
session_token,
|
||
max_age=users_mod.SESSION_TTL,
|
||
httponly=True,
|
||
samesite="Lax",
|
||
)
|
||
raise resp
|
||
```
|
||
|
||
- [ ] **Step 2: Update the route registration**
|
||
|
||
Find this in `app.add_routes([...])`:
|
||
|
||
```python
|
||
web.get("/login/oauth/gitea", oauth_gitea_redirect),
|
||
web.get("/login/oauth/gitea/callback", oauth_gitea_callback),
|
||
```
|
||
|
||
Replace with:
|
||
|
||
```python
|
||
web.get("/login/oauth/{name}", oauth_redirect),
|
||
web.get("/login/oauth/{name}/callback", oauth_callback),
|
||
```
|
||
|
||
- [ ] **Step 3: Update the login page to loop over providers**
|
||
|
||
Find this block in `login_page` (around line 628):
|
||
|
||
```python
|
||
gitea_button = ""
|
||
if oauth_mod.is_enabled(config):
|
||
logo_url = config.get("oauth", {}).get("gitea", {}).get("logo", "")
|
||
logo_img = f'<img src="{logo_url}" alt="" class="gitea-logo">' if logo_url else ""
|
||
gitea_button = f"""
|
||
<div class="divider">or</div>
|
||
<a href="/login/oauth/gitea" class="gitea-btn">
|
||
{logo_img}Sign in with Gitea
|
||
</a>"""
|
||
```
|
||
|
||
Replace with:
|
||
|
||
```python
|
||
oauth_buttons = ""
|
||
_providers = oauth_mod.get_providers(config)
|
||
if _providers:
|
||
buttons_html = ""
|
||
for _p in _providers:
|
||
_logo = f'<img src="{_p.logo}" alt="" class="oauth-logo">' if _p.logo else ""
|
||
buttons_html += f"""
|
||
<a href="/login/oauth/{_p.name}" class="oauth-btn">
|
||
{_logo}{_p.label}
|
||
</a>"""
|
||
oauth_buttons = f"""
|
||
<div class="divider">or</div>{buttons_html}"""
|
||
```
|
||
|
||
- [ ] **Step 4: Update CSS classes and HTML reference in the login page template**
|
||
|
||
In the `html = f"""..."""` block, make these replacements:
|
||
|
||
1. Replace `{gitea_button}` with `{oauth_buttons}` (in the `</form>{gitea_button}` line).
|
||
|
||
2. Replace the two CSS rules:
|
||
```css
|
||
.gitea-btn {{ display: flex; align-items: center; justify-content: center;
|
||
gap: .5em; width: 100%; padding: .6em; background: #16191d;
|
||
color: #fff; border-radius: 4px; font-size: 1em; text-align: center;
|
||
text-decoration: none; box-sizing: border-box; }}
|
||
.gitea-btn:hover {{ background: #4e7d1e; }}
|
||
.gitea-logo {{ height: 1.2em; width: auto; vertical-align: middle; }}
|
||
```
|
||
with:
|
||
```css
|
||
.oauth-btn {{ display: flex; align-items: center; justify-content: center;
|
||
gap: .5em; width: 100%; padding: .6em; background: #16191d;
|
||
color: #fff; border-radius: 4px; font-size: 1em; text-align: center;
|
||
text-decoration: none; box-sizing: border-box; margin-top: .5em; }}
|
||
.oauth-btn:hover {{ background: #444; }}
|
||
.oauth-logo {{ height: 1.2em; width: auto; vertical-align: middle; }}
|
||
```
|
||
|
||
- [ ] **Step 5: Run the full test suite**
|
||
|
||
```bash
|
||
python -m pytest tests/ -v
|
||
```
|
||
|
||
Expected: all tests PASS.
|
||
|
||
- [ ] **Step 6: Commit**
|
||
|
||
```bash
|
||
git add hbd/server/http.py
|
||
git commit -m "feat: multi-provider OAuth2 login page and generic routes"
|
||
```
|
||
|
||
---
|
||
|
||
## Self-Review Checklist (already done — kept for reference)
|
||
|
||
| Spec requirement | Task covering it |
|
||
|---|---|
|
||
| Dict-of-named-instances config with `type` field | Task 1 (`get_providers`) |
|
||
| `type` defaults to `"gitea"` when absent (backward compat) | Task 1 (`get_providers`, test) |
|
||
| `label` optional, falls back to provider default | Task 1 |
|
||
| `logo` optional, defaults to `""` | Task 1 |
|
||
| Gitea, GitHub, Nextcloud provider defs | Task 1 (`PROVIDER_DEFS`) |
|
||
| GitHub needs no `url` | Task 1 (test: `test_get_providers_github_no_url_required`) |
|
||
| Nextcloud `profile_data_path = ["ocs", "data"]` | Task 1 (test) |
|
||
| `Accept: application/json` on all token requests | Task 2 (test: `test_exchange_code_sends_accept_json`) |
|
||
| Nextcloud nested profile extraction | Task 2 (test: `test_fetch_user_nextcloud_nested_extraction`) |
|
||
| Nextcloud avatar absent → `""` | Task 2 (test) |
|
||
| `is_enabled` updated | Task 2 |
|
||
| Generic `{name}` routes | Task 3 |
|
||
| Login page loops over providers | Task 3 |
|
||
| One `or` divider regardless of provider count | Task 3 |
|
||
| Unknown provider name → 404 | Task 3 (`_get_oauth_provider` returns None) |
|
||
| Invalid/missing config entries skipped with warning | Task 1 (test: `test_get_providers_skips_unknown_type`) |
|