diff --git a/hbd/server/oauth.py b/hbd/server/oauth.py index 517ff84..7fe37d1 100644 --- a/hbd/server/oauth.py +++ b/hbd/server/oauth.py @@ -26,6 +26,29 @@ STATE_TTL = 600 # 10 minutes _states: dict[str, float] = {} +def make_state() -> str: + """Generate a CSRF state token, store it with TTL, and return it.""" + _purge_states() + token = secrets.token_hex(32) + _states[token] = time.time() + STATE_TTL + return token + + +def validate_state(state: str) -> bool: + """Return True if *state* is known and unexpired; always removes it.""" + expiry = _states.pop(state, None) + if expiry is None: + return False + return time.time() < expiry + + +def _purge_states() -> None: + now = time.time() + expired = [k for k, exp in list(_states.items()) if exp < now] + for k in expired: + del _states[k] + + class OAuthError(Exception): """Raised when the OAuth2 flow fails for any reason.""" diff --git a/tests/test_oauth.py b/tests/test_oauth.py index f472b59..9893a6e 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -24,3 +24,35 @@ def test_is_enabled_false_when_no_oauth_key(): def test_is_enabled_false_when_partial_config(): assert oauth.is_enabled(CFG_PARTIAL) is False + + +import time as time_mod + + +def test_make_state_returns_unique_tokens(): + s1 = oauth.make_state() + s2 = oauth.make_state() + assert s1 != s2 + assert len(s1) == 64 # 32 bytes hex + + +def test_validate_state_valid(): + state = oauth.make_state() + assert oauth.validate_state(state) is True + + +def test_validate_state_consumed_on_use(): + state = oauth.make_state() + oauth.validate_state(state) + assert oauth.validate_state(state) is False # replay rejected + + +def test_validate_state_unknown(): + assert oauth.validate_state("notastate") is False + + +def test_validate_state_expired(monkeypatch): + state = oauth.make_state() + # Wind expiry into the past + monkeypatch.setitem(oauth._states, state, time_mod.time() - 1) + assert oauth.validate_state(state) is False