feat: generic build_auth_url/exchange_code/fetch_user for multi-provider OAuth2
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+56
-44
@@ -1,17 +1,30 @@
|
|||||||
"""Gitea OAuth2 support.
|
"""OAuth2 provider support.
|
||||||
|
|
||||||
Config shape (in ~/.hb.yaml):
|
Config shape (in ~/.hb.yaml):
|
||||||
|
|
||||||
oauth:
|
oauth:
|
||||||
gitea:
|
my-gitea: # route slug → /login/oauth/my-gitea
|
||||||
url: https://git.example.com
|
type: gitea # "gitea" | "github" | "nextcloud"
|
||||||
|
# omit type to default to "gitea"
|
||||||
|
url: https://git.example.com # required for gitea and nextcloud
|
||||||
|
client_id: <client-id>
|
||||||
|
client_secret: <client-secret>
|
||||||
|
label: "Work Gitea" # optional display name on login button
|
||||||
|
logo: https://example.com/logo.png # optional logo URL
|
||||||
|
|
||||||
|
github:
|
||||||
|
type: github
|
||||||
client_id: <client-id>
|
client_id: <client-id>
|
||||||
client_secret: <client-secret>
|
client_secret: <client-secret>
|
||||||
|
|
||||||
Register a Gitea OAuth2 application at:
|
nextcloud:
|
||||||
Gitea → Settings → Applications → OAuth2
|
type: nextcloud
|
||||||
Set the redirect URI to:
|
url: https://cloud.example.com
|
||||||
https://<hbd-host>/login/oauth/gitea/callback
|
client_id: <client-id>
|
||||||
|
client_secret: <client-secret>
|
||||||
|
|
||||||
|
Register the OAuth app with each provider and set the redirect URI to:
|
||||||
|
https://<hbd-host>/login/oauth/<name>/callback
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -155,44 +168,32 @@ def get_providers(config: dict) -> list[ResolvedProvider]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _gitea_cfg(config: dict) -> dict:
|
|
||||||
"""Return the gitea sub-dict or {} if absent/incomplete."""
|
|
||||||
return config.get("oauth", {}).get("gitea", {})
|
|
||||||
|
|
||||||
|
|
||||||
def is_enabled(config: dict) -> bool:
|
def is_enabled(config: dict) -> bool:
|
||||||
"""Return True when all three required Gitea OAuth keys are present."""
|
"""Return True when at least one OAuth provider is fully configured."""
|
||||||
g = _gitea_cfg(config)
|
return bool(get_providers(config))
|
||||||
return bool(g.get("url") and g.get("client_id") and g.get("client_secret"))
|
|
||||||
|
|
||||||
|
|
||||||
def authorization_url(config: dict, state: str, redirect_uri: str) -> str:
|
def build_auth_url(provider: ResolvedProvider, state: str, redirect_uri: str) -> str:
|
||||||
"""Return the Gitea OAuth2 authorization URL to redirect the browser to."""
|
"""Return the provider's OAuth2 authorization URL to redirect the browser to."""
|
||||||
g = _gitea_cfg(config)
|
params: dict = {
|
||||||
if not (g.get("url") and g.get("client_id") and g.get("client_secret")):
|
"client_id": provider.client_id,
|
||||||
raise OAuthError("Gitea OAuth2 is not configured")
|
|
||||||
params = urllib.parse.urlencode({
|
|
||||||
"client_id": g["client_id"],
|
|
||||||
"redirect_uri": redirect_uri,
|
"redirect_uri": redirect_uri,
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
"scope": "user:email",
|
|
||||||
"state": state,
|
"state": state,
|
||||||
})
|
}
|
||||||
return f"{g['url'].rstrip('/')}/login/oauth/authorize?{params}"
|
if provider.scope:
|
||||||
|
params["scope"] = provider.scope
|
||||||
|
return f"{provider.authorize_url}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
|
|
||||||
async def exchange_code(config: dict, code: str, redirect_uri: str) -> str:
|
async def exchange_code(provider: ResolvedProvider, code: str, redirect_uri: str) -> str:
|
||||||
"""Exchange an authorization *code* for a Gitea access token.
|
"""Exchange an authorization *code* for an access token.
|
||||||
|
|
||||||
Returns the access token string. Raises OAuthError on any failure.
|
Returns the access token string. Raises OAuthError on any failure.
|
||||||
"""
|
"""
|
||||||
g = _gitea_cfg(config)
|
|
||||||
if not (g.get("url") and g.get("client_id") and g.get("client_secret")):
|
|
||||||
raise OAuthError("Gitea OAuth2 is not configured")
|
|
||||||
url = f"{g['url'].rstrip('/')}/login/oauth/access_token"
|
|
||||||
payload = {
|
payload = {
|
||||||
"client_id": g["client_id"],
|
"client_id": provider.client_id,
|
||||||
"client_secret": g["client_secret"],
|
"client_secret": provider.client_secret,
|
||||||
"code": code,
|
"code": code,
|
||||||
"grant_type": "authorization_code",
|
"grant_type": "authorization_code",
|
||||||
"redirect_uri": redirect_uri,
|
"redirect_uri": redirect_uri,
|
||||||
@@ -200,7 +201,11 @@ async def exchange_code(config: dict, code: str, redirect_uri: str) -> str:
|
|||||||
timeout = aiohttp.ClientTimeout(total=10)
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
async with session.post(url, json=payload, headers={"Accept": "application/json"}) as resp:
|
async with session.post(
|
||||||
|
provider.token_url,
|
||||||
|
json=payload,
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
text = await resp.text()
|
text = await resp.text()
|
||||||
raise OAuthError(f"Token exchange failed ({resp.status}): {text}")
|
raise OAuthError(f"Token exchange failed ({resp.status}): {text}")
|
||||||
@@ -213,28 +218,35 @@ async def exchange_code(config: dict, code: str, redirect_uri: str) -> str:
|
|||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
async def fetch_user(config: dict, token: str) -> dict:
|
async def fetch_user(provider: ResolvedProvider, token: str) -> dict:
|
||||||
"""Fetch the authenticated user's profile from Gitea.
|
"""Fetch the authenticated user's profile from the provider.
|
||||||
|
|
||||||
Returns a dict with keys: login, full_name, avatar_url.
|
Returns a dict with keys: login, full_name, avatar_url.
|
||||||
Raises OAuthError on any failure.
|
Raises OAuthError on any failure.
|
||||||
"""
|
"""
|
||||||
g = _gitea_cfg(config)
|
|
||||||
if not (g.get("url") and g.get("client_id") and g.get("client_secret")):
|
|
||||||
raise OAuthError("Gitea OAuth2 is not configured")
|
|
||||||
url = f"{g['url'].rstrip('/')}/api/v1/user"
|
|
||||||
timeout = aiohttp.ClientTimeout(total=10)
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
async with session.get(url, headers={"Authorization": f"token {token}"}) as resp:
|
async with session.get(
|
||||||
|
provider.profile_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
text = await resp.text()
|
text = await resp.text()
|
||||||
raise OAuthError(f"User fetch failed ({resp.status}): {text}")
|
raise OAuthError(f"User fetch failed ({resp.status}): {text}")
|
||||||
data = await resp.json()
|
data = await resp.json()
|
||||||
except aiohttp.ClientError as exc:
|
except aiohttp.ClientError as exc:
|
||||||
raise OAuthError(f"User fetch network error: {exc}") from exc
|
raise OAuthError(f"User fetch network error: {exc}") from exc
|
||||||
|
|
||||||
|
for key in provider.profile_data_path:
|
||||||
|
data = data.get(key, {})
|
||||||
|
|
||||||
|
avatar_field = provider.field_map.get("avatar")
|
||||||
return {
|
return {
|
||||||
"login": data.get("login", ""),
|
"login": data.get(provider.field_map["username"], ""),
|
||||||
"full_name": data.get("full_name", ""),
|
"full_name": data.get(provider.field_map["full_name"], ""),
|
||||||
"avatar_url": data.get("avatar_url", ""),
|
"avatar_url": data.get(avatar_field, "") if avatar_field else "",
|
||||||
}
|
}
|
||||||
|
|||||||
+269
-143
@@ -37,17 +37,6 @@ def reset_users_dict():
|
|||||||
users_mod.users = original
|
users_mod.users = original
|
||||||
|
|
||||||
|
|
||||||
def test_is_enabled_when_all_keys_present():
|
|
||||||
assert oauth.is_enabled(CFG_ON) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_enabled_false_when_no_oauth_key():
|
|
||||||
assert oauth.is_enabled(CFG_OFF) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_enabled_false_when_partial_config():
|
|
||||||
assert oauth.is_enabled(CFG_PARTIAL) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_make_state_returns_unique_tokens():
|
def test_make_state_returns_unique_tokens():
|
||||||
s1 = oauth.make_state()
|
s1 = oauth.make_state()
|
||||||
@@ -135,132 +124,6 @@ def test_provision_oauth_user_survives_config_reload():
|
|||||||
assert "oauthonly" in users_mod.users
|
assert "oauthonly" in users_mod.users
|
||||||
|
|
||||||
|
|
||||||
def test_authorization_url_shape():
|
|
||||||
state = "teststate"
|
|
||||||
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
|
||||||
url = oauth.authorization_url(CFG_ON, state, redirect_uri)
|
|
||||||
parsed = urlparse(url)
|
|
||||||
qs = parse_qs(parsed.query)
|
|
||||||
assert parsed.scheme == "https"
|
|
||||||
assert parsed.netloc == "git.example.com"
|
|
||||||
assert parsed.path == "/login/oauth/authorize"
|
|
||||||
assert qs["client_id"] == ["cid"]
|
|
||||||
assert qs["state"] == ["teststate"]
|
|
||||||
assert qs["redirect_uri"] == [redirect_uri]
|
|
||||||
assert qs["scope"] == ["user:email"]
|
|
||||||
assert qs["response_type"] == ["code"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exchange_code_returns_token():
|
|
||||||
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.status = 200
|
|
||||||
mock_response.json = AsyncMock(return_value={"access_token": "tok123"})
|
|
||||||
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session.post = MagicMock(return_value=AsyncMock(
|
|
||||||
__aenter__=AsyncMock(return_value=mock_response),
|
|
||||||
__aexit__=AsyncMock(return_value=False),
|
|
||||||
))
|
|
||||||
|
|
||||||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
||||||
__aenter__=AsyncMock(return_value=mock_session),
|
|
||||||
__aexit__=AsyncMock(return_value=False),
|
|
||||||
)):
|
|
||||||
token = await oauth.exchange_code(CFG_ON, "mycode", redirect_uri)
|
|
||||||
assert token == "tok123"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exchange_code_raises_on_error_status():
|
|
||||||
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.status = 401
|
|
||||||
mock_response.text = AsyncMock(return_value="unauthorized")
|
|
||||||
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session.post = MagicMock(return_value=AsyncMock(
|
|
||||||
__aenter__=AsyncMock(return_value=mock_response),
|
|
||||||
__aexit__=AsyncMock(return_value=False),
|
|
||||||
))
|
|
||||||
|
|
||||||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
||||||
__aenter__=AsyncMock(return_value=mock_session),
|
|
||||||
__aexit__=AsyncMock(return_value=False),
|
|
||||||
)):
|
|
||||||
with pytest.raises(oauth.OAuthError):
|
|
||||||
await oauth.exchange_code(CFG_ON, "badcode", redirect_uri)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_fetch_user_returns_profile():
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.status = 200
|
|
||||||
mock_response.json = AsyncMock(return_value={
|
|
||||||
"login": "alice",
|
|
||||||
"full_name": "Alice Smith",
|
|
||||||
"avatar_url": "https://git.example.com/avatars/alice.png",
|
|
||||||
})
|
|
||||||
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session.get = MagicMock(return_value=AsyncMock(
|
|
||||||
__aenter__=AsyncMock(return_value=mock_response),
|
|
||||||
__aexit__=AsyncMock(return_value=False),
|
|
||||||
))
|
|
||||||
|
|
||||||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
||||||
__aenter__=AsyncMock(return_value=mock_session),
|
|
||||||
__aexit__=AsyncMock(return_value=False),
|
|
||||||
)):
|
|
||||||
profile = await oauth.fetch_user(CFG_ON, "tok123")
|
|
||||||
assert profile == {
|
|
||||||
"login": "alice",
|
|
||||||
"full_name": "Alice Smith",
|
|
||||||
"avatar_url": "https://git.example.com/avatars/alice.png",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exchange_code_raises_when_no_access_token():
|
|
||||||
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.status = 200
|
|
||||||
mock_response.json = AsyncMock(return_value={"error": "bad_request"})
|
|
||||||
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session.post = MagicMock(return_value=AsyncMock(
|
|
||||||
__aenter__=AsyncMock(return_value=mock_response),
|
|
||||||
__aexit__=AsyncMock(return_value=False),
|
|
||||||
))
|
|
||||||
|
|
||||||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
||||||
__aenter__=AsyncMock(return_value=mock_session),
|
|
||||||
__aexit__=AsyncMock(return_value=False),
|
|
||||||
)):
|
|
||||||
with pytest.raises(oauth.OAuthError):
|
|
||||||
await oauth.exchange_code(CFG_ON, "mycode", redirect_uri)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_fetch_user_raises_on_error_status():
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.status = 401
|
|
||||||
mock_response.text = AsyncMock(return_value="unauthorized")
|
|
||||||
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session.get = MagicMock(return_value=AsyncMock(
|
|
||||||
__aenter__=AsyncMock(return_value=mock_response),
|
|
||||||
__aexit__=AsyncMock(return_value=False),
|
|
||||||
))
|
|
||||||
|
|
||||||
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
||||||
__aenter__=AsyncMock(return_value=mock_session),
|
|
||||||
__aexit__=AsyncMock(return_value=False),
|
|
||||||
)):
|
|
||||||
with pytest.raises(oauth.OAuthError):
|
|
||||||
await oauth.fetch_user(CFG_ON, "tok123")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Integration-style tests: callback logic chain
|
# Integration-style tests: callback logic chain
|
||||||
@@ -277,13 +140,12 @@ async def test_callback_invalid_state_rejects():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_full_oauth_flow_chain():
|
async def test_full_oauth_flow_chain():
|
||||||
"""Integration-style test: state → exchange → fetch → provision chain."""
|
"""Integration-style test: state → exchange → fetch → provision chain."""
|
||||||
|
p = _gitea_provider()
|
||||||
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||||||
|
|
||||||
# Step 1: create a state token
|
|
||||||
state = oauth.make_state()
|
state = oauth.make_state()
|
||||||
assert oauth.validate_state(state) is True # consumed; replay would return False
|
assert oauth.validate_state(state) is True
|
||||||
|
|
||||||
# Step 2: exchange code → token (mocked)
|
|
||||||
mock_token_response = AsyncMock()
|
mock_token_response = AsyncMock()
|
||||||
mock_token_response.status = 200
|
mock_token_response.status = 200
|
||||||
mock_token_response.json = AsyncMock(return_value={"access_token": "flow_token"})
|
mock_token_response.json = AsyncMock(return_value={"access_token": "flow_token"})
|
||||||
@@ -310,13 +172,12 @@ async def test_full_oauth_flow_chain():
|
|||||||
__aenter__=AsyncMock(return_value=mock_session),
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
__aexit__=AsyncMock(return_value=False),
|
__aexit__=AsyncMock(return_value=False),
|
||||||
)):
|
)):
|
||||||
token = await oauth.exchange_code(CFG_ON, "authcode", redirect_uri)
|
token = await oauth.exchange_code(p, "authcode", redirect_uri)
|
||||||
profile = await oauth.fetch_user(CFG_ON, token)
|
profile = await oauth.fetch_user(p, token)
|
||||||
|
|
||||||
assert token == "flow_token"
|
assert token == "flow_token"
|
||||||
assert profile["login"] == "flowuser"
|
assert profile["login"] == "flowuser"
|
||||||
|
|
||||||
# Step 3: provision user
|
|
||||||
_reset_users()
|
_reset_users()
|
||||||
user = users_mod.provision_oauth_user(
|
user = users_mod.provision_oauth_user(
|
||||||
profile["login"], profile["full_name"], profile["avatar_url"]
|
profile["login"], profile["full_name"], profile["avatar_url"]
|
||||||
@@ -473,3 +334,268 @@ def test_get_providers_skips_unknown_type(caplog):
|
|||||||
def test_get_providers_empty_config():
|
def test_get_providers_empty_config():
|
||||||
assert oauth.get_providers({}) == []
|
assert oauth.get_providers({}) == []
|
||||||
assert oauth.get_providers(CFG_OFF) == []
|
assert oauth.get_providers(CFG_OFF) == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_auth_url / exchange_code / fetch_user (generic, ResolvedProvider-based)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _gitea_provider() -> oauth.ResolvedProvider:
|
||||||
|
return oauth.get_providers(CFG_ON)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _github_provider() -> oauth.ResolvedProvider:
|
||||||
|
return oauth.get_providers(CFG_GITHUB)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _nextcloud_provider() -> oauth.ResolvedProvider:
|
||||||
|
return oauth.get_providers(CFG_NEXTCLOUD)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_auth_url_gitea():
|
||||||
|
p = _gitea_provider()
|
||||||
|
url = oauth.build_auth_url(p, "teststate", "https://hbd.example.com/login/oauth/gitea/callback")
|
||||||
|
parsed = urlparse(url)
|
||||||
|
qs = parse_qs(parsed.query)
|
||||||
|
assert parsed.netloc == "git.example.com"
|
||||||
|
assert parsed.path == "/login/oauth/authorize"
|
||||||
|
assert qs["client_id"] == ["cid"]
|
||||||
|
assert qs["state"] == ["teststate"]
|
||||||
|
assert qs["scope"] == ["user:email"]
|
||||||
|
assert qs["response_type"] == ["code"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_auth_url_github():
|
||||||
|
p = _github_provider()
|
||||||
|
url = oauth.build_auth_url(p, "st", "https://hbd.example.com/login/oauth/github/callback")
|
||||||
|
parsed = urlparse(url)
|
||||||
|
qs = parse_qs(parsed.query)
|
||||||
|
assert parsed.netloc == "github.com"
|
||||||
|
assert qs["scope"] == ["read:user"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_auth_url_nextcloud_no_scope_param():
|
||||||
|
"""Nextcloud scope is empty — the 'scope' key must be absent from the URL."""
|
||||||
|
p = _nextcloud_provider()
|
||||||
|
url = oauth.build_auth_url(p, "st", "https://hbd.example.com/login/oauth/nc/callback")
|
||||||
|
qs = parse_qs(urlparse(url).query)
|
||||||
|
assert "scope" not in qs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_generic_returns_token():
|
||||||
|
p = _gitea_provider()
|
||||||
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={"access_token": "tok123"})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
token = await oauth.exchange_code(p, "mycode", redirect_uri)
|
||||||
|
assert token == "tok123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_sends_accept_json():
|
||||||
|
"""Accept: application/json must be present for all providers (required by GitHub)."""
|
||||||
|
p = _github_provider()
|
||||||
|
captured_headers = {}
|
||||||
|
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={"access_token": "ghtoken"})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
|
||||||
|
def capture_post(url, **kwargs):
|
||||||
|
captured_headers.update(kwargs.get("headers", {}))
|
||||||
|
return AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session.post = capture_post
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
await oauth.exchange_code(p, "code", "https://hbd.example.com/login/oauth/github/callback")
|
||||||
|
|
||||||
|
assert captured_headers.get("Accept") == "application/json"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_raises_on_error_status():
|
||||||
|
p = _gitea_provider()
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 401
|
||||||
|
mock_response.text = AsyncMock(return_value="unauthorized")
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
with pytest.raises(oauth.OAuthError):
|
||||||
|
await oauth.exchange_code(p, "badcode", "https://hbd.example.com/login/oauth/gitea/callback")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_raises_when_no_access_token():
|
||||||
|
p = _gitea_provider()
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={"error": "bad_request"})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
with pytest.raises(oauth.OAuthError):
|
||||||
|
await oauth.exchange_code(p, "mycode", "https://hbd.example.com/login/oauth/gitea/callback")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_user_gitea_returns_profile():
|
||||||
|
p = _gitea_provider()
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={
|
||||||
|
"login": "alice",
|
||||||
|
"full_name": "Alice Smith",
|
||||||
|
"avatar_url": "https://git.example.com/avatars/alice.png",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
profile = await oauth.fetch_user(p, "tok123")
|
||||||
|
|
||||||
|
assert profile == {
|
||||||
|
"login": "alice",
|
||||||
|
"full_name": "Alice Smith",
|
||||||
|
"avatar_url": "https://git.example.com/avatars/alice.png",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_user_github_maps_name_field():
|
||||||
|
p = _github_provider()
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={
|
||||||
|
"login": "bobgh",
|
||||||
|
"name": "Bob GitHub",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/1",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
profile = await oauth.fetch_user(p, "ghtoken")
|
||||||
|
|
||||||
|
assert profile["login"] == "bobgh"
|
||||||
|
assert profile["full_name"] == "Bob GitHub"
|
||||||
|
assert profile["avatar_url"] == "https://avatars.githubusercontent.com/u/1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_user_nextcloud_nested_extraction():
|
||||||
|
"""Nextcloud profile is nested under ocs.data; avatar is absent."""
|
||||||
|
p = _nextcloud_provider()
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={
|
||||||
|
"ocs": {
|
||||||
|
"meta": {"status": "ok", "statuscode": 200},
|
||||||
|
"data": {
|
||||||
|
"id": "ncuser",
|
||||||
|
"display-name": "NC User",
|
||||||
|
"email": "nc@example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
profile = await oauth.fetch_user(p, "nctoken")
|
||||||
|
|
||||||
|
assert profile["login"] == "ncuser"
|
||||||
|
assert profile["full_name"] == "NC User"
|
||||||
|
assert profile["avatar_url"] == "" # Nextcloud has no avatar field
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_user_raises_on_error_status():
|
||||||
|
p = _gitea_provider()
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 401
|
||||||
|
mock_response.text = AsyncMock(return_value="unauthorized")
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
with pytest.raises(oauth.OAuthError):
|
||||||
|
await oauth.fetch_user(p, "badtoken")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_enabled_with_valid_provider():
|
||||||
|
assert oauth.is_enabled(CFG_ON) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_enabled_false_when_no_providers():
|
||||||
|
assert oauth.is_enabled(CFG_OFF) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_enabled_false_partial_config():
|
||||||
|
assert oauth.is_enabled(CFG_PARTIAL) is False
|
||||||
|
|||||||
Reference in New Issue
Block a user