diff --git a/hbd/server/oauth.py b/hbd/server/oauth.py index 7fe37d1..99427b1 100644 --- a/hbd/server/oauth.py +++ b/hbd/server/oauth.py @@ -43,6 +43,7 @@ def validate_state(state: str) -> bool: def _purge_states() -> None: + """Remove all expired CSRF state tokens from the in-memory store.""" now = time.time() expired = [k for k, exp in list(_states.items()) if exp < now] for k in expired: diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 9893a6e..92787e2 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -1,3 +1,7 @@ +import time as time_mod + +import pytest + from hbd.server import oauth @@ -14,6 +18,13 @@ CFG_ON = { CFG_PARTIAL = {"oauth": {"gitea": {"url": "https://git.example.com"}}} +@pytest.fixture(autouse=True) +def clear_oauth_states(): + oauth._states.clear() + yield + oauth._states.clear() + + def test_is_enabled_when_all_keys_present(): assert oauth.is_enabled(CFG_ON) is True @@ -26,9 +37,6 @@ def test_is_enabled_false_when_partial_config(): assert oauth.is_enabled(CFG_PARTIAL) is False -import time as time_mod - - def test_make_state_returns_unique_tokens(): s1 = oauth.make_state() s2 = oauth.make_state() @@ -54,5 +62,5 @@ def test_validate_state_unknown(): def test_validate_state_expired(monkeypatch): state = oauth.make_state() # Wind expiry into the past - monkeypatch.setitem(oauth._states, state, time_mod.time() - 1) + monkeypatch.setitem(oauth._states, state, time_mod.time() - 1000) assert oauth.validate_state(state) is False