479 lines
16 KiB
Python
479 lines
16 KiB
Python
import time as time_mod
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from urllib.parse import urlparse, parse_qs
|
|
|
|
import pytest
|
|
|
|
from hbd.server import oauth
|
|
from hbd.server import users as users_mod
|
|
from hbd.server.users import User
|
|
|
|
|
|
CFG_OFF = {}
|
|
CFG_ON = {
|
|
"oauth": {
|
|
"gitea": {
|
|
"url": "https://git.example.com",
|
|
"client_id": "cid",
|
|
"client_secret": "csec",
|
|
}
|
|
}
|
|
}
|
|
CFG_PARTIAL = {"oauth": {"gitea": {"url": "https://git.example.com"}}}
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_oauth_states():
|
|
oauth._states.clear()
|
|
yield
|
|
oauth._states.clear()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_users_dict():
|
|
original = dict(users_mod.users)
|
|
yield
|
|
users_mod.users = original
|
|
|
|
|
|
def test_is_enabled_when_all_keys_present():
|
|
assert oauth.is_enabled(CFG_ON) is True
|
|
|
|
|
|
def test_is_enabled_false_when_no_oauth_key():
|
|
assert oauth.is_enabled(CFG_OFF) is False
|
|
|
|
|
|
def test_is_enabled_false_when_partial_config():
|
|
assert oauth.is_enabled(CFG_PARTIAL) is False
|
|
|
|
|
|
def test_make_state_returns_unique_tokens():
|
|
s1 = oauth.make_state()
|
|
s2 = oauth.make_state()
|
|
assert s1 != s2
|
|
assert len(s1) == 64 # 32 bytes hex
|
|
|
|
|
|
def test_validate_state_valid():
|
|
state = oauth.make_state()
|
|
assert oauth.validate_state(state) is True
|
|
|
|
|
|
def test_validate_state_consumed_on_use():
|
|
state = oauth.make_state()
|
|
oauth.validate_state(state)
|
|
assert oauth.validate_state(state) is False # replay rejected
|
|
|
|
|
|
def test_validate_state_unknown():
|
|
assert oauth.validate_state("notastate") is False
|
|
|
|
|
|
def test_validate_state_expired(monkeypatch):
|
|
state = oauth.make_state()
|
|
# Wind expiry into the past
|
|
monkeypatch.setitem(oauth._states, state, time_mod.time() - 1000)
|
|
assert oauth.validate_state(state) is False
|
|
|
|
|
|
def _reset_users(entries=None):
|
|
users_mod.users = entries or {}
|
|
|
|
|
|
def test_provision_oauth_user_new():
|
|
_reset_users()
|
|
user = users_mod.provision_oauth_user("gituser", "Git User", "https://example.com/avatar.png")
|
|
assert user.username == "gituser"
|
|
assert user.full_name == "Git User"
|
|
assert user.avatar == "https://example.com/avatar.png"
|
|
assert user.admin is False
|
|
assert user.password_hash == ""
|
|
assert "gituser" in users_mod.users
|
|
|
|
|
|
def test_provision_oauth_user_no_password_login():
|
|
_reset_users()
|
|
user = users_mod.provision_oauth_user("gituser", "Git User", "")
|
|
assert user.check_password("anything") is False
|
|
|
|
|
|
def test_provision_oauth_user_existing_updates_profile():
|
|
existing = User(
|
|
username="alice",
|
|
full_name="Old Name",
|
|
avatar="old.png",
|
|
password_hash="pbkdf2:sha256:1:salt:abc",
|
|
admin=True,
|
|
notification_channels=["chan1"],
|
|
)
|
|
_reset_users({"alice": existing})
|
|
user = users_mod.provision_oauth_user("alice", "New Name", "new.png")
|
|
assert user.full_name == "New Name"
|
|
assert user.avatar == "new.png"
|
|
# Preserved
|
|
assert user.admin is True
|
|
assert user.password_hash == "pbkdf2:sha256:1:salt:abc"
|
|
assert user.notification_channels == ["chan1"]
|
|
|
|
|
|
def test_provision_oauth_user_does_not_overwrite_with_empty():
|
|
existing = User(username="bob", full_name="Bob", avatar="bob.png")
|
|
_reset_users({"bob": existing})
|
|
user = users_mod.provision_oauth_user("bob", "", "")
|
|
assert user.full_name == "Bob"
|
|
assert user.avatar == "bob.png"
|
|
|
|
|
|
def test_provision_oauth_user_survives_config_reload():
|
|
_reset_users()
|
|
users_mod.provision_oauth_user("oauthonly", "OAuth Only", "https://example.com/a.png")
|
|
assert "oauthonly" in users_mod.users
|
|
# Reload with empty config — OAuth user should survive
|
|
users_mod.load_users({})
|
|
assert "oauthonly" in users_mod.users
|
|
|
|
|
|
def test_authorization_url_shape():
|
|
state = "teststate"
|
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
|
url = oauth.authorization_url(CFG_ON, state, redirect_uri)
|
|
parsed = urlparse(url)
|
|
qs = parse_qs(parsed.query)
|
|
assert parsed.scheme == "https"
|
|
assert parsed.netloc == "git.example.com"
|
|
assert parsed.path == "/login/oauth/authorize"
|
|
assert qs["client_id"] == ["cid"]
|
|
assert qs["state"] == ["teststate"]
|
|
assert qs["redirect_uri"] == [redirect_uri]
|
|
assert qs["scope"] == ["user:email"]
|
|
assert qs["response_type"] == ["code"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exchange_code_returns_token():
|
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 200
|
|
mock_response.json = AsyncMock(return_value={"access_token": "tok123"})
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_response),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
))
|
|
|
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_session),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
)):
|
|
token = await oauth.exchange_code(CFG_ON, "mycode", redirect_uri)
|
|
assert token == "tok123"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exchange_code_raises_on_error_status():
|
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 401
|
|
mock_response.text = AsyncMock(return_value="unauthorized")
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_response),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
))
|
|
|
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_session),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
)):
|
|
with pytest.raises(oauth.OAuthError):
|
|
await oauth.exchange_code(CFG_ON, "badcode", redirect_uri)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_user_returns_profile():
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 200
|
|
mock_response.json = AsyncMock(return_value={
|
|
"login": "alice",
|
|
"full_name": "Alice Smith",
|
|
"avatar_url": "https://git.example.com/avatars/alice.png",
|
|
})
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_response),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
))
|
|
|
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_session),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
)):
|
|
profile = await oauth.fetch_user(CFG_ON, "tok123")
|
|
assert profile == {
|
|
"login": "alice",
|
|
"full_name": "Alice Smith",
|
|
"avatar_url": "https://git.example.com/avatars/alice.png",
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exchange_code_raises_when_no_access_token():
|
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 200
|
|
mock_response.json = AsyncMock(return_value={"error": "bad_request"})
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_response),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
))
|
|
|
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_session),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
)):
|
|
with pytest.raises(oauth.OAuthError):
|
|
await oauth.exchange_code(CFG_ON, "mycode", redirect_uri)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_user_raises_on_error_status():
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 401
|
|
mock_response.text = AsyncMock(return_value="unauthorized")
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_response),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
))
|
|
|
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_session),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
)):
|
|
with pytest.raises(oauth.OAuthError):
|
|
await oauth.fetch_user(CFG_ON, "tok123")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration-style tests: callback logic chain
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_callback_invalid_state_rejects():
|
|
"""Verify validate_state returns False for unknown state tokens."""
|
|
fake_state = "this-is-not-a-real-state"
|
|
assert oauth.validate_state(fake_state) is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_oauth_flow_chain():
|
|
"""Integration-style test: state → exchange → fetch → provision chain."""
|
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
|
|
|
# Step 1: create a state token
|
|
state = oauth.make_state()
|
|
assert oauth.validate_state(state) is True # consumed; replay would return False
|
|
|
|
# Step 2: exchange code → token (mocked)
|
|
mock_token_response = AsyncMock()
|
|
mock_token_response.status = 200
|
|
mock_token_response.json = AsyncMock(return_value={"access_token": "flow_token"})
|
|
|
|
mock_user_response = AsyncMock()
|
|
mock_user_response.status = 200
|
|
mock_user_response.json = AsyncMock(return_value={
|
|
"login": "flowuser",
|
|
"full_name": "Flow User",
|
|
"avatar_url": "https://git.example.com/avatars/flow.png",
|
|
})
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_token_response),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
))
|
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_user_response),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
))
|
|
|
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=mock_session),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
)):
|
|
token = await oauth.exchange_code(CFG_ON, "authcode", redirect_uri)
|
|
profile = await oauth.fetch_user(CFG_ON, token)
|
|
|
|
assert token == "flow_token"
|
|
assert profile["login"] == "flowuser"
|
|
|
|
# Step 3: provision user
|
|
_reset_users()
|
|
user = users_mod.provision_oauth_user(
|
|
profile["login"], profile["full_name"], profile["avatar_url"]
|
|
)
|
|
assert user.username == "flowuser"
|
|
assert user.check_password("anything") is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_providers()
|
|
# ---------------------------------------------------------------------------
|
|
|
|
CFG_GITHUB = {
|
|
"oauth": {
|
|
"github": {"type": "github", "client_id": "ghid", "client_secret": "ghs"},
|
|
}
|
|
}
|
|
|
|
CFG_NEXTCLOUD = {
|
|
"oauth": {
|
|
"nc": {
|
|
"type": "nextcloud",
|
|
"url": "https://nc.example.com",
|
|
"client_id": "ncid",
|
|
"client_secret": "ncs",
|
|
}
|
|
}
|
|
}
|
|
|
|
CFG_MULTI = {
|
|
"oauth": {
|
|
"mygitea": {
|
|
"type": "gitea",
|
|
"url": "https://git.example.com",
|
|
"client_id": "cid",
|
|
"client_secret": "cs",
|
|
"label": "Work Gitea",
|
|
"logo": "https://example.com/logo.png",
|
|
},
|
|
"github": {"type": "github", "client_id": "ghid", "client_secret": "ghs"},
|
|
"nc": {
|
|
"type": "nextcloud",
|
|
"url": "https://nc.example.com",
|
|
"client_id": "ncid",
|
|
"client_secret": "ncs",
|
|
},
|
|
}
|
|
}
|
|
|
|
|
|
def test_get_providers_backward_compat_no_type_field():
|
|
"""Old config without 'type' defaults to gitea."""
|
|
providers = oauth.get_providers(CFG_ON)
|
|
assert len(providers) == 1
|
|
p = providers[0]
|
|
assert p.name == "gitea"
|
|
assert p.type == "gitea"
|
|
assert p.label == "Gitea"
|
|
assert p.client_id == "cid"
|
|
assert p.authorize_url == "https://git.example.com/login/oauth/authorize"
|
|
assert p.token_url == "https://git.example.com/login/oauth/access_token"
|
|
assert p.profile_url == "https://git.example.com/api/v1/user"
|
|
assert p.scope == "user:email"
|
|
assert p.profile_data_path == []
|
|
|
|
|
|
def test_get_providers_multiple():
|
|
providers = oauth.get_providers(CFG_MULTI)
|
|
assert len(providers) == 3
|
|
names = [p.name for p in providers]
|
|
assert "mygitea" in names
|
|
assert "github" in names
|
|
assert "nc" in names
|
|
|
|
|
|
def test_get_providers_custom_label_and_logo():
|
|
providers = oauth.get_providers(CFG_MULTI)
|
|
gitea = next(p for p in providers if p.name == "mygitea")
|
|
assert gitea.label == "Work Gitea"
|
|
assert gitea.logo == "https://example.com/logo.png"
|
|
|
|
|
|
def test_get_providers_github_default_label():
|
|
providers = oauth.get_providers(CFG_GITHUB)
|
|
assert providers[0].label == "GitHub"
|
|
assert providers[0].logo == ""
|
|
|
|
|
|
def test_get_providers_github_fixed_urls():
|
|
providers = oauth.get_providers(CFG_GITHUB)
|
|
p = providers[0]
|
|
assert p.authorize_url == "https://github.com/login/oauth/authorize"
|
|
assert p.token_url == "https://github.com/login/oauth/access_token"
|
|
assert p.profile_url == "https://api.github.com/user"
|
|
assert p.scope == "read:user"
|
|
|
|
|
|
def test_get_providers_nextcloud_urls_and_path():
|
|
providers = oauth.get_providers(CFG_NEXTCLOUD)
|
|
p = providers[0]
|
|
assert p.authorize_url == "https://nc.example.com/apps/oauth2/authorize"
|
|
assert p.token_url == "https://nc.example.com/apps/oauth2/api/v1/token"
|
|
assert p.profile_url == "https://nc.example.com/ocs/v2.php/cloud/user?format=json"
|
|
assert p.profile_data_path == ["ocs", "data"]
|
|
assert p.scope == ""
|
|
|
|
|
|
def test_get_providers_skips_missing_client_id(caplog):
|
|
cfg = {"oauth": {"gitea": {"url": "https://git.example.com", "client_secret": "cs"}}}
|
|
import logging
|
|
with caplog.at_level(logging.WARNING, logger="hbd.server.oauth"):
|
|
result = oauth.get_providers(cfg)
|
|
assert result == []
|
|
assert caplog.text # a warning was logged
|
|
|
|
|
|
def test_get_providers_skips_missing_client_secret(caplog):
|
|
cfg = {"oauth": {"gitea": {"url": "https://git.example.com", "client_id": "cid"}}}
|
|
import logging
|
|
with caplog.at_level(logging.WARNING, logger="hbd.server.oauth"):
|
|
result = oauth.get_providers(cfg)
|
|
assert result == []
|
|
assert caplog.text # a warning was logged
|
|
|
|
|
|
def test_get_providers_skips_missing_url_for_gitea(caplog):
|
|
cfg = {"oauth": {"gitea": {"type": "gitea", "client_id": "cid", "client_secret": "cs"}}}
|
|
import logging
|
|
with caplog.at_level(logging.WARNING, logger="hbd.server.oauth"):
|
|
result = oauth.get_providers(cfg)
|
|
assert result == []
|
|
assert caplog.text # a warning was logged
|
|
|
|
|
|
def test_get_providers_skips_missing_url_for_nextcloud(caplog):
|
|
cfg = {"oauth": {"nc": {"type": "nextcloud", "client_id": "cid", "client_secret": "cs"}}}
|
|
import logging
|
|
with caplog.at_level(logging.WARNING, logger="hbd.server.oauth"):
|
|
result = oauth.get_providers(cfg)
|
|
assert result == []
|
|
assert caplog.text # a warning was logged
|
|
|
|
|
|
def test_get_providers_github_no_url_required():
|
|
providers = oauth.get_providers(CFG_GITHUB)
|
|
assert len(providers) == 1
|
|
|
|
|
|
def test_get_providers_skips_unknown_type(caplog):
|
|
cfg = {"oauth": {"mystery": {"type": "saml", "client_id": "cid", "client_secret": "cs"}}}
|
|
import logging
|
|
with caplog.at_level(logging.WARNING, logger="hbd.server.oauth"):
|
|
result = oauth.get_providers(cfg)
|
|
assert result == []
|
|
assert "saml" in caplog.text
|
|
|
|
|
|
def test_get_providers_empty_config():
|
|
assert oauth.get_providers({}) == []
|
|
assert oauth.get_providers(CFG_OFF) == []
|