feat: add OAuth2 CSRF state management
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -26,6 +26,29 @@ STATE_TTL = 600 # 10 minutes
|
|||||||
_states: dict[str, float] = {}
|
_states: dict[str, float] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def make_state() -> str:
|
||||||
|
"""Generate a CSRF state token, store it with TTL, and return it."""
|
||||||
|
_purge_states()
|
||||||
|
token = secrets.token_hex(32)
|
||||||
|
_states[token] = time.time() + STATE_TTL
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def validate_state(state: str) -> bool:
|
||||||
|
"""Return True if *state* is known and unexpired; always removes it."""
|
||||||
|
expiry = _states.pop(state, None)
|
||||||
|
if expiry is None:
|
||||||
|
return False
|
||||||
|
return time.time() < expiry
|
||||||
|
|
||||||
|
|
||||||
|
def _purge_states() -> None:
|
||||||
|
now = time.time()
|
||||||
|
expired = [k for k, exp in list(_states.items()) if exp < now]
|
||||||
|
for k in expired:
|
||||||
|
del _states[k]
|
||||||
|
|
||||||
|
|
||||||
class OAuthError(Exception):
|
class OAuthError(Exception):
|
||||||
"""Raised when the OAuth2 flow fails for any reason."""
|
"""Raised when the OAuth2 flow fails for any reason."""
|
||||||
|
|
||||||
|
|||||||
@@ -24,3 +24,35 @@ def test_is_enabled_false_when_no_oauth_key():
|
|||||||
|
|
||||||
def test_is_enabled_false_when_partial_config():
|
def test_is_enabled_false_when_partial_config():
|
||||||
assert oauth.is_enabled(CFG_PARTIAL) is False
|
assert oauth.is_enabled(CFG_PARTIAL) is False
|
||||||
|
|
||||||
|
|
||||||
|
import time as time_mod
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_state_returns_unique_tokens():
|
||||||
|
s1 = oauth.make_state()
|
||||||
|
s2 = oauth.make_state()
|
||||||
|
assert s1 != s2
|
||||||
|
assert len(s1) == 64 # 32 bytes hex
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_valid():
|
||||||
|
state = oauth.make_state()
|
||||||
|
assert oauth.validate_state(state) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_consumed_on_use():
|
||||||
|
state = oauth.make_state()
|
||||||
|
oauth.validate_state(state)
|
||||||
|
assert oauth.validate_state(state) is False # replay rejected
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_unknown():
|
||||||
|
assert oauth.validate_state("notastate") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_expired(monkeypatch):
|
||||||
|
state = oauth.make_state()
|
||||||
|
# Wind expiry into the past
|
||||||
|
monkeypatch.setitem(oauth._states, state, time_mod.time() - 1)
|
||||||
|
assert oauth.validate_state(state) is False
|
||||||
|
|||||||
Reference in New Issue
Block a user