diff --git a/hbd/server/oauth.py b/hbd/server/oauth.py index 60f7d2d..026b921 100644 --- a/hbd/server/oauth.py +++ b/hbd/server/oauth.py @@ -1,17 +1,30 @@ -"""Gitea OAuth2 support. +"""OAuth2 provider support. Config shape (in ~/.hb.yaml): oauth: - gitea: - url: https://git.example.com + my-gitea: # route slug → /login/oauth/my-gitea + type: gitea # "gitea" | "github" | "nextcloud" + # omit type to default to "gitea" + url: https://git.example.com # required for gitea and nextcloud + client_id: + client_secret: + label: "Work Gitea" # optional display name on login button + logo: https://example.com/logo.png # optional logo URL + + github: + type: github client_id: client_secret: -Register a Gitea OAuth2 application at: - Gitea → Settings → Applications → OAuth2 -Set the redirect URI to: - https:///login/oauth/gitea/callback + nextcloud: + type: nextcloud + url: https://cloud.example.com + client_id: + client_secret: + +Register the OAuth app with each provider and set the redirect URI to: + https:///login/oauth//callback """ import logging @@ -155,44 +168,32 @@ def get_providers(config: dict) -> list[ResolvedProvider]: return result -def _gitea_cfg(config: dict) -> dict: - """Return the gitea sub-dict or {} if absent/incomplete.""" - return config.get("oauth", {}).get("gitea", {}) - - def is_enabled(config: dict) -> bool: - """Return True when all three required Gitea OAuth keys are present.""" - g = _gitea_cfg(config) - return bool(g.get("url") and g.get("client_id") and g.get("client_secret")) + """Return True when at least one OAuth provider is fully configured.""" + return bool(get_providers(config)) -def authorization_url(config: dict, state: str, redirect_uri: str) -> str: - """Return the Gitea OAuth2 authorization URL to redirect the browser to.""" - g = _gitea_cfg(config) - if not (g.get("url") and g.get("client_id") and g.get("client_secret")): - raise OAuthError("Gitea OAuth2 is not configured") - params = urllib.parse.urlencode({ - "client_id": g["client_id"], +def build_auth_url(provider: ResolvedProvider, state: str, redirect_uri: str) -> str: + """Return the provider's OAuth2 authorization URL to redirect the browser to.""" + params: dict = { + "client_id": provider.client_id, "redirect_uri": redirect_uri, "response_type": "code", - "scope": "user:email", "state": state, - }) - return f"{g['url'].rstrip('/')}/login/oauth/authorize?{params}" + } + if provider.scope: + params["scope"] = provider.scope + return f"{provider.authorize_url}?{urllib.parse.urlencode(params)}" -async def exchange_code(config: dict, code: str, redirect_uri: str) -> str: - """Exchange an authorization *code* for a Gitea access token. +async def exchange_code(provider: ResolvedProvider, code: str, redirect_uri: str) -> str: + """Exchange an authorization *code* for an access token. Returns the access token string. Raises OAuthError on any failure. """ - g = _gitea_cfg(config) - if not (g.get("url") and g.get("client_id") and g.get("client_secret")): - raise OAuthError("Gitea OAuth2 is not configured") - url = f"{g['url'].rstrip('/')}/login/oauth/access_token" payload = { - "client_id": g["client_id"], - "client_secret": g["client_secret"], + "client_id": provider.client_id, + "client_secret": provider.client_secret, "code": code, "grant_type": "authorization_code", "redirect_uri": redirect_uri, @@ -200,7 +201,11 @@ async def exchange_code(config: dict, code: str, redirect_uri: str) -> str: timeout = aiohttp.ClientTimeout(total=10) try: async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(url, json=payload, headers={"Accept": "application/json"}) as resp: + async with session.post( + provider.token_url, + json=payload, + headers={"Accept": "application/json"}, + ) as resp: if resp.status != 200: text = await resp.text() raise OAuthError(f"Token exchange failed ({resp.status}): {text}") @@ -213,28 +218,35 @@ async def exchange_code(config: dict, code: str, redirect_uri: str) -> str: return token -async def fetch_user(config: dict, token: str) -> dict: - """Fetch the authenticated user's profile from Gitea. +async def fetch_user(provider: ResolvedProvider, token: str) -> dict: + """Fetch the authenticated user's profile from the provider. Returns a dict with keys: login, full_name, avatar_url. Raises OAuthError on any failure. """ - g = _gitea_cfg(config) - if not (g.get("url") and g.get("client_id") and g.get("client_secret")): - raise OAuthError("Gitea OAuth2 is not configured") - url = f"{g['url'].rstrip('/')}/api/v1/user" timeout = aiohttp.ClientTimeout(total=10) try: async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(url, headers={"Authorization": f"token {token}"}) as resp: + async with session.get( + provider.profile_url, + headers={ + "Authorization": f"Bearer {token}", + "Accept": "application/json", + }, + ) as resp: if resp.status != 200: text = await resp.text() raise OAuthError(f"User fetch failed ({resp.status}): {text}") data = await resp.json() except aiohttp.ClientError as exc: raise OAuthError(f"User fetch network error: {exc}") from exc + + for key in provider.profile_data_path: + data = data.get(key, {}) + + avatar_field = provider.field_map.get("avatar") return { - "login": data.get("login", ""), - "full_name": data.get("full_name", ""), - "avatar_url": data.get("avatar_url", ""), + "login": data.get(provider.field_map["username"], ""), + "full_name": data.get(provider.field_map["full_name"], ""), + "avatar_url": data.get(avatar_field, "") if avatar_field else "", } diff --git a/tests/test_oauth.py b/tests/test_oauth.py index c9d0ee1..df479c7 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -37,17 +37,6 @@ def reset_users_dict(): users_mod.users = original -def test_is_enabled_when_all_keys_present(): - assert oauth.is_enabled(CFG_ON) is True - - -def test_is_enabled_false_when_no_oauth_key(): - assert oauth.is_enabled(CFG_OFF) is False - - -def test_is_enabled_false_when_partial_config(): - assert oauth.is_enabled(CFG_PARTIAL) is False - def test_make_state_returns_unique_tokens(): s1 = oauth.make_state() @@ -135,132 +124,6 @@ def test_provision_oauth_user_survives_config_reload(): assert "oauthonly" in users_mod.users -def test_authorization_url_shape(): - state = "teststate" - redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback" - url = oauth.authorization_url(CFG_ON, state, redirect_uri) - parsed = urlparse(url) - qs = parse_qs(parsed.query) - assert parsed.scheme == "https" - assert parsed.netloc == "git.example.com" - assert parsed.path == "/login/oauth/authorize" - assert qs["client_id"] == ["cid"] - assert qs["state"] == ["teststate"] - assert qs["redirect_uri"] == [redirect_uri] - assert qs["scope"] == ["user:email"] - assert qs["response_type"] == ["code"] - - -@pytest.mark.asyncio -async def test_exchange_code_returns_token(): - redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback" - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value={"access_token": "tok123"}) - - mock_session = MagicMock() - mock_session.post = MagicMock(return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_response), - __aexit__=AsyncMock(return_value=False), - )) - - with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_session), - __aexit__=AsyncMock(return_value=False), - )): - token = await oauth.exchange_code(CFG_ON, "mycode", redirect_uri) - assert token == "tok123" - - -@pytest.mark.asyncio -async def test_exchange_code_raises_on_error_status(): - redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback" - mock_response = AsyncMock() - mock_response.status = 401 - mock_response.text = AsyncMock(return_value="unauthorized") - - mock_session = MagicMock() - mock_session.post = MagicMock(return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_response), - __aexit__=AsyncMock(return_value=False), - )) - - with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_session), - __aexit__=AsyncMock(return_value=False), - )): - with pytest.raises(oauth.OAuthError): - await oauth.exchange_code(CFG_ON, "badcode", redirect_uri) - - -@pytest.mark.asyncio -async def test_fetch_user_returns_profile(): - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value={ - "login": "alice", - "full_name": "Alice Smith", - "avatar_url": "https://git.example.com/avatars/alice.png", - }) - - mock_session = MagicMock() - mock_session.get = MagicMock(return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_response), - __aexit__=AsyncMock(return_value=False), - )) - - with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_session), - __aexit__=AsyncMock(return_value=False), - )): - profile = await oauth.fetch_user(CFG_ON, "tok123") - assert profile == { - "login": "alice", - "full_name": "Alice Smith", - "avatar_url": "https://git.example.com/avatars/alice.png", - } - - -@pytest.mark.asyncio -async def test_exchange_code_raises_when_no_access_token(): - redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback" - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value={"error": "bad_request"}) - - mock_session = MagicMock() - mock_session.post = MagicMock(return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_response), - __aexit__=AsyncMock(return_value=False), - )) - - with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_session), - __aexit__=AsyncMock(return_value=False), - )): - with pytest.raises(oauth.OAuthError): - await oauth.exchange_code(CFG_ON, "mycode", redirect_uri) - - -@pytest.mark.asyncio -async def test_fetch_user_raises_on_error_status(): - mock_response = AsyncMock() - mock_response.status = 401 - mock_response.text = AsyncMock(return_value="unauthorized") - - mock_session = MagicMock() - mock_session.get = MagicMock(return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_response), - __aexit__=AsyncMock(return_value=False), - )) - - with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_session), - __aexit__=AsyncMock(return_value=False), - )): - with pytest.raises(oauth.OAuthError): - await oauth.fetch_user(CFG_ON, "tok123") - # --------------------------------------------------------------------------- # Integration-style tests: callback logic chain @@ -277,13 +140,12 @@ async def test_callback_invalid_state_rejects(): @pytest.mark.asyncio async def test_full_oauth_flow_chain(): """Integration-style test: state → exchange → fetch → provision chain.""" + p = _gitea_provider() redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback" - # Step 1: create a state token state = oauth.make_state() - assert oauth.validate_state(state) is True # consumed; replay would return False + assert oauth.validate_state(state) is True - # Step 2: exchange code → token (mocked) mock_token_response = AsyncMock() mock_token_response.status = 200 mock_token_response.json = AsyncMock(return_value={"access_token": "flow_token"}) @@ -310,13 +172,12 @@ async def test_full_oauth_flow_chain(): __aenter__=AsyncMock(return_value=mock_session), __aexit__=AsyncMock(return_value=False), )): - token = await oauth.exchange_code(CFG_ON, "authcode", redirect_uri) - profile = await oauth.fetch_user(CFG_ON, token) + token = await oauth.exchange_code(p, "authcode", redirect_uri) + profile = await oauth.fetch_user(p, token) assert token == "flow_token" assert profile["login"] == "flowuser" - # Step 3: provision user _reset_users() user = users_mod.provision_oauth_user( profile["login"], profile["full_name"], profile["avatar_url"] @@ -473,3 +334,268 @@ def test_get_providers_skips_unknown_type(caplog): def test_get_providers_empty_config(): assert oauth.get_providers({}) == [] assert oauth.get_providers(CFG_OFF) == [] + + +# --------------------------------------------------------------------------- +# build_auth_url / exchange_code / fetch_user (generic, ResolvedProvider-based) +# --------------------------------------------------------------------------- + +def _gitea_provider() -> oauth.ResolvedProvider: + return oauth.get_providers(CFG_ON)[0] + + +def _github_provider() -> oauth.ResolvedProvider: + return oauth.get_providers(CFG_GITHUB)[0] + + +def _nextcloud_provider() -> oauth.ResolvedProvider: + return oauth.get_providers(CFG_NEXTCLOUD)[0] + + +def test_build_auth_url_gitea(): + p = _gitea_provider() + url = oauth.build_auth_url(p, "teststate", "https://hbd.example.com/login/oauth/gitea/callback") + parsed = urlparse(url) + qs = parse_qs(parsed.query) + assert parsed.netloc == "git.example.com" + assert parsed.path == "/login/oauth/authorize" + assert qs["client_id"] == ["cid"] + assert qs["state"] == ["teststate"] + assert qs["scope"] == ["user:email"] + assert qs["response_type"] == ["code"] + + +def test_build_auth_url_github(): + p = _github_provider() + url = oauth.build_auth_url(p, "st", "https://hbd.example.com/login/oauth/github/callback") + parsed = urlparse(url) + qs = parse_qs(parsed.query) + assert parsed.netloc == "github.com" + assert qs["scope"] == ["read:user"] + + +def test_build_auth_url_nextcloud_no_scope_param(): + """Nextcloud scope is empty — the 'scope' key must be absent from the URL.""" + p = _nextcloud_provider() + url = oauth.build_auth_url(p, "st", "https://hbd.example.com/login/oauth/nc/callback") + qs = parse_qs(urlparse(url).query) + assert "scope" not in qs + + +@pytest.mark.asyncio +async def test_exchange_code_generic_returns_token(): + p = _gitea_provider() + redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback" + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"access_token": "tok123"}) + + mock_session = MagicMock() + mock_session.post = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_response), + __aexit__=AsyncMock(return_value=False), + )) + + with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_session), + __aexit__=AsyncMock(return_value=False), + )): + token = await oauth.exchange_code(p, "mycode", redirect_uri) + assert token == "tok123" + + +@pytest.mark.asyncio +async def test_exchange_code_sends_accept_json(): + """Accept: application/json must be present for all providers (required by GitHub).""" + p = _github_provider() + captured_headers = {} + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"access_token": "ghtoken"}) + + mock_session = MagicMock() + + def capture_post(url, **kwargs): + captured_headers.update(kwargs.get("headers", {})) + return AsyncMock( + __aenter__=AsyncMock(return_value=mock_response), + __aexit__=AsyncMock(return_value=False), + ) + + mock_session.post = capture_post + + with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_session), + __aexit__=AsyncMock(return_value=False), + )): + await oauth.exchange_code(p, "code", "https://hbd.example.com/login/oauth/github/callback") + + assert captured_headers.get("Accept") == "application/json" + + +@pytest.mark.asyncio +async def test_exchange_code_raises_on_error_status(): + p = _gitea_provider() + mock_response = AsyncMock() + mock_response.status = 401 + mock_response.text = AsyncMock(return_value="unauthorized") + + mock_session = MagicMock() + mock_session.post = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_response), + __aexit__=AsyncMock(return_value=False), + )) + + with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_session), + __aexit__=AsyncMock(return_value=False), + )): + with pytest.raises(oauth.OAuthError): + await oauth.exchange_code(p, "badcode", "https://hbd.example.com/login/oauth/gitea/callback") + + +@pytest.mark.asyncio +async def test_exchange_code_raises_when_no_access_token(): + p = _gitea_provider() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"error": "bad_request"}) + + mock_session = MagicMock() + mock_session.post = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_response), + __aexit__=AsyncMock(return_value=False), + )) + + with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_session), + __aexit__=AsyncMock(return_value=False), + )): + with pytest.raises(oauth.OAuthError): + await oauth.exchange_code(p, "mycode", "https://hbd.example.com/login/oauth/gitea/callback") + + +@pytest.mark.asyncio +async def test_fetch_user_gitea_returns_profile(): + p = _gitea_provider() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + "login": "alice", + "full_name": "Alice Smith", + "avatar_url": "https://git.example.com/avatars/alice.png", + }) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_response), + __aexit__=AsyncMock(return_value=False), + )) + + with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_session), + __aexit__=AsyncMock(return_value=False), + )): + profile = await oauth.fetch_user(p, "tok123") + + assert profile == { + "login": "alice", + "full_name": "Alice Smith", + "avatar_url": "https://git.example.com/avatars/alice.png", + } + + +@pytest.mark.asyncio +async def test_fetch_user_github_maps_name_field(): + p = _github_provider() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + "login": "bobgh", + "name": "Bob GitHub", + "avatar_url": "https://avatars.githubusercontent.com/u/1", + }) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_response), + __aexit__=AsyncMock(return_value=False), + )) + + with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_session), + __aexit__=AsyncMock(return_value=False), + )): + profile = await oauth.fetch_user(p, "ghtoken") + + assert profile["login"] == "bobgh" + assert profile["full_name"] == "Bob GitHub" + assert profile["avatar_url"] == "https://avatars.githubusercontent.com/u/1" + + +@pytest.mark.asyncio +async def test_fetch_user_nextcloud_nested_extraction(): + """Nextcloud profile is nested under ocs.data; avatar is absent.""" + p = _nextcloud_provider() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + "ocs": { + "meta": {"status": "ok", "statuscode": 200}, + "data": { + "id": "ncuser", + "display-name": "NC User", + "email": "nc@example.com", + }, + } + }) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_response), + __aexit__=AsyncMock(return_value=False), + )) + + with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_session), + __aexit__=AsyncMock(return_value=False), + )): + profile = await oauth.fetch_user(p, "nctoken") + + assert profile["login"] == "ncuser" + assert profile["full_name"] == "NC User" + assert profile["avatar_url"] == "" # Nextcloud has no avatar field + + +@pytest.mark.asyncio +async def test_fetch_user_raises_on_error_status(): + p = _gitea_provider() + mock_response = AsyncMock() + mock_response.status = 401 + mock_response.text = AsyncMock(return_value="unauthorized") + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_response), + __aexit__=AsyncMock(return_value=False), + )) + + with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_session), + __aexit__=AsyncMock(return_value=False), + )): + with pytest.raises(oauth.OAuthError): + await oauth.fetch_user(p, "badtoken") + + +def test_is_enabled_with_valid_provider(): + assert oauth.is_enabled(CFG_ON) is True + + +def test_is_enabled_false_when_no_providers(): + assert oauth.is_enabled(CFG_OFF) is False + + +def test_is_enabled_false_partial_config(): + assert oauth.is_enabled(CFG_PARTIAL) is False