Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 58c2b9d996 | |||
| 2e8bcb630d | |||
| 338711181b | |||
| 43487f17e7 | |||
| 40205bf5c7 | |||
| b95f1a5bb7 | |||
| 12f7eb722b | |||
| 217bba1b76 | |||
| 967e05ed74 | |||
| c20245b0ab | |||
| b9db0c552e | |||
| 05045bafa2 | |||
| 39f1b5de30 | |||
| b06de6fdd3 | |||
| 940d0af35e | |||
| d6d31aa2e3 | |||
| 76edfe7577 | |||
| d190029728 | |||
| b8307e7a9d | |||
| a2fdf091f5 | |||
| 1914e6f28e | |||
| 82cbce9615 | |||
| dbb779b013 | |||
| ca908ee967 | |||
| 73c697b6c5 | |||
| 3e2357380b | |||
| cc4a103bae | |||
| 53fb10fdf5 | |||
| 2df2ad18c9 | |||
| b81a0d2a6c | |||
| 1a19088cfe | |||
| 172f6e950f | |||
| 4349ae217a | |||
| b3aa7b585f | |||
| 88a3c09b51 | |||
| 0504402a8a | |||
| ca58c18802 | |||
| 1ddc4b8132 | |||
| 5e1720ed32 | |||
| 77f127fe60 | |||
| 54fbd8d73d | |||
| 7ab17e26e2 | |||
| 28f5fa951c | |||
| 37f1c58969 | |||
| f006077a71 | |||
| d9fc8d632f | |||
| f640574e4f | |||
| 9a19424279 | |||
| ca8ba84e65 |
@@ -507,6 +507,9 @@ hbc --boot your-server.example.com
|
|||||||
|
|
||||||
# Verbose output
|
# Verbose output
|
||||||
hbc -v your-server.example.com
|
hbc -v your-server.example.com
|
||||||
|
|
||||||
|
# Send 'boot' and 'shutdown' messages on start and exit
|
||||||
|
hbc -b your-server.example.com
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also run it via the module entrypoint:
|
You can also run it via the module entrypoint:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ This guide explains how to create custom plugins for the Heartbeat monitoring sy
|
|||||||
- [Plugin Types](#plugin-types)
|
- [Plugin Types](#plugin-types)
|
||||||
- [Creating a Plugin](#creating-a-plugin)
|
- [Creating a Plugin](#creating-a-plugin)
|
||||||
- [Plugin Lifecycle](#plugin-lifecycle)
|
- [Plugin Lifecycle](#plugin-lifecycle)
|
||||||
|
- [Server-initiated InfoPlugin refresh](#server-initiated-infoplugin-refresh)
|
||||||
- [Configuration](#configuration)
|
- [Configuration](#configuration)
|
||||||
- [Best Practices](#best-practices)
|
- [Best Practices](#best-practices)
|
||||||
- [Examples](#examples)
|
- [Examples](#examples)
|
||||||
@@ -250,6 +251,28 @@ Understanding the plugin lifecycle helps you implement plugins correctly:
|
|||||||
└─> Plugin releases resources, closes connections
|
└─> Plugin releases resources, closes connections
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Server-initiated InfoPlugin refresh
|
||||||
|
|
||||||
|
When a heartbeat packet arrives from a host the server has no plugin data for (e.g. after a server restart), the server sets `request_update = 1` in the ACK reply. The client detects this flag and immediately re-runs all InfoPlugins — clearing their cached results first — then resends the data as PLG messages.
|
||||||
|
|
||||||
|
This means InfoPlugin data will always reach the server as soon as possible without requiring a client restart. No action is needed from plugin authors: the framework handles cache invalidation and re-collection automatically.
|
||||||
|
|
||||||
|
The lifecycle for this case looks like:
|
||||||
|
|
||||||
|
```
|
||||||
|
Server restarts, host reconnects
|
||||||
|
└─> hbd receives HTB with no existing plugin_data for host
|
||||||
|
└─> hbd sets request_update=1 in ACK
|
||||||
|
|
||||||
|
Client receives ACK
|
||||||
|
└─> Detects request_update flag
|
||||||
|
└─> Clears _cache on every registered InfoPlugin
|
||||||
|
└─> Calls collect() on each InfoPlugin
|
||||||
|
└─> Sends fresh PLG messages to server
|
||||||
|
```
|
||||||
|
|
||||||
|
If you write an `InfoPlugin` with side effects in `_collect_info()` (opening connections, writing files, etc.), be aware it may be called more than once per client session when this mechanism triggers.
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
### Plugin-Specific Configuration
|
### Plugin-Specific Configuration
|
||||||
|
|||||||
@@ -256,6 +256,56 @@ disk_monitor:
|
|||||||
operator: "<"
|
operator: "<"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### ZFS Monitor
|
||||||
|
|
||||||
|
ZFS pool health is checked automatically for every pool. A pool in any state
|
||||||
|
other than `ONLINE` (e.g. `DEGRADED`, `SUSPENDED`, `FAULTED`, `UNAVAIL`) raises
|
||||||
|
a **CRITICAL** alert by default — no configuration required.
|
||||||
|
|
||||||
|
The default threshold is equivalent to:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
zfs_monitor:
|
||||||
|
pools:
|
||||||
|
'*':
|
||||||
|
status:
|
||||||
|
warning: 1
|
||||||
|
critical: 2
|
||||||
|
operator: ">"
|
||||||
|
hysteresis: 0.0
|
||||||
|
display: "ZFS pool {pool_name} is {health}"
|
||||||
|
```
|
||||||
|
|
||||||
|
`'*'` matches every pool on the host. The notification message includes the pool
|
||||||
|
name and its current health string, e.g. `ZFS pool tank is DEGRADED`.
|
||||||
|
|
||||||
|
**Override for specific pools** — named pool entries take priority over `'*'`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
zfs_monitor:
|
||||||
|
pools:
|
||||||
|
# Suppress health alerts for a scratch pool (not mission-critical)
|
||||||
|
scratch:
|
||||||
|
status:
|
||||||
|
enabled: false
|
||||||
|
|
||||||
|
# Capacity threshold for a specific pool
|
||||||
|
tank:
|
||||||
|
capacity:
|
||||||
|
warning: 75.0
|
||||||
|
critical: 90.0
|
||||||
|
operator: ">"
|
||||||
|
hysteresis: 0.05
|
||||||
|
```
|
||||||
|
|
||||||
|
**Alert state paths** follow the pattern `zfs_monitor.<pool_name>.status`,
|
||||||
|
so acknowledgements and silences target individual pools:
|
||||||
|
|
||||||
|
```
|
||||||
|
zfs_monitor.tank.status
|
||||||
|
zfs_monitor.backup.status
|
||||||
|
```
|
||||||
|
|
||||||
### Network Monitor
|
### Network Monitor
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -46,6 +46,24 @@ default_owner: andreas # owns hosts with no explicit owner
|
|||||||
# falls back to the first admin user if omitted
|
# falls back to the first admin user if omitted
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Client-declared host ownership
|
||||||
|
|
||||||
|
A host can declare its own owner directly in the hbc or hbc_mini client configuration. This is useful for hosts that are not listed in the server config, or during initial setup before a server-side config entry has been created.
|
||||||
|
|
||||||
|
**`~/.hbc.yaml`** (hbc):
|
||||||
|
```yaml
|
||||||
|
owner: andreas
|
||||||
|
```
|
||||||
|
|
||||||
|
**`~/.hbc.json`** (hbc_mini):
|
||||||
|
```json
|
||||||
|
{ "owner": "andreas" }
|
||||||
|
```
|
||||||
|
|
||||||
|
When set, the value is included in the `os_info` plugin data sent to the server. The server applies it as `host.owner` the first time `os_info` arrives, provided no owner has been configured server-side for that host. Server-configured ownership always takes precedence.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
### Assigning roles to hosts
|
### Assigning roles to hosts
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -0,0 +1,781 @@
|
|||||||
|
# Gitea OAuth2 Authentication Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Add Gitea as an OAuth2 login provider that coexists with password auth, auto-provisioning new users on first login.
|
||||||
|
|
||||||
|
**Architecture:** A new `oauth.py` module owns all Gitea-specific logic (CSRF state, URL building, token exchange, user-info fetch). `users.py` gains one function to upsert an OAuth-sourced user. `http.py` gets two new route handlers and a small login-page change. No new dependencies — `aiohttp.ClientSession` is already used in the codebase.
|
||||||
|
|
||||||
|
**Tech Stack:** Python 3.12, aiohttp 3.x, pytest, pytest-asyncio
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## File Map
|
||||||
|
|
||||||
|
| Action | Path | Responsibility |
|
||||||
|
|--------|------|----------------|
|
||||||
|
| Modify | `hbd/server/config.py` | Add `"oauth": {}` default |
|
||||||
|
| Create | `hbd/server/oauth.py` | CSRF state, URL builder, token exchange, user-info fetch |
|
||||||
|
| Modify | `hbd/server/users.py` | Add `provision_oauth_user()` |
|
||||||
|
| Modify | `hbd/server/http.py` | Import oauth, two new routes, login page button |
|
||||||
|
| Create | `tests/test_oauth.py` | All new unit tests |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 1: Add config default and `is_enabled()`
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `hbd/server/config.py:34` (after the `"users"` line)
|
||||||
|
- Create: `hbd/server/oauth.py`
|
||||||
|
- Create: `tests/test_oauth.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write the failing test**
|
||||||
|
|
||||||
|
Create `tests/test_oauth.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pytest
|
||||||
|
from hbd.server import oauth
|
||||||
|
|
||||||
|
|
||||||
|
CFG_OFF = {}
|
||||||
|
CFG_ON = {
|
||||||
|
"oauth": {
|
||||||
|
"gitea": {
|
||||||
|
"url": "https://git.example.com",
|
||||||
|
"client_id": "cid",
|
||||||
|
"client_secret": "csec",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CFG_PARTIAL = {"oauth": {"gitea": {"url": "https://git.example.com"}}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_enabled_when_all_keys_present():
|
||||||
|
assert oauth.is_enabled(CFG_ON) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_enabled_false_when_no_oauth_key():
|
||||||
|
assert oauth.is_enabled(CFG_OFF) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_enabled_false_when_partial_config():
|
||||||
|
assert oauth.is_enabled(CFG_PARTIAL) is False
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Run to confirm failure**
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest tests/test_oauth.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `ModuleNotFoundError: No module named 'hbd.server.oauth'`
|
||||||
|
|
||||||
|
- [ ] **Step 3: Add config default**
|
||||||
|
|
||||||
|
In `hbd/server/config.py`, add after the `"default_owner"` line (currently line 35):
|
||||||
|
|
||||||
|
```python
|
||||||
|
# OAuth2 providers
|
||||||
|
"oauth": {}, # oauth.gitea.{url,client_id,client_secret}
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 4: Create `hbd/server/oauth.py` with `is_enabled`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""Gitea OAuth2 support.
|
||||||
|
|
||||||
|
Config shape (in ~/.hb.yaml):
|
||||||
|
|
||||||
|
oauth:
|
||||||
|
gitea:
|
||||||
|
url: https://git.example.com
|
||||||
|
client_id: <client-id>
|
||||||
|
client_secret: <client-secret>
|
||||||
|
|
||||||
|
Register a Gitea OAuth2 application at:
|
||||||
|
Gitea → Settings → Applications → OAuth2
|
||||||
|
Set the redirect URI to:
|
||||||
|
https://<hbd-host>/login/oauth/gitea/callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
STATE_TTL = 600 # 10 minutes
|
||||||
|
|
||||||
|
# state_token -> expiry timestamp
|
||||||
|
_states: dict[str, float] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthError(Exception):
|
||||||
|
"""Raised when the OAuth2 flow fails for any reason."""
|
||||||
|
|
||||||
|
|
||||||
|
def _gitea_cfg(config: dict) -> dict:
|
||||||
|
"""Return the gitea sub-dict or {} if absent/incomplete."""
|
||||||
|
return config.get("oauth", {}).get("gitea", {})
|
||||||
|
|
||||||
|
|
||||||
|
def is_enabled(config: dict) -> bool:
|
||||||
|
"""Return True when all three required Gitea OAuth keys are present."""
|
||||||
|
g = _gitea_cfg(config)
|
||||||
|
return bool(g.get("url") and g.get("client_id") and g.get("client_secret"))
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 5: Run to confirm tests pass**
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest tests/test_oauth.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: 3 passed
|
||||||
|
|
||||||
|
- [ ] **Step 6: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add hbd/server/config.py hbd/server/oauth.py tests/test_oauth.py
|
||||||
|
git commit -m "feat: add oauth module skeleton and is_enabled()"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 2: CSRF state management
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `hbd/server/oauth.py` (add `make_state`, `validate_state`)
|
||||||
|
- Modify: `tests/test_oauth.py` (add state tests)
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write the failing tests**
|
||||||
|
|
||||||
|
Append to `tests/test_oauth.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time as time_mod
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_state_returns_unique_tokens():
|
||||||
|
s1 = oauth.make_state()
|
||||||
|
s2 = oauth.make_state()
|
||||||
|
assert s1 != s2
|
||||||
|
assert len(s1) == 64 # 32 bytes hex
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_valid():
|
||||||
|
state = oauth.make_state()
|
||||||
|
assert oauth.validate_state(state) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_consumed_on_use():
|
||||||
|
state = oauth.make_state()
|
||||||
|
oauth.validate_state(state)
|
||||||
|
assert oauth.validate_state(state) is False # replay rejected
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_unknown():
|
||||||
|
assert oauth.validate_state("notastate") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_expired(monkeypatch):
|
||||||
|
state = oauth.make_state()
|
||||||
|
# Wind expiry into the past
|
||||||
|
monkeypatch.setitem(oauth._states, state, time_mod.time() - 1)
|
||||||
|
assert oauth.validate_state(state) is False
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Run to confirm failure**
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest tests/test_oauth.py -v -k "state"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `AttributeError: module 'hbd.server.oauth' has no attribute 'make_state'`
|
||||||
|
|
||||||
|
- [ ] **Step 3: Implement state functions**
|
||||||
|
|
||||||
|
Add to `hbd/server/oauth.py` after the `_states` dict definition:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def make_state() -> str:
|
||||||
|
"""Generate a CSRF state token, store it with TTL, and return it."""
|
||||||
|
_purge_states()
|
||||||
|
token = secrets.token_hex(32)
|
||||||
|
_states[token] = time.time() + STATE_TTL
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def validate_state(state: str) -> bool:
|
||||||
|
"""Return True if *state* is known and unexpired; always removes it."""
|
||||||
|
expiry = _states.pop(state, None)
|
||||||
|
if expiry is None:
|
||||||
|
return False
|
||||||
|
return time.time() < expiry
|
||||||
|
|
||||||
|
|
||||||
|
def _purge_states() -> None:
|
||||||
|
now = time.time()
|
||||||
|
expired = [k for k, exp in list(_states.items()) if exp < now]
|
||||||
|
for k in expired:
|
||||||
|
del _states[k]
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 4: Run to confirm tests pass**
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest tests/test_oauth.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: 8 passed
|
||||||
|
|
||||||
|
- [ ] **Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add hbd/server/oauth.py tests/test_oauth.py
|
||||||
|
git commit -m "feat: add OAuth2 CSRF state management"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 3: `provision_oauth_user` in users.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `hbd/server/users.py` (add `provision_oauth_user`)
|
||||||
|
- Modify: `tests/test_oauth.py` (add provisioning tests)
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write the failing tests**
|
||||||
|
|
||||||
|
Append to `tests/test_oauth.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from hbd.server import users as users_mod
|
||||||
|
from hbd.server.users import User
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_users(entries=None):
|
||||||
|
users_mod.users = entries or {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_oauth_user_new():
|
||||||
|
_reset_users()
|
||||||
|
user = users_mod.provision_oauth_user("gituser", "Git User", "https://example.com/avatar.png")
|
||||||
|
assert user.username == "gituser"
|
||||||
|
assert user.full_name == "Git User"
|
||||||
|
assert user.avatar == "https://example.com/avatar.png"
|
||||||
|
assert user.admin is False
|
||||||
|
assert user.password_hash == ""
|
||||||
|
assert "gituser" in users_mod.users
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_oauth_user_no_password_login():
|
||||||
|
_reset_users()
|
||||||
|
user = users_mod.provision_oauth_user("gituser", "Git User", "")
|
||||||
|
assert user.check_password("anything") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_oauth_user_existing_updates_profile():
|
||||||
|
existing = User(
|
||||||
|
username="alice",
|
||||||
|
full_name="Old Name",
|
||||||
|
avatar="old.png",
|
||||||
|
password_hash="pbkdf2:sha256:1:salt:abc",
|
||||||
|
admin=True,
|
||||||
|
notification_channels=["chan1"],
|
||||||
|
)
|
||||||
|
_reset_users({"alice": existing})
|
||||||
|
user = users_mod.provision_oauth_user("alice", "New Name", "new.png")
|
||||||
|
assert user.full_name == "New Name"
|
||||||
|
assert user.avatar == "new.png"
|
||||||
|
# Preserved
|
||||||
|
assert user.admin is True
|
||||||
|
assert user.password_hash == "pbkdf2:sha256:1:salt:abc"
|
||||||
|
assert user.notification_channels == ["chan1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_oauth_user_does_not_overwrite_with_empty():
|
||||||
|
existing = User(username="bob", full_name="Bob", avatar="bob.png")
|
||||||
|
_reset_users({"bob": existing})
|
||||||
|
user = users_mod.provision_oauth_user("bob", "", "")
|
||||||
|
assert user.full_name == "Bob"
|
||||||
|
assert user.avatar == "bob.png"
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Run to confirm failure**
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest tests/test_oauth.py -v -k "provision"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `AttributeError: module 'hbd.server.users' has no attribute 'provision_oauth_user'`
|
||||||
|
|
||||||
|
- [ ] **Step 3: Implement `provision_oauth_user`**
|
||||||
|
|
||||||
|
Add to `hbd/server/users.py` after the `authenticate()` function (after line 187):
|
||||||
|
|
||||||
|
```python
|
||||||
|
def provision_oauth_user(username: str, full_name: str, avatar: str) -> "User":
|
||||||
|
"""Create or update a user sourced from an OAuth2 provider.
|
||||||
|
|
||||||
|
New users are inserted with no password_hash — they can only authenticate
|
||||||
|
via OAuth. Existing users (e.g. defined in config with a password) have
|
||||||
|
their display name and avatar refreshed; all other attributes are preserved.
|
||||||
|
"""
|
||||||
|
user = users.get(username)
|
||||||
|
if user is None:
|
||||||
|
user = User(username=username, full_name=full_name, avatar=avatar)
|
||||||
|
users[username] = user
|
||||||
|
logger.info("Provisioned OAuth user %r", username)
|
||||||
|
else:
|
||||||
|
if full_name:
|
||||||
|
user.full_name = full_name
|
||||||
|
if avatar:
|
||||||
|
user.avatar = avatar
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 4: Run to confirm tests pass**
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest tests/test_oauth.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: 12 passed
|
||||||
|
|
||||||
|
- [ ] **Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add hbd/server/users.py tests/test_oauth.py
|
||||||
|
git commit -m "feat: add provision_oauth_user() to users module"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 4: URL builder, token exchange, and user-info fetch
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `hbd/server/oauth.py` (add `authorization_url`, `exchange_code`, `fetch_user`)
|
||||||
|
- Modify: `tests/test_oauth.py` (add async tests with mocked HTTP)
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write the failing tests**
|
||||||
|
|
||||||
|
Append to `tests/test_oauth.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
|
||||||
|
|
||||||
|
def test_authorization_url_shape():
|
||||||
|
state = "teststate"
|
||||||
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||||||
|
url = oauth.authorization_url(CFG_ON, state, redirect_uri)
|
||||||
|
parsed = urlparse(url)
|
||||||
|
qs = parse_qs(parsed.query)
|
||||||
|
assert parsed.scheme == "https"
|
||||||
|
assert parsed.netloc == "git.example.com"
|
||||||
|
assert parsed.path == "/login/oauth/authorize"
|
||||||
|
assert qs["client_id"] == ["cid"]
|
||||||
|
assert qs["state"] == ["teststate"]
|
||||||
|
assert qs["redirect_uri"] == [redirect_uri]
|
||||||
|
assert qs["scope"] == ["user:email"]
|
||||||
|
assert qs["response_type"] == ["code"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_returns_token():
|
||||||
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={"access_token": "tok123"})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
token = await oauth.exchange_code(CFG_ON, "mycode", redirect_uri)
|
||||||
|
assert token == "tok123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_raises_on_error_status():
|
||||||
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 401
|
||||||
|
mock_response.text = AsyncMock(return_value="unauthorized")
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
with pytest.raises(oauth.OAuthError):
|
||||||
|
await oauth.exchange_code(CFG_ON, "badcode", redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_user_returns_profile():
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={
|
||||||
|
"login": "alice",
|
||||||
|
"full_name": "Alice Smith",
|
||||||
|
"avatar_url": "https://git.example.com/avatars/alice.png",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
profile = await oauth.fetch_user(CFG_ON, "tok123")
|
||||||
|
assert profile == {
|
||||||
|
"login": "alice",
|
||||||
|
"full_name": "Alice Smith",
|
||||||
|
"avatar_url": "https://git.example.com/avatars/alice.png",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Run to confirm failure**
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest tests/test_oauth.py -v -k "url or exchange or fetch"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `AttributeError: module 'hbd.server.oauth' has no attribute 'authorization_url'`
|
||||||
|
|
||||||
|
- [ ] **Step 3: Implement the three functions**
|
||||||
|
|
||||||
|
Add to `hbd/server/oauth.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
|
||||||
|
def authorization_url(config: dict, state: str, redirect_uri: str) -> str:
|
||||||
|
"""Return the Gitea OAuth2 authorization URL to redirect the browser to."""
|
||||||
|
g = _gitea_cfg(config)
|
||||||
|
params = urllib.parse.urlencode({
|
||||||
|
"client_id": g["client_id"],
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": "user:email",
|
||||||
|
"state": state,
|
||||||
|
})
|
||||||
|
return f"{g['url'].rstrip('/')}/login/oauth/authorize?{params}"
|
||||||
|
|
||||||
|
|
||||||
|
async def exchange_code(config: dict, code: str, redirect_uri: str) -> str:
|
||||||
|
"""Exchange an authorization *code* for a Gitea access token.
|
||||||
|
|
||||||
|
Returns the access token string. Raises OAuthError on any failure.
|
||||||
|
"""
|
||||||
|
g = _gitea_cfg(config)
|
||||||
|
url = f"{g['url'].rstrip('/')}/login/oauth/access_token"
|
||||||
|
payload = {
|
||||||
|
"client_id": g["client_id"],
|
||||||
|
"client_secret": g["client_secret"],
|
||||||
|
"code": code,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.post(url, json=payload, headers={"Accept": "application/json"}) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise OAuthError(f"Token exchange failed ({resp.status}): {text}")
|
||||||
|
data = await resp.json()
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
raise OAuthError(f"Token exchange network error: {exc}") from exc
|
||||||
|
token = data.get("access_token")
|
||||||
|
if not token:
|
||||||
|
raise OAuthError(f"No access_token in response: {data}")
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_user(config: dict, token: str) -> dict:
|
||||||
|
"""Fetch the authenticated user's profile from Gitea.
|
||||||
|
|
||||||
|
Returns a dict with keys: login, full_name, avatar_url.
|
||||||
|
Raises OAuthError on any failure.
|
||||||
|
"""
|
||||||
|
g = _gitea_cfg(config)
|
||||||
|
url = f"{g['url'].rstrip('/')}/api/v1/user"
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.get(url, headers={"Authorization": f"token {token}"}) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise OAuthError(f"User fetch failed ({resp.status}): {text}")
|
||||||
|
data = await resp.json()
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
raise OAuthError(f"User fetch network error: {exc}") from exc
|
||||||
|
return {
|
||||||
|
"login": data.get("login", ""),
|
||||||
|
"full_name": data.get("full_name", ""),
|
||||||
|
"avatar_url": data.get("avatar_url", ""),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Also add `import urllib.parse` at the top of `oauth.py` (alongside the existing imports).
|
||||||
|
|
||||||
|
- [ ] **Step 4: Run to confirm tests pass**
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest tests/test_oauth.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: 17 passed
|
||||||
|
|
||||||
|
- [ ] **Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add hbd/server/oauth.py tests/test_oauth.py
|
||||||
|
git commit -m "feat: add authorization_url, exchange_code, fetch_user to oauth module"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 5: HTTP routes — redirect and callback
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `hbd/server/http.py`
|
||||||
|
|
||||||
|
`http.py` defines all handlers inside `async def start(...)`. The two new handlers go in the same block, just before the `app = web.Application()` line (~line 900). The import goes at the top of the file.
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add the import**
|
||||||
|
|
||||||
|
In `hbd/server/http.py`, add after the existing local imports (after `from . import users as users_mod`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from . import oauth as oauth_mod
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Add the two route handlers**
|
||||||
|
|
||||||
|
In `hbd/server/http.py`, add the two handlers immediately before the `app = web.Application()` line:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def oauth_gitea_redirect(request):
|
||||||
|
"""GET /login/oauth/gitea — kick off the Gitea OAuth2 flow."""
|
||||||
|
if not oauth_mod.is_enabled(config):
|
||||||
|
return web.Response(status=404, text="OAuth not configured")
|
||||||
|
state = oauth_mod.make_state()
|
||||||
|
redirect_uri = f"{request.url.origin()}/login/oauth/gitea/callback"
|
||||||
|
raise web.HTTPFound(oauth_mod.authorization_url(config, state, redirect_uri))
|
||||||
|
|
||||||
|
async def oauth_gitea_callback(request):
|
||||||
|
"""GET /login/oauth/gitea/callback — handle Gitea's redirect back."""
|
||||||
|
if not oauth_mod.is_enabled(config):
|
||||||
|
return web.Response(status=404, text="OAuth not configured")
|
||||||
|
code = request.rel_url.query.get("code", "")
|
||||||
|
state = request.rel_url.query.get("state", "")
|
||||||
|
if not code or not state:
|
||||||
|
return web.Response(status=400, text="Missing code or state")
|
||||||
|
if not oauth_mod.validate_state(state):
|
||||||
|
raise web.HTTPFound("/login?error=1")
|
||||||
|
redirect_uri = f"{request.url.origin()}/login/oauth/gitea/callback"
|
||||||
|
try:
|
||||||
|
token = await oauth_mod.exchange_code(config, code, redirect_uri)
|
||||||
|
profile = await oauth_mod.fetch_user(config, token)
|
||||||
|
except oauth_mod.OAuthError as exc:
|
||||||
|
logger.warning("OAuth error: %s", exc)
|
||||||
|
raise web.HTTPFound("/login?error=1")
|
||||||
|
user = users_mod.provision_oauth_user(
|
||||||
|
profile["login"],
|
||||||
|
profile["full_name"],
|
||||||
|
profile["avatar_url"],
|
||||||
|
)
|
||||||
|
session_token = users_mod.create_session(user.username)
|
||||||
|
resp = web.HTTPFound("/")
|
||||||
|
resp.set_cookie(
|
||||||
|
SESSION_COOKIE,
|
||||||
|
session_token,
|
||||||
|
max_age=users_mod.SESSION_TTL,
|
||||||
|
httponly=True,
|
||||||
|
samesite="Lax",
|
||||||
|
)
|
||||||
|
raise resp
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 3: Register the routes**
|
||||||
|
|
||||||
|
In `hbd/server/http.py`, add to the route list after the existing auth routes (after `web.post("/api/0/auth/logout", api_logout)`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
web.get("/login/oauth/gitea", oauth_gitea_redirect),
|
||||||
|
web.get("/login/oauth/gitea/callback", oauth_gitea_callback),
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 4: Manual smoke test**
|
||||||
|
|
||||||
|
Start the server locally with OAuth configured in `~/.hb.yaml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
oauth:
|
||||||
|
gitea:
|
||||||
|
url: https://your-gitea-instance.example.com
|
||||||
|
client_id: your-client-id
|
||||||
|
client_secret: your-client-secret
|
||||||
|
```
|
||||||
|
|
||||||
|
Visit `http://localhost:50004/login/oauth/gitea` — confirm you are redirected to Gitea's authorization page.
|
||||||
|
|
||||||
|
- [ ] **Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add hbd/server/http.py
|
||||||
|
git commit -m "feat: add Gitea OAuth2 redirect and callback routes"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 6: Login page — "Sign in with Gitea" button
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `hbd/server/http.py` (update `login_page` handler, ~line 625)
|
||||||
|
|
||||||
|
- [ ] **Step 1: Replace the login page HTML**
|
||||||
|
|
||||||
|
In `hbd/server/http.py`, find the `html = f"""` block inside `login_page` and replace it with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
gitea_button = ""
|
||||||
|
if oauth_mod.is_enabled(config):
|
||||||
|
gitea_url = _gitea_cfg_url(config)
|
||||||
|
gitea_button = f"""
|
||||||
|
<div class="divider">or</div>
|
||||||
|
<a href="/login/oauth/gitea" class="gitea-btn">
|
||||||
|
Sign in with Gitea
|
||||||
|
</a>"""
|
||||||
|
|
||||||
|
html = f"""<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<title>Heartbeat — Login</title>
|
||||||
|
<style>
|
||||||
|
body {{ font-family: sans-serif; background: #f5f5f5; display: flex;
|
||||||
|
justify-content: center; align-items: center; height: 100vh; margin: 0; }}
|
||||||
|
.box {{ background: #fff; padding: 2em 2.5em; border-radius: 8px;
|
||||||
|
box-shadow: 0 2px 12px rgba(0,0,0,.15); min-width: 300px; }}
|
||||||
|
h2 {{ margin: 0 0 1.2em; color: #333; font-size: 1.4em; }}
|
||||||
|
label {{ display: block; margin-bottom: .3em; font-size: .9em; color: #555; }}
|
||||||
|
input {{ width: 100%; padding: .5em .7em; border: 1px solid #ccc;
|
||||||
|
border-radius: 4px; font-size: 1em; box-sizing: border-box; }}
|
||||||
|
button {{ margin-top: 1.2em; width: 100%; padding: .6em; background: #0066cc;
|
||||||
|
color: #fff; border: none; border-radius: 4px; font-size: 1em; cursor: pointer; }}
|
||||||
|
button:hover {{ background: #0055aa; }}
|
||||||
|
.error {{ color: #c00; font-size: .9em; margin-bottom: .8em; }}
|
||||||
|
.field {{ margin-bottom: .9em; }}
|
||||||
|
.divider {{ text-align: center; margin: 1.2em 0 .8em; color: #999;
|
||||||
|
font-size: .85em; border-top: 1px solid #eee; padding-top: .8em; }}
|
||||||
|
.gitea-btn {{ display: block; width: 100%; padding: .6em; background: #609926;
|
||||||
|
color: #fff; border-radius: 4px; font-size: 1em; text-align: center;
|
||||||
|
text-decoration: none; box-sizing: border-box; }}
|
||||||
|
.gitea-btn:hover {{ background: #4e7d1e; }}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="box">
|
||||||
|
<h2>Heartbeat</h2>
|
||||||
|
{'<p class="error">Invalid username, password, or OAuth error.</p>' if error else ''}
|
||||||
|
<form method="post">
|
||||||
|
<div class="field"><label>Username</label><input name="username" autofocus></div>
|
||||||
|
<div class="field"><label>Password</label><input name="password" type="password"></div>
|
||||||
|
<button type="submit">Sign in</button>
|
||||||
|
</form>{gitea_button}
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>"""
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Add the `_gitea_cfg_url` helper**
|
||||||
|
|
||||||
|
Add this small helper in `hbd/server/http.py` just before the `login_page` handler (around line 600) so the template can read the Gitea display URL without importing internal oauth details:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _gitea_cfg_url(config: dict) -> str:
|
||||||
|
return config.get("oauth", {}).get("gitea", {}).get("url", "")
|
||||||
|
```
|
||||||
|
|
||||||
|
Also update the `login_page` handler's `error` logic to show the error when the `?error=1` query param is present (set by the callback on OAuth failure):
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def login_page(request):
|
||||||
|
"""GET /login — show login form; POST /login — process and redirect."""
|
||||||
|
if not users_mod.users_enabled():
|
||||||
|
raise web.HTTPFound("/")
|
||||||
|
|
||||||
|
error = ""
|
||||||
|
if request.method == "POST":
|
||||||
|
form = await request.post()
|
||||||
|
username = form.get("username", "")
|
||||||
|
password = form.get("password", "")
|
||||||
|
user = users_mod.authenticate(username, password)
|
||||||
|
if user:
|
||||||
|
token = users_mod.create_session(username)
|
||||||
|
redirect_to = request.rel_url.query.get("next", "/")
|
||||||
|
resp = web.HTTPFound(redirect_to)
|
||||||
|
resp.set_cookie(
|
||||||
|
SESSION_COOKIE,
|
||||||
|
token,
|
||||||
|
max_age=users_mod.SESSION_TTL,
|
||||||
|
httponly=True,
|
||||||
|
samesite="Lax",
|
||||||
|
)
|
||||||
|
raise resp
|
||||||
|
error = "Invalid username or password."
|
||||||
|
elif request.rel_url.query.get("error"):
|
||||||
|
error = "Sign-in failed. Please try again."
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 3: Manual verification**
|
||||||
|
|
||||||
|
Start the server with OAuth configured. Visit `/login`. Confirm:
|
||||||
|
- The "Sign in with Gitea" button appears (green, below a divider)
|
||||||
|
- Clicking it redirects to Gitea
|
||||||
|
- After authorising on Gitea, you are redirected back and land on `/` with a valid session cookie
|
||||||
|
|
||||||
|
Without OAuth configured, confirm the button does not appear.
|
||||||
|
|
||||||
|
- [ ] **Step 4: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add hbd/server/http.py
|
||||||
|
git commit -m "feat: add Sign in with Gitea button to login page"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Self-Review Notes
|
||||||
|
|
||||||
|
- All 5 spec requirements covered: coexist ✓, auto-provision ✓, regular user ✓, any Gitea user ✓, config-driven ✓
|
||||||
|
- `exchange_code` signature in Task 4 matches usage in Task 5 (`config, code, redirect_uri`) ✓
|
||||||
|
- `fetch_user` returns `{login, full_name, avatar_url}` — matched in callback handler ✓
|
||||||
|
- `validate_state` removes state on use (replay protection) ✓
|
||||||
|
- `provision_oauth_user` skips empty strings so existing avatar/name aren't erased ✓
|
||||||
|
- `_gitea_cfg_url` is a plain `def`, not `async` — safe to call in template prep ✓
|
||||||
@@ -0,0 +1,184 @@
|
|||||||
|
# Gitea OAuth2 Authentication — Design Spec
|
||||||
|
|
||||||
|
Date: 2026-05-08
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Add Gitea as an OAuth2 login provider alongside the existing username/password
|
||||||
|
authentication. Any user on the configured Gitea instance can sign in; their
|
||||||
|
local account is auto-provisioned on first login as a regular (non-admin) user.
|
||||||
|
Password login continues to work unchanged.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Config
|
||||||
|
|
||||||
|
A new optional `oauth.gitea` block in `~/.hb.yaml`. OAuth is disabled when the
|
||||||
|
block is absent or any of the three required keys is missing.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
oauth:
|
||||||
|
gitea:
|
||||||
|
url: https://git.example.com # Gitea base URL, no trailing slash
|
||||||
|
client_id: <gitea-app-client-id>
|
||||||
|
client_secret: <gitea-app-client-secret>
|
||||||
|
```
|
||||||
|
|
||||||
|
**Gitea setup:** Create an OAuth2 application in Gitea under
|
||||||
|
*Settings → Applications → OAuth2*. Set the redirect URI to
|
||||||
|
`https://<hbd-host>/login/oauth/gitea/callback`.
|
||||||
|
|
||||||
|
`config.py` default:
|
||||||
|
|
||||||
|
```python
|
||||||
|
"oauth": {},
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## New module: `hbd/server/oauth.py`
|
||||||
|
|
||||||
|
Owns all OAuth2 logic. No new dependencies — uses `aiohttp.ClientSession`
|
||||||
|
already present in the codebase.
|
||||||
|
|
||||||
|
### CSRF state store
|
||||||
|
|
||||||
|
```python
|
||||||
|
# state -> expires (float)
|
||||||
|
_states: dict[str, float] = {}
|
||||||
|
STATE_TTL = 600 # 10 minutes
|
||||||
|
```
|
||||||
|
|
||||||
|
`_states` is an in-memory dict. Entries are created on redirect and deleted on
|
||||||
|
use or expiry. A purge runs on every new state generation.
|
||||||
|
|
||||||
|
### Public API
|
||||||
|
|
||||||
|
| Function | Description |
|
||||||
|
|---|---|
|
||||||
|
| `is_enabled(config)` | Returns `True` when url, client_id, and client_secret are all set |
|
||||||
|
| `make_state()` | Generates a random state token, stores it with TTL, returns it |
|
||||||
|
| `validate_state(state)` | Returns `True` and removes the state if valid and unexpired |
|
||||||
|
| `authorization_url(config, state, redirect_uri)` | Builds the Gitea `/login/oauth/authorize` redirect URL with `client_id`, `redirect_uri`, `scope=user:email`, `state` |
|
||||||
|
| `exchange_code(config, code, redirect_uri)` async | POSTs to Gitea `/login/oauth/access_token` with code and redirect_uri, returns the access token string or raises `OAuthError` |
|
||||||
|
| `fetch_user(config, token)` async | GETs Gitea `/api/v1/user` with Bearer token, returns `{"login", "full_name", "avatar_url"}` or raises `OAuthError` |
|
||||||
|
|
||||||
|
### Error handling
|
||||||
|
|
||||||
|
`OAuthError(message)` is a module-level exception. The callback route catches it
|
||||||
|
and renders the login page with an error message — identical to an invalid
|
||||||
|
password error in UX terms.
|
||||||
|
|
||||||
|
Network timeouts use a 10-second `aiohttp` timeout. Any non-2xx response from
|
||||||
|
Gitea raises `OAuthError`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Change: `hbd/server/users.py`
|
||||||
|
|
||||||
|
One new function added to the public API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def provision_oauth_user(username: str, full_name: str, avatar: str) -> User:
|
||||||
|
```
|
||||||
|
|
||||||
|
- If the username does not exist in the live `users` dict, creates a `User`
|
||||||
|
with no `password_hash` (so password login is impossible for this account)
|
||||||
|
and inserts it.
|
||||||
|
- If the username already exists (e.g. was defined in config with a password),
|
||||||
|
updates `full_name` and `avatar` from the OAuth profile and returns the
|
||||||
|
existing user unchanged in all other respects (preserving admin flag,
|
||||||
|
notification channels, etc.).
|
||||||
|
- Logs a one-line INFO message on first provision.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Changes: `hbd/server/http.py`
|
||||||
|
|
||||||
|
### Two new route handlers
|
||||||
|
|
||||||
|
**`GET /login/oauth/gitea`**
|
||||||
|
|
||||||
|
1. Checks `oauth.is_enabled(config)` — returns 404 if not.
|
||||||
|
2. Calls `oauth.make_state()`.
|
||||||
|
3. Constructs `redirect_uri` as `{request.url.origin()}/login/oauth/gitea/callback` using aiohttp's `request.url.origin()`.
|
||||||
|
4. Redirects the browser to `oauth.authorization_url(config, state, redirect_uri)`.
|
||||||
|
|
||||||
|
**`GET /login/oauth/gitea/callback`**
|
||||||
|
|
||||||
|
1. Reads `code` and `state` query params; returns 400 if either is missing.
|
||||||
|
2. Calls `oauth.validate_state(state)` — redirects to `/login` with error if
|
||||||
|
invalid (CSRF or replay protection).
|
||||||
|
3. Reconstructs the same `redirect_uri` as the redirect handler (required by OAuth2 spec for token exchange).
|
||||||
|
4. Calls `await oauth.exchange_code(config, code, redirect_uri)` to get the access token.
|
||||||
|
4. Calls `await oauth.fetch_user(config, token)` to get the Gitea user profile.
|
||||||
|
5. Calls `users_mod.provision_oauth_user(login, full_name, avatar_url)`.
|
||||||
|
6. Calls `users_mod.create_session(username)` to get a session token.
|
||||||
|
7. Sets `hbd_session` cookie (same flags as password login: httponly, Lax,
|
||||||
|
24h TTL).
|
||||||
|
8. Redirects to `/`.
|
||||||
|
9. Any `OAuthError` re-renders the login page with a generic error message.
|
||||||
|
|
||||||
|
### Login page change
|
||||||
|
|
||||||
|
When `oauth.is_enabled(config)` is `True`, the existing login form gains a
|
||||||
|
separator and a "Sign in with Gitea" link button pointing to
|
||||||
|
`/login/oauth/gitea`. The password form is always rendered regardless.
|
||||||
|
|
||||||
|
### Route registration
|
||||||
|
|
||||||
|
```python
|
||||||
|
web.get("/login/oauth/gitea", oauth_redirect),
|
||||||
|
web.get("/login/oauth/gitea/callback", oauth_callback),
|
||||||
|
```
|
||||||
|
|
||||||
|
Added alongside the existing `/login` and `/logout` routes.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data flow
|
||||||
|
|
||||||
|
```
|
||||||
|
Browser hbd Gitea
|
||||||
|
| | |
|
||||||
|
|-- GET /login ----------->| |
|
||||||
|
|<- login page (+ button) -| |
|
||||||
|
| | |
|
||||||
|
|-- GET /login/oauth/gitea>| |
|
||||||
|
|<- 302 Gitea /authorize --| |
|
||||||
|
| | |
|
||||||
|
|-- GET /login/oauth/authorize ----------------------->|
|
||||||
|
|<- 302 /login/oauth/gitea/callback?code=..&state=.. --|
|
||||||
|
| | |
|
||||||
|
|-- GET /callback -------->| |
|
||||||
|
| |-- POST /access_token ---->|
|
||||||
|
| |<- {access_token} ---------|
|
||||||
|
| |-- GET /api/v1/user ------>|
|
||||||
|
| |<- {login, name, avatar} --|
|
||||||
|
| | provision_oauth_user() |
|
||||||
|
| | create_session() |
|
||||||
|
|<- 302 / (set cookie) ----| |
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
- `test_oauth_state`: `make_state` + `validate_state` happy path; expired state
|
||||||
|
returns False; replay (double-use) returns False.
|
||||||
|
- `test_provision_oauth_user_new`: new username creates User with no password.
|
||||||
|
- `test_provision_oauth_user_existing`: existing config user updates name/avatar,
|
||||||
|
preserves admin flag and notification_channels.
|
||||||
|
- `test_oauth_callback_invalid_state`: callback with bad state redirects to login.
|
||||||
|
- Integration: mock Gitea endpoints with `aiohttp_client` fixture; full
|
||||||
|
redirect → callback → session cookie flow.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Out of scope
|
||||||
|
|
||||||
|
- Restricting login to specific Gitea organisations or teams.
|
||||||
|
- Making OAuth users admin automatically.
|
||||||
|
- Multiple OAuth providers.
|
||||||
|
- Token refresh (Gitea access tokens are long-lived; the hbd session TTL governs
|
||||||
|
re-authentication).
|
||||||
+1
-1
@@ -14,4 +14,4 @@ Install options:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ["__version__"]
|
__all__ = ["__version__"]
|
||||||
__version__ = "5.2.1"
|
__version__ = "5.2.6"
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ CLIENT_DEFAULTS = {
|
|||||||
"hb_port": 50003, # Port where hbd servers listen
|
"hb_port": 50003, # Port where hbd servers listen
|
||||||
"interval": 10, # Heartbeat interval in seconds
|
"interval": 10, # Heartbeat interval in seconds
|
||||||
|
|
||||||
|
# Host identity
|
||||||
|
"owner": None, # Optional username to set as this host's owner on the server
|
||||||
|
|
||||||
# Runtime flags
|
# Runtime flags
|
||||||
"foreground": False,
|
"foreground": False,
|
||||||
"verbose": False,
|
"verbose": False,
|
||||||
|
|||||||
+37
-14
@@ -21,6 +21,7 @@ from typing import Dict, List, Optional
|
|||||||
# Import protocol and config
|
# Import protocol and config
|
||||||
from .config import load_config
|
from .config import load_config
|
||||||
from ..common.proto import dicttos, stodict
|
from ..common.proto import dicttos, stodict
|
||||||
|
from .. import __version__
|
||||||
|
|
||||||
# Import plugin system
|
# Import plugin system
|
||||||
from .plugin import PluginRegistry, PluginLoader, InfoPlugin, MonitorPlugin
|
from .plugin import PluginRegistry, PluginLoader, InfoPlugin, MonitorPlugin
|
||||||
@@ -58,6 +59,7 @@ class AsyncConnection:
|
|||||||
self._dead = False
|
self._dead = False
|
||||||
self._ever_opened = False
|
self._ever_opened = False
|
||||||
self._open_fail_count = 0 # consecutive failures before first success
|
self._open_fail_count = 0 # consecutive failures before first success
|
||||||
|
self.request_info_event: asyncio.Event = asyncio.Event()
|
||||||
|
|
||||||
self.logger = logging.getLogger(f"hbc.conn.{addr}")
|
self.logger = logging.getLogger(f"hbc.conn.{addr}")
|
||||||
|
|
||||||
@@ -137,6 +139,9 @@ class AsyncConnection:
|
|||||||
|
|
||||||
self.ackcount += 1
|
self.ackcount += 1
|
||||||
self.logger.debug(f"ACK received, RTT: {rtt:.1f}ms")
|
self.logger.debug(f"ACK received, RTT: {rtt:.1f}ms")
|
||||||
|
if msg.get("request_update"):
|
||||||
|
self.logger.info("server requested plugin info refresh")
|
||||||
|
self.request_info_event.set()
|
||||||
|
|
||||||
|
|
||||||
class HeartbeatProtocol(asyncio.DatagramProtocol):
|
class HeartbeatProtocol(asyncio.DatagramProtocol):
|
||||||
@@ -172,9 +177,8 @@ class HeartbeatProtocol(asyncio.DatagramProtocol):
|
|||||||
self.logger.error(f"Error processing datagram: {e}", exc_info=True)
|
self.logger.error(f"Error processing datagram: {e}", exc_info=True)
|
||||||
|
|
||||||
def error_received(self, exc):
|
def error_received(self, exc):
|
||||||
"""Handle protocol errors."""
|
"""Handle protocol errors — close transport so the heartbeat sender retries."""
|
||||||
self.logger.warning(f"Protocol error on {self.connection.addr}: {exc} — dropping connection")
|
self.logger.warning(f"Protocol error on {self.connection.addr}: {exc} — will retry")
|
||||||
self.connection._dead = True
|
|
||||||
self.connection.close()
|
self.connection.close()
|
||||||
|
|
||||||
|
|
||||||
@@ -338,6 +342,26 @@ async def heartbeat_sender(conn: AsyncConnection, interval: int):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _info_plugin_refresh_loop(conn: AsyncConnection, info_plugins: List):
|
||||||
|
"""Wait for server requests to re-send InfoPlugin data."""
|
||||||
|
logger = logging.getLogger("hbc.plugins")
|
||||||
|
while running:
|
||||||
|
await conn.request_info_event.wait()
|
||||||
|
if not running:
|
||||||
|
break
|
||||||
|
conn.request_info_event.clear()
|
||||||
|
logger.info("refreshing InfoPlugins on server request")
|
||||||
|
for plugin in info_plugins:
|
||||||
|
plugin._cache = None
|
||||||
|
try:
|
||||||
|
data = await plugin.collect()
|
||||||
|
if data:
|
||||||
|
await conn.sendto({"plugin": plugin.name, **data}, "PLG")
|
||||||
|
logger.info(f"Resent {plugin.name} data")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error re-collecting {plugin.name}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
async def plugin_collector(conn: AsyncConnection, registry: PluginRegistry):
|
async def plugin_collector(conn: AsyncConnection, registry: PluginRegistry):
|
||||||
"""Collect and send plugin data.
|
"""Collect and send plugin data.
|
||||||
|
|
||||||
@@ -369,16 +393,13 @@ async def plugin_collector(conn: AsyncConnection, registry: PluginRegistry):
|
|||||||
for plugin in monitor_plugins:
|
for plugin in monitor_plugins:
|
||||||
by_interval[plugin.interval].append(plugin)
|
by_interval[plugin.interval].append(plugin)
|
||||||
|
|
||||||
# Create tasks for each interval
|
# Create tasks for each interval; always include the info-refresh watcher
|
||||||
tasks = []
|
tasks = [asyncio.create_task(_info_plugin_refresh_loop(conn, info_plugins))]
|
||||||
for interval, plugins in by_interval.items():
|
for interval, plugins in by_interval.items():
|
||||||
task = asyncio.create_task(
|
tasks.append(asyncio.create_task(
|
||||||
plugin_collector_interval(conn, plugins, interval)
|
plugin_collector_interval(conn, plugins, interval)
|
||||||
)
|
))
|
||||||
tasks.append(task)
|
|
||||||
|
|
||||||
# Wait for all tasks
|
|
||||||
if tasks:
|
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@@ -464,7 +485,7 @@ async def cleanup(connections: List[AsyncConnection]):
|
|||||||
logger.info("Cleaning up connections")
|
logger.info("Cleaning up connections")
|
||||||
|
|
||||||
target = next((c for c in connections if c.transport), connections[0] if connections else None)
|
target = next((c for c in connections if c.transport), connections[0] if connections else None)
|
||||||
if target:
|
if target and send_shutdown:
|
||||||
try:
|
try:
|
||||||
await target.sendto({"shutdown": 1, "acks": target.ackcount})
|
await target.sendto({"shutdown": 1, "acks": target.ackcount})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -478,7 +499,7 @@ async def cleanup(connections: List[AsyncConnection]):
|
|||||||
|
|
||||||
async def async_main(args, config):
|
async def async_main(args, config):
|
||||||
"""Async main function."""
|
"""Async main function."""
|
||||||
global running, shutdown_event, active_tasks
|
global running, shutdown_event, active_tasks, send_shutdown
|
||||||
|
|
||||||
# Create shutdown event
|
# Create shutdown event
|
||||||
shutdown_event = asyncio.Event()
|
shutdown_event = asyncio.Event()
|
||||||
@@ -495,8 +516,7 @@ async def async_main(args, config):
|
|||||||
hb_port = config.get("hb_port", PORT)
|
hb_port = config.get("hb_port", PORT)
|
||||||
interval = config.get("interval", INTERVAL)
|
interval = config.get("interval", INTERVAL)
|
||||||
|
|
||||||
logger.info(f"Starting hbc for {iam} -> {hb_hosts}")
|
logger.info(f"hbc {__version__} on {iam} -> {hb_hosts} port={hb_port}, interval={interval}s")
|
||||||
logger.info(f"Port: {hb_port}, Interval: {interval}s")
|
|
||||||
|
|
||||||
# Create connections
|
# Create connections
|
||||||
connections = []
|
connections = []
|
||||||
@@ -526,10 +546,13 @@ async def async_main(args, config):
|
|||||||
logger.info(f"Created {len(connections)} connections")
|
logger.info(f"Created {len(connections)} connections")
|
||||||
|
|
||||||
# Send boot/message if requested
|
# Send boot/message if requested
|
||||||
|
send_shutdown = False
|
||||||
if args.boot or args.message:
|
if args.boot or args.message:
|
||||||
boot_msg = {}
|
boot_msg = {}
|
||||||
if args.boot:
|
if args.boot:
|
||||||
boot_msg["boot"] = 1
|
boot_msg["boot"] = 1
|
||||||
|
args.boot = False # Clear boot flag so we don't send it again in main loop
|
||||||
|
send_shutdown = True
|
||||||
if args.message:
|
if args.message:
|
||||||
boot_msg["service"] = "service"
|
boot_msg["service"] = "service"
|
||||||
boot_msg["msg"] = args.message
|
boot_msg["msg"] = args.message
|
||||||
|
|||||||
@@ -364,7 +364,10 @@ class PluginLoader:
|
|||||||
|
|
||||||
# Instantiate plugin with config — check plugins subdict first,
|
# Instantiate plugin with config — check plugins subdict first,
|
||||||
# then top-level keys (e.g. nagios_runner: ... at root of config).
|
# then top-level keys (e.g. nagios_runner: ... at root of config).
|
||||||
plugin_instance_config = plugins_subconfig.get(obj.name) or raw_config.get(obj.name, {})
|
plugin_instance_config = dict(plugins_subconfig.get(obj.name) or raw_config.get(obj.name) or {})
|
||||||
|
# Propagate top-level owner so os_info (and any future plugin) can report it.
|
||||||
|
if "owner" in raw_config and "owner" not in plugin_instance_config:
|
||||||
|
plugin_instance_config["owner"] = raw_config["owner"]
|
||||||
plugin = obj(config=plugin_instance_config)
|
plugin = obj(config=plugin_instance_config)
|
||||||
|
|
||||||
# Initialize plugin
|
# Initialize plugin
|
||||||
|
|||||||
@@ -62,6 +62,9 @@ class OSInfoPlugin(InfoPlugin):
|
|||||||
"hbc_version": hbc_version,
|
"hbc_version": hbc_version,
|
||||||
"hbc_type": "full",
|
"hbc_type": "full",
|
||||||
}
|
}
|
||||||
|
if self.config.get("owner"):
|
||||||
|
self.logger.debug(f"Adding owner from config: {self.config['owner']}")
|
||||||
|
data["owner"] = self.config["owner"]
|
||||||
|
|
||||||
# Add Linux-specific distribution info
|
# Add Linux-specific distribution info
|
||||||
if platform.system() == "Linux":
|
if platform.system() == "Linux":
|
||||||
|
|||||||
@@ -13,12 +13,8 @@ plugins:
|
|||||||
count: 3 # ICMP packets per ping run (default 3)
|
count: 3 # ICMP packets per ping run (default 3)
|
||||||
timeout: 5 # seconds before a host is considered unreachable (default 5)
|
timeout: 5 # seconds before a host is considered unreachable (default 5)
|
||||||
hosts:
|
hosts:
|
||||||
8.8.8.8:
|
- 8.8.8.8
|
||||||
warning: 20.0 # ms
|
- 192.168.1.1
|
||||||
critical: 100.0 # ms
|
|
||||||
192.168.1.1:
|
|
||||||
warning: 5.0
|
|
||||||
critical: 20.0
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Reported metrics per host (metric key uses the hostname with dots/colons replaced
|
Reported metrics per host (metric key uses the hostname with dots/colons replaced
|
||||||
|
|||||||
@@ -89,8 +89,18 @@ class ZFSMonitorPlugin(MonitorPlugin):
|
|||||||
name = parts[0].strip()
|
name = parts[0].strip()
|
||||||
if self._pools_filter and name not in self._pools_filter:
|
if self._pools_filter and name not in self._pools_filter:
|
||||||
continue
|
continue
|
||||||
|
health = parts[1].strip()
|
||||||
|
if health == "ONLINE":
|
||||||
|
status = 0
|
||||||
|
elif health in ("DEGRADED", "ONLINE with errors"):
|
||||||
|
status = 1
|
||||||
|
elif health in ("FAULTED", "OFFLINE", "UNAVAIL"):
|
||||||
|
status = 2
|
||||||
|
else:
|
||||||
|
status = 3 # unknown status
|
||||||
pools[name] = {
|
pools[name] = {
|
||||||
"health": parts[1].strip(),
|
"health": health,
|
||||||
|
"status": status,
|
||||||
"size": _int(parts[2]),
|
"size": _int(parts[2]),
|
||||||
"alloc": _int(parts[3]),
|
"alloc": _int(parts[3]),
|
||||||
"free": _int(parts[4]),
|
"free": _int(parts[4]),
|
||||||
|
|||||||
@@ -134,6 +134,30 @@ thresholds:
|
|||||||
hysteresis: 0.1
|
hysteresis: 0.1
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# ZFS Monitor Thresholds
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
zfs_monitor:
|
||||||
|
# Pool health check — built-in default; shown here for reference/override.
|
||||||
|
# status is 0 (ONLINE) or 1 (DEGRADED) or 2 (SUSPENDED, FAULTED, UNAVAIL…).
|
||||||
|
# Use '*' to apply the same rule to every pool, or name a specific pool.
|
||||||
|
pools:
|
||||||
|
'*':
|
||||||
|
status:
|
||||||
|
warning: 1 # Alert WARNING when pool is DEGRADED
|
||||||
|
critical: 2 # Alert CRITICAL when pool is SUSPENDED/FAULTED/UNAVAIL
|
||||||
|
operator: ">"
|
||||||
|
hysteresis: 0.0 # No hysteresis — a degraded pool is always critical
|
||||||
|
display: "ZFS pool {pool_name} is {health}"
|
||||||
|
|
||||||
|
# Per-pool capacity thresholds (optional; add pools you care about)
|
||||||
|
# tank:
|
||||||
|
# capacity:
|
||||||
|
# warning: 75.0 # Warn at 75% used
|
||||||
|
# critical: 90.0 # Critical at 90% used
|
||||||
|
# operator: ">"
|
||||||
|
# hysteresis: 0.05
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
# Network Monitor Thresholds
|
# Network Monitor Thresholds
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
|
|||||||
+17
-1
@@ -34,6 +34,9 @@ SERVER_DEFAULTS = {
|
|||||||
"users": {}, # username -> {full_name, avatar, password, admin, notification_channels}
|
"users": {}, # username -> {full_name, avatar, password, admin, notification_channels}
|
||||||
"default_owner": None, # Username that owns hosts with no explicit owner
|
"default_owner": None, # Username that owns hosts with no explicit owner
|
||||||
|
|
||||||
|
# OAuth2 providers
|
||||||
|
"oauth": {}, # oauth.gitea.{url,client_id,client_secret}
|
||||||
|
|
||||||
# Host management
|
# Host management
|
||||||
"hosts": {}, # Unified host definitions
|
"hosts": {}, # Unified host definitions
|
||||||
"dyndnshosts": [], # Hosts with dynamic DNS (legacy)
|
"dyndnshosts": [], # Hosts with dynamic DNS (legacy)
|
||||||
@@ -101,9 +104,22 @@ THRESHOLD_DEFAULTS = {
|
|||||||
'display': '{check_name} {output}',
|
'display': '{check_name} {output}',
|
||||||
'operator': "nagios"
|
'operator': "nagios"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
'zfs_monitor': {
|
||||||
|
'pools': {
|
||||||
|
'*': {
|
||||||
|
'status': {
|
||||||
|
'warning': 1,
|
||||||
|
'critical': 2,
|
||||||
|
'operator': '>',
|
||||||
|
'hysteresis': 0.0,
|
||||||
|
'display': 'ZFS pool {pool_name} is {health}'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_config(path=None):
|
def load_config(path=None):
|
||||||
@@ -309,7 +325,7 @@ def get_host_access(config, hostname) -> dict:
|
|||||||
"""
|
"""
|
||||||
host_cfg = get_host_config(config, hostname)
|
host_cfg = get_host_config(config, hostname)
|
||||||
|
|
||||||
owner = host_cfg.get("owner") or get_default_owner(config)
|
owner = host_cfg.get("owner") # or get_default_owner(config)
|
||||||
|
|
||||||
managers = host_cfg.get("managers", [])
|
managers = host_cfg.get("managers", [])
|
||||||
if isinstance(managers, str):
|
if isinstance(managers, str):
|
||||||
|
|||||||
+69
-1
@@ -16,6 +16,7 @@ from . import data
|
|||||||
from . import notify as notify_mod
|
from . import notify as notify_mod
|
||||||
from . import settings as settings_mod
|
from . import settings as settings_mod
|
||||||
from . import users as users_mod
|
from . import users as users_mod
|
||||||
|
from . import oauth as oauth_mod
|
||||||
from . import ws as ws_mod
|
from . import ws as ws_mod
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -538,6 +539,7 @@ async def start(
|
|||||||
"name": hostname,
|
"name": hostname,
|
||||||
"plugins": list(host.plugin_data.keys()),
|
"plugins": list(host.plugin_data.keys()),
|
||||||
"is_owner": _can_own_host(current_user, host),
|
"is_owner": _can_own_host(current_user, host),
|
||||||
|
"owner": host.owner,
|
||||||
})
|
})
|
||||||
|
|
||||||
tmpl = env.get_template("plugins.html")
|
tmpl = env.get_template("plugins.html")
|
||||||
@@ -620,6 +622,18 @@ async def start(
|
|||||||
)
|
)
|
||||||
raise resp
|
raise resp
|
||||||
error = "Invalid username or password."
|
error = "Invalid username or password."
|
||||||
|
elif request.rel_url.query.get("error"):
|
||||||
|
error = "Sign-in failed. Please try again."
|
||||||
|
|
||||||
|
gitea_button = ""
|
||||||
|
if oauth_mod.is_enabled(config):
|
||||||
|
logo_url = config.get("oauth", {}).get("gitea", {}).get("logo", "")
|
||||||
|
logo_img = f'<img src="{logo_url}" alt="" class="gitea-logo">' if logo_url else ""
|
||||||
|
gitea_button = f"""
|
||||||
|
<div class="divider">or</div>
|
||||||
|
<a href="/login/oauth/gitea" class="gitea-btn">
|
||||||
|
{logo_img}Sign in with Gitea
|
||||||
|
</a>"""
|
||||||
|
|
||||||
html = f"""<!DOCTYPE html>
|
html = f"""<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
@@ -640,6 +654,14 @@ async def start(
|
|||||||
button:hover {{ background: #0055aa; }}
|
button:hover {{ background: #0055aa; }}
|
||||||
.error {{ color: #c00; font-size: .9em; margin-bottom: .8em; }}
|
.error {{ color: #c00; font-size: .9em; margin-bottom: .8em; }}
|
||||||
.field {{ margin-bottom: .9em; }}
|
.field {{ margin-bottom: .9em; }}
|
||||||
|
.divider {{ text-align: center; margin: 1.2em 0 .8em; color: #999;
|
||||||
|
font-size: .85em; border-top: 1px solid #eee; padding-top: .8em; }}
|
||||||
|
.gitea-btn {{ display: flex; align-items: center; justify-content: center;
|
||||||
|
gap: .5em; width: 100%; padding: .6em; background: #16191d;
|
||||||
|
color: #fff; border-radius: 4px; font-size: 1em; text-align: center;
|
||||||
|
text-decoration: none; box-sizing: border-box; }}
|
||||||
|
.gitea-btn:hover {{ background: #4e7d1e; }}
|
||||||
|
.gitea-logo {{ height: 1.2em; width: auto; vertical-align: middle; }}
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
@@ -650,7 +672,7 @@ async def start(
|
|||||||
<div class="field"><label>Username</label><input name="username" autofocus></div>
|
<div class="field"><label>Username</label><input name="username" autofocus></div>
|
||||||
<div class="field"><label>Password</label><input name="password" type="password"></div>
|
<div class="field"><label>Password</label><input name="password" type="password"></div>
|
||||||
<button type="submit">Sign in</button>
|
<button type="submit">Sign in</button>
|
||||||
</form>
|
</form>{gitea_button}
|
||||||
</div>
|
</div>
|
||||||
</body>
|
</body>
|
||||||
</html>"""
|
</html>"""
|
||||||
@@ -896,6 +918,50 @@ async def start(
|
|||||||
)
|
)
|
||||||
return web.Response(text=body, content_type="text/html")
|
return web.Response(text=body, content_type="text/html")
|
||||||
|
|
||||||
|
def _oauth_redirect_uri(request) -> str:
|
||||||
|
base = config.get("base_url", "").rstrip("/") or str(request.url.origin())
|
||||||
|
return f"{base}/login/oauth/gitea/callback"
|
||||||
|
|
||||||
|
async def oauth_gitea_redirect(request):
|
||||||
|
"""GET /login/oauth/gitea — kick off the Gitea OAuth2 flow."""
|
||||||
|
if not oauth_mod.is_enabled(config):
|
||||||
|
return web.Response(status=404, text="OAuth not configured")
|
||||||
|
state = oauth_mod.make_state()
|
||||||
|
raise web.HTTPFound(oauth_mod.authorization_url(config, state, _oauth_redirect_uri(request)))
|
||||||
|
|
||||||
|
async def oauth_gitea_callback(request):
|
||||||
|
"""GET /login/oauth/gitea/callback — handle Gitea's redirect back."""
|
||||||
|
if not oauth_mod.is_enabled(config):
|
||||||
|
return web.Response(status=404, text="OAuth not configured")
|
||||||
|
code = request.rel_url.query.get("code", "")
|
||||||
|
state = request.rel_url.query.get("state", "")
|
||||||
|
if not code or not state:
|
||||||
|
return web.Response(status=400, text="Missing code or state")
|
||||||
|
if not oauth_mod.validate_state(state):
|
||||||
|
logger.warning("OAuth: invalid or expired state token from %s", request.remote)
|
||||||
|
raise web.HTTPFound("/login?error=1")
|
||||||
|
try:
|
||||||
|
token = await oauth_mod.exchange_code(config, code, _oauth_redirect_uri(request))
|
||||||
|
profile = await oauth_mod.fetch_user(config, token)
|
||||||
|
except oauth_mod.OAuthError as exc:
|
||||||
|
logger.warning("OAuth error: %s", exc)
|
||||||
|
raise web.HTTPFound("/login?error=1")
|
||||||
|
user = users_mod.provision_oauth_user(
|
||||||
|
profile["login"],
|
||||||
|
profile["full_name"],
|
||||||
|
profile["avatar_url"],
|
||||||
|
)
|
||||||
|
session_token = users_mod.create_session(user.username)
|
||||||
|
resp = web.HTTPFound("/")
|
||||||
|
resp.set_cookie(
|
||||||
|
SESSION_COOKIE,
|
||||||
|
session_token,
|
||||||
|
max_age=users_mod.SESSION_TTL,
|
||||||
|
httponly=True,
|
||||||
|
samesite="Lax",
|
||||||
|
)
|
||||||
|
raise resp
|
||||||
|
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.add_routes(
|
app.add_routes(
|
||||||
[
|
[
|
||||||
@@ -907,6 +973,8 @@ async def start(
|
|||||||
web.get("/logout", web_logout),
|
web.get("/logout", web_logout),
|
||||||
web.post("/api/0/auth/login", api_login),
|
web.post("/api/0/auth/login", api_login),
|
||||||
web.post("/api/0/auth/logout", api_logout),
|
web.post("/api/0/auth/logout", api_logout),
|
||||||
|
web.get("/login/oauth/gitea", oauth_gitea_redirect),
|
||||||
|
web.get("/login/oauth/gitea/callback", oauth_gitea_callback),
|
||||||
# Users
|
# Users
|
||||||
web.get("/api/0/users", api_users),
|
web.get("/api/0/users", api_users),
|
||||||
web.get("/api/0/users/me", api_user_self),
|
web.get("/api/0/users/me", api_user_self),
|
||||||
|
|||||||
+2
-1
@@ -475,7 +475,8 @@ def run(config, config_path=None):
|
|||||||
if config.get("debug", 0) > 0:
|
if config.get("debug", 0) > 0:
|
||||||
log_level = logging.DEBUG
|
log_level = logging.DEBUG
|
||||||
logging.basicConfig(level=log_level)
|
logging.basicConfig(level=log_level)
|
||||||
logging.getLogger("aiohttp.access").setLevel(logging.DEBUG)
|
if not config.get("debug", 0):
|
||||||
|
logging.getLogger("aiohttp.access").propagate = False
|
||||||
load_pickled_hosts(config, hbdclass)
|
load_pickled_hosts(config, hbdclass)
|
||||||
|
|
||||||
notify_mod.initlog(logfile=config.get("logfile", "messages.log"))
|
notify_mod.initlog(logfile=config.get("logfile", "messages.log"))
|
||||||
|
|||||||
+14
-5
@@ -106,11 +106,18 @@ def closelog():
|
|||||||
|
|
||||||
def eventlog(host, lvl, m, service=None):
|
def eventlog(host, lvl, m, service=None):
|
||||||
ts = time.time()
|
ts = time.time()
|
||||||
|
msg = {
|
||||||
|
"ts": ts,
|
||||||
|
"host": host or None,
|
||||||
|
"level": lvl,
|
||||||
|
"service": service,
|
||||||
|
"message": m,
|
||||||
|
}
|
||||||
|
data.msgs.append(msg)
|
||||||
s = f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))} {lvl} "
|
s = f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))} {lvl} "
|
||||||
if host:
|
if host:
|
||||||
s += f"{host} "
|
s += f"{host} "
|
||||||
s += m
|
s += m
|
||||||
data.msgs.append(s)
|
|
||||||
logger.info(s)
|
logger.info(s)
|
||||||
if logf:
|
if logf:
|
||||||
try:
|
try:
|
||||||
@@ -118,7 +125,7 @@ def eventlog(host, lvl, m, service=None):
|
|||||||
logf.flush()
|
logf.flush()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("failed to write to logfile: %s", e)
|
logger.warning("failed to write to logfile: %s", e)
|
||||||
msg_to_websockets("message", s)
|
msg_to_websockets("message", msg)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -134,9 +141,11 @@ def _send_pushover(channel_cfg: dict, notif: Notification) -> bool:
|
|||||||
logger.warning("pushover: missing token or user")
|
logger.warning("pushover: missing token or user")
|
||||||
return False
|
return False
|
||||||
params: dict = {"token": token, "user": user, "title": notif.title, "message": notif.body}
|
params: dict = {"token": token, "user": user, "title": notif.title, "message": notif.body}
|
||||||
|
if channel_cfg.get("sound"):
|
||||||
|
params["sound"] = channel_cfg["sound"]
|
||||||
if notif.url:
|
if notif.url:
|
||||||
params["url"] = notif.url
|
params["url"] = notif.url
|
||||||
params["url_title"] = "Plugin metrics"
|
params["url_title"] = "Heartbeat"
|
||||||
conn = http.client.HTTPSConnection("api.pushover.net:443")
|
conn = http.client.HTTPSConnection("api.pushover.net:443")
|
||||||
try:
|
try:
|
||||||
conn.request(
|
conn.request(
|
||||||
@@ -209,7 +218,7 @@ def _send_mattermost(channel_cfg: dict, notif: Notification) -> bool:
|
|||||||
return False
|
return False
|
||||||
text = f"**{notif.title}**\n{notif.body}"
|
text = f"**{notif.title}**\n{notif.body}"
|
||||||
if notif.url:
|
if notif.url:
|
||||||
text += f"\n[Plugin metrics]({notif.url})"
|
text += f"\n[Plugin metrics] {notif.url}"
|
||||||
ses = {"url": host, "scheme": "http", "basepath": "/api/v4", "port": 8065}
|
ses = {"url": host, "scheme": "http", "basepath": "/api/v4", "port": 8065}
|
||||||
mm = Driver(ses)
|
mm = Driver(ses)
|
||||||
payload: dict = {"text": text, "channel": channel, "username": channel_cfg.get("username", "hbd")}
|
payload: dict = {"text": text, "channel": channel, "username": channel_cfg.get("username", "hbd")}
|
||||||
@@ -392,7 +401,7 @@ def _build_url(host_name: str) -> str:
|
|||||||
base_url = _config.get("base_url", "").rstrip("/")
|
base_url = _config.get("base_url", "").rstrip("/")
|
||||||
if not base_url:
|
if not base_url:
|
||||||
return ""
|
return ""
|
||||||
return f"{base_url}/plugins#{host_name}"
|
return f"{base_url}/alerts?filter={host_name}"
|
||||||
|
|
||||||
|
|
||||||
async def send_notification(host_name: str, notif: Notification) -> dict:
|
async def send_notification(host_name: str, notif: Notification) -> dict:
|
||||||
|
|||||||
@@ -0,0 +1,142 @@
|
|||||||
|
"""Gitea OAuth2 support.
|
||||||
|
|
||||||
|
Config shape (in ~/.hb.yaml):
|
||||||
|
|
||||||
|
oauth:
|
||||||
|
gitea:
|
||||||
|
url: https://git.example.com
|
||||||
|
client_id: <client-id>
|
||||||
|
client_secret: <client-secret>
|
||||||
|
|
||||||
|
Register a Gitea OAuth2 application at:
|
||||||
|
Gitea → Settings → Applications → OAuth2
|
||||||
|
Set the redirect URI to:
|
||||||
|
https://<hbd-host>/login/oauth/gitea/callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
STATE_TTL = 600 # 10 minutes
|
||||||
|
|
||||||
|
# state_token -> expiry timestamp
|
||||||
|
_states: dict[str, float] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def make_state() -> str:
|
||||||
|
"""Generate a CSRF state token, store it with TTL, and return it."""
|
||||||
|
_purge_states()
|
||||||
|
token = secrets.token_hex(32)
|
||||||
|
_states[token] = time.time() + STATE_TTL
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def validate_state(state: str) -> bool:
|
||||||
|
"""Return True if *state* is known and unexpired; always removes it."""
|
||||||
|
expiry = _states.pop(state, None)
|
||||||
|
if expiry is None:
|
||||||
|
return False
|
||||||
|
return time.time() < expiry
|
||||||
|
|
||||||
|
|
||||||
|
def _purge_states() -> None:
|
||||||
|
"""Remove all expired CSRF state tokens from the in-memory store."""
|
||||||
|
now = time.time()
|
||||||
|
expired = [k for k, exp in list(_states.items()) if exp < now]
|
||||||
|
for k in expired:
|
||||||
|
del _states[k]
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthError(Exception):
|
||||||
|
"""Raised when the OAuth2 flow fails for any reason."""
|
||||||
|
|
||||||
|
|
||||||
|
def _gitea_cfg(config: dict) -> dict:
|
||||||
|
"""Return the gitea sub-dict or {} if absent/incomplete."""
|
||||||
|
return config.get("oauth", {}).get("gitea", {})
|
||||||
|
|
||||||
|
|
||||||
|
def is_enabled(config: dict) -> bool:
|
||||||
|
"""Return True when all three required Gitea OAuth keys are present."""
|
||||||
|
g = _gitea_cfg(config)
|
||||||
|
return bool(g.get("url") and g.get("client_id") and g.get("client_secret"))
|
||||||
|
|
||||||
|
|
||||||
|
def authorization_url(config: dict, state: str, redirect_uri: str) -> str:
|
||||||
|
"""Return the Gitea OAuth2 authorization URL to redirect the browser to."""
|
||||||
|
g = _gitea_cfg(config)
|
||||||
|
if not (g.get("url") and g.get("client_id") and g.get("client_secret")):
|
||||||
|
raise OAuthError("Gitea OAuth2 is not configured")
|
||||||
|
params = urllib.parse.urlencode({
|
||||||
|
"client_id": g["client_id"],
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": "user:email",
|
||||||
|
"state": state,
|
||||||
|
})
|
||||||
|
return f"{g['url'].rstrip('/')}/login/oauth/authorize?{params}"
|
||||||
|
|
||||||
|
|
||||||
|
async def exchange_code(config: dict, code: str, redirect_uri: str) -> str:
|
||||||
|
"""Exchange an authorization *code* for a Gitea access token.
|
||||||
|
|
||||||
|
Returns the access token string. Raises OAuthError on any failure.
|
||||||
|
"""
|
||||||
|
g = _gitea_cfg(config)
|
||||||
|
if not (g.get("url") and g.get("client_id") and g.get("client_secret")):
|
||||||
|
raise OAuthError("Gitea OAuth2 is not configured")
|
||||||
|
url = f"{g['url'].rstrip('/')}/login/oauth/access_token"
|
||||||
|
payload = {
|
||||||
|
"client_id": g["client_id"],
|
||||||
|
"client_secret": g["client_secret"],
|
||||||
|
"code": code,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.post(url, json=payload, headers={"Accept": "application/json"}) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise OAuthError(f"Token exchange failed ({resp.status}): {text}")
|
||||||
|
data = await resp.json()
|
||||||
|
token = data.get("access_token")
|
||||||
|
if not token:
|
||||||
|
raise OAuthError(f"No access_token in response: {data}")
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
raise OAuthError(f"Token exchange network error: {exc}") from exc
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_user(config: dict, token: str) -> dict:
|
||||||
|
"""Fetch the authenticated user's profile from Gitea.
|
||||||
|
|
||||||
|
Returns a dict with keys: login, full_name, avatar_url.
|
||||||
|
Raises OAuthError on any failure.
|
||||||
|
"""
|
||||||
|
g = _gitea_cfg(config)
|
||||||
|
if not (g.get("url") and g.get("client_id") and g.get("client_secret")):
|
||||||
|
raise OAuthError("Gitea OAuth2 is not configured")
|
||||||
|
url = f"{g['url'].rstrip('/')}/api/v1/user"
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.get(url, headers={"Authorization": f"token {token}"}) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise OAuthError(f"User fetch failed ({resp.status}): {text}")
|
||||||
|
data = await resp.json()
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
raise OAuthError(f"User fetch network error: {exc}") from exc
|
||||||
|
return {
|
||||||
|
"login": data.get("login", ""),
|
||||||
|
"full_name": data.get("full_name", ""),
|
||||||
|
"avatar_url": data.get("avatar_url", ""),
|
||||||
|
}
|
||||||
@@ -94,6 +94,24 @@
|
|||||||
border-color: #2196f3;
|
border-color: #2196f3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.filter-input {
|
||||||
|
padding: 7px 12px;
|
||||||
|
border: 2px solid #ddd;
|
||||||
|
border-radius: 20px;
|
||||||
|
font-size: 0.9em;
|
||||||
|
outline: none;
|
||||||
|
width: 200px;
|
||||||
|
transition: border-color 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.filter-input:focus {
|
||||||
|
border-color: #2196f3;
|
||||||
|
}
|
||||||
|
|
||||||
|
.filter-input.invalid {
|
||||||
|
border-color: #f44336;
|
||||||
|
}
|
||||||
|
|
||||||
.alerts-container {
|
.alerts-container {
|
||||||
background: white;
|
background: white;
|
||||||
border-radius: 8px;
|
border-radius: 8px;
|
||||||
@@ -184,9 +202,9 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
.alert-metric {
|
.alert-metric {
|
||||||
color: #666;
|
color: #0066cc;
|
||||||
font-family: 'Courier New', monospace;
|
font-size: 1.1em;
|
||||||
font-size: 0.9em;
|
font-weight: normal;
|
||||||
}
|
}
|
||||||
|
|
||||||
.alert-details {
|
.alert-details {
|
||||||
@@ -316,6 +334,7 @@
|
|||||||
<button class="filter-button active" onclick="filterAlerts('all')">All</button>
|
<button class="filter-button active" onclick="filterAlerts('all')">All</button>
|
||||||
<button class="filter-button" onclick="filterAlerts('critical')">Critical Only</button>
|
<button class="filter-button" onclick="filterAlerts('critical')">Critical Only</button>
|
||||||
<button class="filter-button" onclick="filterAlerts('warning')">Warning Only</button>
|
<button class="filter-button" onclick="filterAlerts('warning')">Warning Only</button>
|
||||||
|
<input id="host-filter" class="filter-input" type="text" placeholder="host filter (regex)" oninput="onHostFilterInput(this)">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="alerts-container">
|
<div class="alerts-container">
|
||||||
@@ -332,6 +351,7 @@
|
|||||||
<script>
|
<script>
|
||||||
let currentFilter = 'all';
|
let currentFilter = 'all';
|
||||||
let allAlerts = [];
|
let allAlerts = [];
|
||||||
|
let hostFilterRe = null;
|
||||||
|
|
||||||
async function loadAlerts() {
|
async function loadAlerts() {
|
||||||
try {
|
try {
|
||||||
@@ -366,10 +386,13 @@
|
|||||||
// Filter alerts based on current filter
|
// Filter alerts based on current filter
|
||||||
let filteredAlerts = alerts;
|
let filteredAlerts = alerts;
|
||||||
if (currentFilter !== 'all') {
|
if (currentFilter !== 'all') {
|
||||||
filteredAlerts = alerts.filter(alert =>
|
filteredAlerts = filteredAlerts.filter(alert =>
|
||||||
alert.level.toLowerCase() === currentFilter
|
alert.level.toLowerCase() === currentFilter
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
if (hostFilterRe) {
|
||||||
|
filteredAlerts = filteredAlerts.filter(alert => hostFilterRe.test(alert.hostname));
|
||||||
|
}
|
||||||
|
|
||||||
if (filteredAlerts.length === 0) {
|
if (filteredAlerts.length === 0) {
|
||||||
if (currentFilter === 'all' && alerts.length === 0) {
|
if (currentFilter === 'all' && alerts.length === 0) {
|
||||||
@@ -438,8 +461,8 @@
|
|||||||
<div class="alert-header">
|
<div class="alert-header">
|
||||||
<span class="alert-level ${level}">${alert.level}</span>
|
<span class="alert-level ${level}">${alert.level}</span>
|
||||||
<a class="alert-hostname" href="/plugins#${alert.hostname}">${alert.hostname}</a>
|
<a class="alert-hostname" href="/plugins#${alert.hostname}">${alert.hostname}</a>
|
||||||
|
<span class="alert-metric">${(alert.metric_path.includes('.') ? alert.metric_path.slice(alert.metric_path.indexOf('.') + 1) : alert.metric_path).replace(/_status_code$/, '')}</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="alert-metric">${alert.metric_path}</div>
|
|
||||||
<div class="alert-details">
|
<div class="alert-details">
|
||||||
<span>${valueText}</span>
|
<span>${valueText}</span>
|
||||||
<span class="alert-duration">Active for ${duration}</span>
|
<span class="alert-duration">Active for ${duration}</span>
|
||||||
@@ -538,9 +561,36 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function onHostFilterInput(input) {
|
||||||
|
const val = input.value.trim();
|
||||||
|
if (!val) {
|
||||||
|
hostFilterRe = null;
|
||||||
|
input.classList.remove('invalid');
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
hostFilterRe = new RegExp(val, 'i');
|
||||||
|
input.classList.remove('invalid');
|
||||||
|
} catch (_) {
|
||||||
|
hostFilterRe = null;
|
||||||
|
input.classList.add('invalid');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
renderAlerts(allAlerts);
|
||||||
|
}
|
||||||
|
|
||||||
// Auto-refresh every 15 seconds
|
// Auto-refresh every 15 seconds
|
||||||
setInterval(loadAlerts, 15000);
|
setInterval(loadAlerts, 15000);
|
||||||
|
|
||||||
|
// Initialise filter from URL query string (?filter=...)
|
||||||
|
(function () {
|
||||||
|
const param = new URLSearchParams(window.location.search).get('filter');
|
||||||
|
if (param) {
|
||||||
|
const input = document.getElementById('host-filter');
|
||||||
|
input.value = param;
|
||||||
|
onHostFilterInput(input);
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
// Initial load
|
// Initial load
|
||||||
loadAlerts();
|
loadAlerts();
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -214,7 +214,7 @@
|
|||||||
ctx.restore();
|
ctx.restore();
|
||||||
}
|
}
|
||||||
|
|
||||||
hand((m + s / 60) / 60 * Math.PI * 2 - Math.PI / 2,
|
hand((sFrac >= 58.5 ? m + 1 : m) / 60 * Math.PI * 2 - Math.PI / 2,
|
||||||
R * 0.88, -R * 0.12, SIZE * 0.027, '#222'); /* minute */
|
R * 0.88, -R * 0.12, SIZE * 0.027, '#222'); /* minute */
|
||||||
hand((h + m / 60) / 12 * Math.PI * 2 - Math.PI / 2,
|
hand((h + m / 60) / 12 * Math.PI * 2 - Math.PI / 2,
|
||||||
R * 0.58, -R * 0.12, SIZE * 0.039, '#222'); /* hour */
|
R * 0.58, -R * 0.12, SIZE * 0.039, '#222'); /* hour */
|
||||||
|
|||||||
@@ -183,11 +183,24 @@
|
|||||||
line-height: 1.0;
|
line-height: 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#messages div {
|
#messages .log-entry {
|
||||||
padding: 5px 0;
|
padding: 5px 0;
|
||||||
border-bottom: 1px solid #f0f0f0;
|
border-bottom: 1px solid #f0f0f0;
|
||||||
|
display: flex;
|
||||||
|
gap: 0.5em;
|
||||||
|
align-items: baseline;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.log-ts { color: #888; white-space: nowrap; }
|
||||||
|
.log-level { font-weight: bold; min-width: 6em; }
|
||||||
|
.log-host { font-weight: 600; }
|
||||||
|
.log-service { color: #888; }
|
||||||
|
|
||||||
|
.log-warning .log-level { color: #b8860b; }
|
||||||
|
.log-critical .log-level { color: #c00; }
|
||||||
|
.log-recover .log-level { color: #2a7a2a; }
|
||||||
|
.log-info .log-level { color: #555; }
|
||||||
|
|
||||||
/* Modal for connection status messages */
|
/* Modal for connection status messages */
|
||||||
.connection-modal {
|
.connection-modal {
|
||||||
display: none;
|
display: none;
|
||||||
@@ -460,7 +473,20 @@
|
|||||||
update_table(state.data);
|
update_table(state.data);
|
||||||
} else if (state.type == "message") {
|
} else if (state.type == "message") {
|
||||||
var msgs = document.getElementById("messages");
|
var msgs = document.getElementById("messages");
|
||||||
msgs.insertAdjacentHTML("afterbegin", "<div>" + state.data + "</div>");
|
var msg = state.data;
|
||||||
|
var _d = new Date(msg.ts * 1000);
|
||||||
|
function _p(n) { return n < 10 ? '0' + n : '' + n; }
|
||||||
|
var ts_str = _d.getFullYear() + '-' + _p(_d.getMonth()+1) + '-' + _p(_d.getDate())
|
||||||
|
+ ' ' + _p(_d.getHours()) + ':' + _p(_d.getMinutes()) + ':' + _p(_d.getSeconds());
|
||||||
|
var lvl = (msg.level || "INFO").toLowerCase();
|
||||||
|
var html = '<div class="log-entry log-' + lvl + '">';
|
||||||
|
html += '<span class="log-ts">' + ts_str + '</span>';
|
||||||
|
html += '<span class="log-level">' + (msg.level || "") + '</span>';
|
||||||
|
if (msg.host) html += '<span class="log-host">' + msg.host + '</span>';
|
||||||
|
if (msg.service) html += '<span class="log-service">' + msg.service + '</span>';
|
||||||
|
html += '<span class="log-msg">' + msg.message + '</span>';
|
||||||
|
html += '</div>';
|
||||||
|
msgs.insertAdjacentHTML("afterbegin", html);
|
||||||
}
|
}
|
||||||
cnt++;
|
cnt++;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -416,7 +416,8 @@
|
|||||||
<span class="host-name">{{ host.name }}</span>
|
<span class="host-name">{{ host.name }}</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="glance-strip" id="glance-{{ host.name }}">
|
<div class="glance-strip" id="glance-{{ host.name }}" data-owner="{{ host.owner or '' }}">
|
||||||
|
{% if current_user and current_user.admin and host.owner %}<span class="glance-chip neutral">{{ host.owner }}</span>{% endif %}
|
||||||
<span class="glance-loading">—</span>
|
<span class="glance-loading">—</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -480,6 +481,7 @@
|
|||||||
const GLANCE_PLUGINS = ['cpu_monitor','memory_monitor','disk_monitor',
|
const GLANCE_PLUGINS = ['cpu_monitor','memory_monitor','disk_monitor',
|
||||||
'network_monitor','nagios_runner','os_info'];
|
'network_monitor','nagios_runner','os_info'];
|
||||||
const SKIP_FIELDS = new Set(['id','name']);
|
const SKIP_FIELDS = new Set(['id','name']);
|
||||||
|
const CURRENT_USER_ADMIN = {{ 'true' if current_user and current_user.admin else 'false' }};
|
||||||
|
|
||||||
// ── Cache ───────────────────────────────────────────────────────────────
|
// ── Cache ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -558,6 +560,12 @@
|
|||||||
|
|
||||||
const chips = [];
|
const chips = [];
|
||||||
|
|
||||||
|
// Owner (admin only, static from server)
|
||||||
|
const owner = strip.dataset.owner;
|
||||||
|
if (CURRENT_USER_ADMIN && owner) {
|
||||||
|
chips.push(`<span class="glance-chip neutral">${owner}</span>`);
|
||||||
|
}
|
||||||
|
|
||||||
// CPU
|
// CPU
|
||||||
const cpu = getCache(hostname, 'cpu_monitor');
|
const cpu = getCache(hostname, 'cpu_monitor');
|
||||||
if (cpu) {
|
if (cpu) {
|
||||||
|
|||||||
+118
-10
@@ -575,10 +575,13 @@ class ThresholdChecker:
|
|||||||
if not isinstance(threshold_config, dict):
|
if not isinstance(threshold_config, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle nested metrics (e.g., partitions./.percent)
|
# Handle nested metrics (e.g., partitions./.percent or pools.*.status)
|
||||||
if metric_name == "partitions":
|
if metric_name == "partitions":
|
||||||
self._parse_partition_thresholds(plugin_name, threshold_config, target_dict)
|
self._parse_partition_thresholds(plugin_name, threshold_config, target_dict)
|
||||||
continue
|
continue
|
||||||
|
if metric_name == "pools":
|
||||||
|
self._parse_pool_thresholds(plugin_name, threshold_config, target_dict)
|
||||||
|
continue
|
||||||
|
|
||||||
metric_path = f"{plugin_name}.{metric_name}"
|
metric_path = f"{plugin_name}.{metric_name}"
|
||||||
|
|
||||||
@@ -664,6 +667,56 @@ class ThresholdChecker:
|
|||||||
|
|
||||||
target_dict[metric_path] = threshold
|
target_dict[metric_path] = threshold
|
||||||
|
|
||||||
|
def _parse_pool_thresholds(
|
||||||
|
self,
|
||||||
|
plugin_name: str,
|
||||||
|
pools: Dict[str, Any],
|
||||||
|
target_dict: Optional[Dict[str, ThresholdConfig]] = None,
|
||||||
|
):
|
||||||
|
"""Parse ZFS pool thresholds. Pool names may be literal or '*' (all pools).
|
||||||
|
|
||||||
|
Config shape::
|
||||||
|
|
||||||
|
zfs_monitor:
|
||||||
|
pools:
|
||||||
|
'*':
|
||||||
|
status:
|
||||||
|
warning: 1
|
||||||
|
critical: 2
|
||||||
|
operator: '>'
|
||||||
|
tank:
|
||||||
|
capacity:
|
||||||
|
warning: 80
|
||||||
|
critical: 90
|
||||||
|
"""
|
||||||
|
if target_dict is None:
|
||||||
|
target_dict = self.thresholds
|
||||||
|
|
||||||
|
for pool_name, metrics in pools.items():
|
||||||
|
if not isinstance(metrics, dict):
|
||||||
|
continue
|
||||||
|
for metric_name, threshold_config in metrics.items():
|
||||||
|
if not isinstance(threshold_config, dict):
|
||||||
|
continue
|
||||||
|
metric_path = f"{plugin_name}.{pool_name}.{metric_name}"
|
||||||
|
warning = threshold_config.get("warning")
|
||||||
|
critical = threshold_config.get("critical")
|
||||||
|
operator = threshold_config.get("operator", ">")
|
||||||
|
hysteresis = threshold_config.get("hysteresis", 0.02)
|
||||||
|
enabled = threshold_config.get("enabled", True)
|
||||||
|
display = threshold_config.get("display")
|
||||||
|
if warning is None and critical is None:
|
||||||
|
continue
|
||||||
|
target_dict[metric_path] = ThresholdConfig(
|
||||||
|
metric_path=metric_path,
|
||||||
|
warning=warning,
|
||||||
|
critical=critical,
|
||||||
|
operator=operator,
|
||||||
|
hysteresis=hysteresis,
|
||||||
|
enabled=enabled,
|
||||||
|
display=display,
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_rtt_thresholds(
|
def _parse_rtt_thresholds(
|
||||||
self,
|
self,
|
||||||
rtt_thresholds: Dict[str, Any],
|
rtt_thresholds: Dict[str, Any],
|
||||||
@@ -967,6 +1020,44 @@ class ThresholdChecker:
|
|||||||
# Get host-specific thresholds
|
# Get host-specific thresholds
|
||||||
thresholds = self.get_thresholds_for_host(host_name)
|
thresholds = self.get_thresholds_for_host(host_name)
|
||||||
|
|
||||||
|
# ZFS pool health checks
|
||||||
|
if plugin_name == "zfs_monitor" and "pools" in data:
|
||||||
|
pools = data["pools"]
|
||||||
|
if isinstance(pools, dict):
|
||||||
|
for pool_name, pool_metrics in pools.items():
|
||||||
|
if not isinstance(pool_metrics, dict):
|
||||||
|
continue
|
||||||
|
# Synthesize status from health string for older clients
|
||||||
|
# that predate the status field.
|
||||||
|
pool_metrics_effective = dict(pool_metrics)
|
||||||
|
if "health" in pool_metrics and "status" not in pool_metrics:
|
||||||
|
pool_metrics_effective["status"] = 0 if pool_metrics["health"] == "ONLINE" else 1
|
||||||
|
for metric_name, value in pool_metrics_effective.items():
|
||||||
|
# Try specific pool name first, then wildcard '*'
|
||||||
|
metric_path = f"{plugin_name}.{pool_name}.{metric_name}"
|
||||||
|
wildcard_path = f"{plugin_name}.*.{metric_name}"
|
||||||
|
threshold = thresholds.get(metric_path) or thresholds.get(wildcard_path)
|
||||||
|
if threshold is None:
|
||||||
|
continue
|
||||||
|
if metric_path not in alert_states:
|
||||||
|
alert_states[metric_path] = AlertState(metric_path)
|
||||||
|
alert_state = alert_states[metric_path]
|
||||||
|
new_level = threshold.evaluate_with_hysteresis(value, alert_state.level)
|
||||||
|
threshold_value = None
|
||||||
|
if new_level == AlertLevel.CRITICAL and threshold.critical is not None:
|
||||||
|
threshold_value = threshold.critical
|
||||||
|
elif new_level == AlertLevel.WARNING and threshold.warning is not None:
|
||||||
|
threshold_value = threshold.warning
|
||||||
|
alert_state.hysteresis = threshold.hysteresis if new_level != AlertLevel.OK else None
|
||||||
|
pool_context = dict(pool_metrics_effective)
|
||||||
|
pool_context["pool_name"] = pool_name
|
||||||
|
old_level = alert_state.level
|
||||||
|
if alert_state.update(new_level, value, threshold_value, threshold.operator.value):
|
||||||
|
state_changes.append((metric_path, old_level, new_level, value))
|
||||||
|
self._apply_grace(host_name, alert_state, metric_path, old_level, new_level, value, threshold, pool_context, metric_name=pool_name)
|
||||||
|
elif new_level != AlertLevel.OK:
|
||||||
|
self._check_pending_or_renotify(host_name, alert_state, metric_path, value, threshold, pool_context, metric_name=pool_name)
|
||||||
|
|
||||||
# Look for partition data in disk_monitor
|
# Look for partition data in disk_monitor
|
||||||
if plugin_name == "disk_monitor" and "partitions" in data:
|
if plugin_name == "disk_monitor" and "partitions" in data:
|
||||||
partitions = data["partitions"]
|
partitions = data["partitions"]
|
||||||
@@ -1044,8 +1135,8 @@ class ThresholdChecker:
|
|||||||
# Format operator symbol
|
# Format operator symbol
|
||||||
op_symbol = threshold.operator.value
|
op_symbol = threshold.operator.value
|
||||||
|
|
||||||
# Short metric label: strip the plugin-name prefix for readability
|
# Short metric label: strip the plugin-name prefix and _status_code suffix
|
||||||
short_path = metric_path.partition(".")[2] or metric_path
|
short_path = (metric_path.partition(".")[2] or metric_path).removesuffix("_status_code")
|
||||||
|
|
||||||
# Use a display-friendly value (inf is the sentinel for "overdue")
|
# Use a display-friendly value (inf is the sentinel for "overdue")
|
||||||
import math
|
import math
|
||||||
@@ -1109,11 +1200,16 @@ class ThresholdChecker:
|
|||||||
if host is not None and not host.watched:
|
if host is not None and not host.watched:
|
||||||
eventlog(host_name, lvl, message, service="threshold")
|
eventlog(host_name, lvl, message, service="threshold")
|
||||||
return
|
return
|
||||||
|
short_path = (metric_path.partition(".")[2] or metric_path).removesuffix("_status_code")
|
||||||
|
title = f"[{lvl}] {host_name} {short_path}"
|
||||||
|
# Strip the "metric = " prefix from message so body is just the value/detail
|
||||||
|
prefix = short_path + " = "
|
||||||
|
body = message[len(prefix):] if message.startswith(prefix) else message
|
||||||
asyncio.get_event_loop().create_task(notify_mod.send_notification(
|
asyncio.get_event_loop().create_task(notify_mod.send_notification(
|
||||||
host_name,
|
host_name,
|
||||||
notify_mod.Notification(
|
notify_mod.Notification(
|
||||||
title=f"[{lvl}] {host_name}",
|
title=title,
|
||||||
body=message,
|
body=body,
|
||||||
level=lvl,
|
level=lvl,
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
@@ -1297,6 +1393,17 @@ class ThresholdChecker:
|
|||||||
else:
|
else:
|
||||||
self._check_renotify(host_name, alert_state, metric_path, value, threshold, plugin_data, check_name=check_name, metric_name=metric_name)
|
self._check_renotify(host_name, alert_state, metric_path, value, threshold, plugin_data, check_name=check_name, metric_name=metric_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _human_duration(seconds: float) -> str:
|
||||||
|
s = int(seconds)
|
||||||
|
if s < 120:
|
||||||
|
return f"{s}s"
|
||||||
|
if s < 3600:
|
||||||
|
return f"{s // 60}m {s % 60}s"
|
||||||
|
h, rem = divmod(s, 3600)
|
||||||
|
m = rem // 60
|
||||||
|
return f"{h}h {m}m" if m else f"{h}h"
|
||||||
|
|
||||||
def _check_renotify(
|
def _check_renotify(
|
||||||
self,
|
self,
|
||||||
host_name: str,
|
host_name: str,
|
||||||
@@ -1344,7 +1451,7 @@ class ThresholdChecker:
|
|||||||
|
|
||||||
# Format operator symbol
|
# Format operator symbol
|
||||||
op_symbol = threshold.operator.value
|
op_symbol = threshold.operator.value
|
||||||
short_path = metric_path.partition(".")[2] or metric_path
|
short_path = (metric_path.partition(".")[2] or metric_path).removesuffix("_status_code")
|
||||||
|
|
||||||
# Time to re-notify
|
# Time to re-notify
|
||||||
if threshold_value is not None:
|
if threshold_value is not None:
|
||||||
@@ -1358,9 +1465,10 @@ class ThresholdChecker:
|
|||||||
check_name=check_name,
|
check_name=check_name,
|
||||||
metric_name=metric_name,
|
metric_name=metric_name,
|
||||||
)
|
)
|
||||||
message = f"REMINDER ({alert_state.level.name}): {host_name} - {short_path} = {value} {threshold_info}, ongoing for {int(now - alert_state.since)}s"
|
body = f"{value} {threshold_info}, ongoing for {self._human_duration(now - alert_state.since)}"
|
||||||
else:
|
else:
|
||||||
message = f"REMINDER ({alert_state.level.name}): {host_name} - {short_path} = {value} (ongoing for {int(now - alert_state.since)}s)"
|
body = f"{value} (ongoing for {self._human_duration(now - alert_state.since)})"
|
||||||
|
message = f"REMINDER ({alert_state.level.name}): {host_name} - {short_path} = {body}"
|
||||||
|
|
||||||
from . import hbdclass
|
from . import hbdclass
|
||||||
host = hbdclass.Host.hosts.get(host_name)
|
host = hbdclass.Host.hosts.get(host_name)
|
||||||
@@ -1368,8 +1476,8 @@ class ThresholdChecker:
|
|||||||
asyncio.get_event_loop().create_task(notify_mod.send_notification(
|
asyncio.get_event_loop().create_task(notify_mod.send_notification(
|
||||||
host_name,
|
host_name,
|
||||||
notify_mod.Notification(
|
notify_mod.Notification(
|
||||||
title=f"[REMINDER/{alert_state.level.name}] {host_name}",
|
title=f"[REMINDER/{alert_state.level.name}] {host_name} {short_path}",
|
||||||
body=message,
|
body=body,
|
||||||
level=alert_state.level.name,
|
level=alert_state.level.name,
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
|
|||||||
+11
-1
@@ -350,8 +350,10 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
|
|||||||
|
|
||||||
if msg.get("ID") == "HTB":
|
if msg.get("ID") == "HTB":
|
||||||
host.doesack = msg.get("acks", -1)
|
host.doesack = msg.get("acks", -1)
|
||||||
# send ACK back
|
# send ACK back; ask client to resend plugin info when we have none yet
|
||||||
rmsg = {"time": time.time()}
|
rmsg = {"time": time.time()}
|
||||||
|
if not host.plugin_data:
|
||||||
|
rmsg["request_update"] = 1
|
||||||
opkt = dicttos("ACK", rmsg)
|
opkt = dicttos("ACK", rmsg)
|
||||||
try:
|
try:
|
||||||
transport.sendto(opkt, addr)
|
transport.sendto(opkt, addr)
|
||||||
@@ -368,6 +370,14 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
|
|||||||
if k not in ("ID", "plugin", "id", "name")}
|
if k not in ("ID", "plugin", "id", "name")}
|
||||||
# Store plugin data with timestamp
|
# Store plugin data with timestamp
|
||||||
host.add_plugin_data(plugin_name, plugin_data, timestamp=now)
|
host.add_plugin_data(plugin_name, plugin_data, timestamp=now)
|
||||||
|
|
||||||
|
# If os_info reports an owner and none is configured server-side, apply it
|
||||||
|
if plugin_name == "os_info":
|
||||||
|
config_owner = config_mod.get_host_access(cfg, uname).get("owner")
|
||||||
|
default_owner = config_mod.get_default_owner(cfg)
|
||||||
|
inferred_owner = plugin_data.get("owner", config_owner or default_owner)
|
||||||
|
host.owner = inferred_owner
|
||||||
|
logger.info(f"owner for {uname} is '{host.owner}")
|
||||||
if DEBUG > 1:
|
if DEBUG > 1:
|
||||||
print(f"Stored plugin data for {uname}: {plugin_name}")
|
print(f"Stored plugin data for {uname}: {plugin_name}")
|
||||||
|
|
||||||
|
|||||||
@@ -146,9 +146,14 @@ def load_users(config: dict) -> dict:
|
|||||||
Returns the new ``users`` dict.
|
Returns the new ``users`` dict.
|
||||||
"""
|
"""
|
||||||
global users
|
global users
|
||||||
|
old_users = dict(users) # snapshot before rebuild
|
||||||
users_cfg = config.get("users", {})
|
users_cfg = config.get("users", {})
|
||||||
if not isinstance(users_cfg, dict):
|
if not isinstance(users_cfg, dict):
|
||||||
users = {}
|
users = {}
|
||||||
|
# Preserve OAuth-provisioned users (password_hash == "") that aren't in config.
|
||||||
|
for username, existing_user in old_users.items():
|
||||||
|
if not existing_user.password_hash and username not in users:
|
||||||
|
users[username] = existing_user
|
||||||
return users
|
return users
|
||||||
|
|
||||||
result: dict = {}
|
result: dict = {}
|
||||||
@@ -166,6 +171,10 @@ def load_users(config: dict) -> dict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
users = result
|
users = result
|
||||||
|
# Preserve OAuth-provisioned users (password_hash == "") that aren't in config.
|
||||||
|
for username, existing_user in old_users.items():
|
||||||
|
if not existing_user.password_hash and username not in users:
|
||||||
|
users[username] = existing_user
|
||||||
logger.info("Loaded %d user(s) from config", len(users))
|
logger.info("Loaded %d user(s) from config", len(users))
|
||||||
return users
|
return users
|
||||||
|
|
||||||
@@ -187,6 +196,26 @@ def authenticate(username: str, password: str) -> "User | None":
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def provision_oauth_user(username: str, full_name: str, avatar: str) -> "User":
|
||||||
|
"""Create or update a user sourced from an OAuth2 provider.
|
||||||
|
|
||||||
|
New users are inserted with no password_hash — they can only authenticate
|
||||||
|
via OAuth. Existing users (e.g. defined in config with a password) have
|
||||||
|
their display name and avatar refreshed; all other attributes are preserved.
|
||||||
|
"""
|
||||||
|
user = users.get(username)
|
||||||
|
if user is None:
|
||||||
|
user = User(username=username, full_name=full_name, avatar=avatar)
|
||||||
|
users[username] = user
|
||||||
|
logger.info("Provisioned OAuth user %r", username)
|
||||||
|
else:
|
||||||
|
if full_name:
|
||||||
|
user.full_name = full_name
|
||||||
|
if avatar:
|
||||||
|
user.avatar = avatar
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Session management
|
# Session management
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
+5
-1
@@ -85,10 +85,12 @@ async def handler(request):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending initial hosts: %s", e)
|
logger.error("Error sending initial hosts: %s", e)
|
||||||
|
|
||||||
# Send recent messages
|
# Send recent messages, filtered to hosts this user may see
|
||||||
if data.msgs:
|
if data.msgs:
|
||||||
try:
|
try:
|
||||||
for m in data.msgs:
|
for m in data.msgs:
|
||||||
|
host_name = m.get("host") if isinstance(m, dict) else None
|
||||||
|
if not host_name or _user_can_see_host(user, host_name):
|
||||||
await ws.send_str(json.dumps({"type": "message", "data": m}))
|
await ws.send_str(json.dumps({"type": "message", "data": m}))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending initial messages: %s", e)
|
logger.error("Error sending initial messages: %s", e)
|
||||||
@@ -128,6 +130,8 @@ def broadcast(typ: str, payload) -> bool:
|
|||||||
host_name: Optional[str] = None
|
host_name: Optional[str] = None
|
||||||
if typ in ("host", "plugin"):
|
if typ in ("host", "plugin"):
|
||||||
host_name = payload.get("raw_name") or payload.get("host") or payload.get("name")
|
host_name = payload.get("raw_name") or payload.get("host") or payload.get("name")
|
||||||
|
elif typ == "message" and isinstance(payload, dict):
|
||||||
|
host_name = payload.get("host")
|
||||||
|
|
||||||
jmsg = json.dumps({"type": typ, "data": payload})
|
jmsg = json.dumps({"type": typ, "data": payload})
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "hbd"
|
name = "hbd"
|
||||||
version = "5.2.1"
|
version = "5.2.6"
|
||||||
description = "Heartbeat monitoring system — client (hbc) and server (hbd)"
|
description = "Heartbeat monitoring system — client (hbc) and server (hbd)"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
hbc_mini
|
||||||
|
hbc_mini_dbg
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
CC ?= cc
|
||||||
|
CFLAGS = -O2 -Wall -Wextra -std=c11
|
||||||
|
LDFLAGS = -lz -lpthread -lm
|
||||||
|
TARGET = hbc_mini
|
||||||
|
SRC = hbc_mini.c
|
||||||
|
|
||||||
|
# FreeBSD/NetBSD keep zlib in base; no extra flags needed.
|
||||||
|
# On some NetBSD installs pthreads may need -lpthread from pkgsrc.
|
||||||
|
|
||||||
|
.PHONY: all clean debug
|
||||||
|
|
||||||
|
all: $(TARGET)
|
||||||
|
|
||||||
|
$(TARGET): $(SRC)
|
||||||
|
$(CC) $(CFLAGS) -o $@ $< $(LDFLAGS)
|
||||||
|
|
||||||
|
debug: $(SRC)
|
||||||
|
$(CC) -g -fsanitize=address,undefined -o $(TARGET)_dbg $< $(LDFLAGS)
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f $(TARGET) $(TARGET)_dbg
|
||||||
File diff suppressed because it is too large
Load Diff
+36
-15
@@ -41,7 +41,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
# updated by scripts/bumpminor.sh
|
# updated by scripts/bumpminor.sh
|
||||||
__version__ = "5.2.1"
|
__version__ = "5.2.6"
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Protocol (mirrors hbd/common/proto.py)
|
# Protocol (mirrors hbd/common/proto.py)
|
||||||
@@ -114,6 +114,7 @@ def _stodict(data: bytes) -> Dict[str, Any]:
|
|||||||
_DEFAULTS: Dict[str, Any] = {
|
_DEFAULTS: Dict[str, Any] = {
|
||||||
"hb_port": 50003,
|
"hb_port": 50003,
|
||||||
"interval": 10,
|
"interval": 10,
|
||||||
|
"owner": None,
|
||||||
"plugins": {},
|
"plugins": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,6 +240,8 @@ class OSInfoPlugin(InfoPlugin):
|
|||||||
"hbc_version": __version__,
|
"hbc_version": __version__,
|
||||||
"hbc_type": "mini",
|
"hbc_type": "mini",
|
||||||
}
|
}
|
||||||
|
if self.config.get("owner"):
|
||||||
|
data["owner"] = self.config["owner"]
|
||||||
if platform.system() == "Linux":
|
if platform.system() == "Linux":
|
||||||
data.update(_linux_distro())
|
data.update(_linux_distro())
|
||||||
elif platform.system() == "Darwin":
|
elif platform.system() == "Darwin":
|
||||||
@@ -716,7 +719,9 @@ async def _load_plugins(cfg: Dict[str, Any]) -> List[Plugin]:
|
|||||||
plugins_cfg: Dict[str, Any] = cfg.get("plugins", {})
|
plugins_cfg: Dict[str, Any] = cfg.get("plugins", {})
|
||||||
loaded: List[Plugin] = []
|
loaded: List[Plugin] = []
|
||||||
for cls in _ALL_PLUGIN_CLASSES:
|
for cls in _ALL_PLUGIN_CLASSES:
|
||||||
plugin_cfg = plugins_cfg.get(cls.name) or cfg.get(cls.name, {})
|
plugin_cfg = dict(plugins_cfg.get(cls.name) or cfg.get(cls.name) or {})
|
||||||
|
if "owner" in cfg and "owner" not in plugin_cfg:
|
||||||
|
plugin_cfg["owner"] = cfg["owner"]
|
||||||
plugin: Plugin = cls(config=plugin_cfg)
|
plugin: Plugin = cls(config=plugin_cfg)
|
||||||
try:
|
try:
|
||||||
ok = await plugin.initialize()
|
ok = await plugin.initialize()
|
||||||
@@ -786,7 +791,7 @@ class _HeartbeatProtocol(asyncio.DatagramProtocol):
|
|||||||
msg_id = msg.get("ID")
|
msg_id = msg.get("ID")
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if msg_id == "ACK":
|
if msg_id == "ACK":
|
||||||
self._conn._handle_ack(now)
|
self._conn._handle_ack(msg, now)
|
||||||
elif msg_id == "CMD":
|
elif msg_id == "CMD":
|
||||||
asyncio.create_task(_handle_command(self._conn, msg))
|
asyncio.create_task(_handle_command(self._conn, msg))
|
||||||
elif msg_id == "UPD":
|
elif msg_id == "UPD":
|
||||||
@@ -797,8 +802,7 @@ class _HeartbeatProtocol(asyncio.DatagramProtocol):
|
|||||||
self._log.error("datagram error: %s", e)
|
self._log.error("datagram error: %s", e)
|
||||||
|
|
||||||
def error_received(self, exc):
|
def error_received(self, exc):
|
||||||
self._log.warning("protocol error on %s: %s — dropping connection", self._conn.addr, exc)
|
self._log.warning("protocol error on %s: %s — will retry", self._conn.addr, exc)
|
||||||
self._conn._dead = True
|
|
||||||
self._conn.close()
|
self._conn.close()
|
||||||
|
|
||||||
|
|
||||||
@@ -814,6 +818,7 @@ class AsyncConnection:
|
|||||||
self.rtts: List[float] = [0.0]
|
self.rtts: List[float] = [0.0]
|
||||||
self._transport: Optional[asyncio.DatagramTransport] = None
|
self._transport: Optional[asyncio.DatagramTransport] = None
|
||||||
self._dead = False
|
self._dead = False
|
||||||
|
self._request_info: asyncio.Event = asyncio.Event()
|
||||||
self._log = logging.getLogger(f"hbc.conn.{addr}")
|
self._log = logging.getLogger(f"hbc.conn.{addr}")
|
||||||
|
|
||||||
async def open(self) -> bool:
|
async def open(self) -> bool:
|
||||||
@@ -832,12 +837,14 @@ class AsyncConnection:
|
|||||||
self._transport.close()
|
self._transport.close()
|
||||||
self._transport = None
|
self._transport = None
|
||||||
|
|
||||||
def _handle_ack(self, now: float):
|
def _handle_ack(self, msg: Dict[str, Any], now: float):
|
||||||
rtt = (now - self.lastsend) * 1000.0
|
rtt = (now - self.lastsend) * 1000.0
|
||||||
self.rtts.append(rtt)
|
self.rtts.append(rtt)
|
||||||
if len(self.rtts) > 10:
|
if len(self.rtts) > 10:
|
||||||
self.rtts.pop(0)
|
self.rtts.pop(0)
|
||||||
self.ackcount += 1
|
self.ackcount += 1
|
||||||
|
if msg.get("request_update"):
|
||||||
|
self._request_info.set()
|
||||||
|
|
||||||
async def sendto(self, msg: Dict[str, Any], msg_id: str = "HTB"):
|
async def sendto(self, msg: Dict[str, Any], msg_id: str = "HTB"):
|
||||||
if self._dead:
|
if self._dead:
|
||||||
@@ -970,6 +977,19 @@ async def _run_monitor_group(conn: AsyncConnection, plugins: List[Plugin], inter
|
|||||||
await _sleep(interval)
|
await _sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
|
async def _info_refresh_loop(conn: AsyncConnection, info: List[Plugin]):
|
||||||
|
log = logging.getLogger("hbc.plugins")
|
||||||
|
while _running:
|
||||||
|
await conn._request_info.wait()
|
||||||
|
if not _running:
|
||||||
|
break
|
||||||
|
conn._request_info.clear()
|
||||||
|
log.info("refreshing InfoPlugins on server request")
|
||||||
|
for plugin in info:
|
||||||
|
plugin._cache = None
|
||||||
|
await _run_info_plugins(conn, info)
|
||||||
|
|
||||||
|
|
||||||
async def _plugin_collector(conn: AsyncConnection, plugins: List[Plugin]):
|
async def _plugin_collector(conn: AsyncConnection, plugins: List[Plugin]):
|
||||||
info = [p for p in plugins if isinstance(p, InfoPlugin)]
|
info = [p for p in plugins if isinstance(p, InfoPlugin)]
|
||||||
monitor = [p for p in plugins if isinstance(p, MonitorPlugin)]
|
monitor = [p for p in plugins if isinstance(p, MonitorPlugin)]
|
||||||
@@ -980,12 +1000,10 @@ async def _plugin_collector(conn: AsyncConnection, plugins: List[Plugin]):
|
|||||||
for p in monitor:
|
for p in monitor:
|
||||||
by_interval[p.interval].append(p)
|
by_interval[p.interval].append(p)
|
||||||
|
|
||||||
if by_interval:
|
tasks = [asyncio.create_task(_info_refresh_loop(conn, info))]
|
||||||
await asyncio.gather(
|
tasks += [asyncio.create_task(_run_monitor_group(conn, grp, iv))
|
||||||
*[asyncio.create_task(_run_monitor_group(conn, grp, iv))
|
for iv, grp in by_interval.items()]
|
||||||
for iv, grp in by_interval.items()],
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
return_exceptions=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -1029,7 +1047,7 @@ def _reconfigure_syslog(level: int):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
async def _async_main(args, cfg: Dict[str, Any]) -> int:
|
async def _async_main(args, cfg: Dict[str, Any]) -> int:
|
||||||
global _running, _shutdown_event, _active_tasks
|
global _running, _shutdown_event, _active_tasks, send_shutdown
|
||||||
_running = True
|
_running = True
|
||||||
_shutdown_event = asyncio.Event()
|
_shutdown_event = asyncio.Event()
|
||||||
_active_tasks = []
|
_active_tasks = []
|
||||||
@@ -1039,7 +1057,7 @@ async def _async_main(args, cfg: Dict[str, Any]) -> int:
|
|||||||
port = cfg.get("hb_port", PORT)
|
port = cfg.get("hb_port", PORT)
|
||||||
interval = cfg.get("interval", INTERVAL)
|
interval = cfg.get("interval", INTERVAL)
|
||||||
|
|
||||||
log.info("starting: %s -> %s port=%d interval=%ds", iam, args.hosts, port, interval)
|
log.info("hbc_mini %s on %s -> %s port=%d interval=%ds",__version__, iam, args.hosts, port, interval)
|
||||||
|
|
||||||
connections: List[AsyncConnection] = []
|
connections: List[AsyncConnection] = []
|
||||||
conn_id = 1
|
conn_id = 1
|
||||||
@@ -1060,10 +1078,13 @@ async def _async_main(args, cfg: Dict[str, Any]) -> int:
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
# Boot / one-shot message
|
# Boot / one-shot message
|
||||||
|
send_shutdown = False
|
||||||
if args.boot or args.message:
|
if args.boot or args.message:
|
||||||
bmsg: Dict[str, Any] = {"acks": 0}
|
bmsg: Dict[str, Any] = {"acks": 0}
|
||||||
if args.boot:
|
if args.boot:
|
||||||
bmsg["boot"] = 1
|
bmsg["boot"] = 1
|
||||||
|
args.boot = False # don't repeat on restart
|
||||||
|
send_shutdown = True
|
||||||
if args.message:
|
if args.message:
|
||||||
bmsg["service"] = "service"
|
bmsg["service"] = "service"
|
||||||
bmsg["msg"] = args.message
|
bmsg["msg"] = args.message
|
||||||
@@ -1101,7 +1122,7 @@ async def _async_main(args, cfg: Dict[str, Any]) -> int:
|
|||||||
|
|
||||||
log.info("shutting down")
|
log.info("shutting down")
|
||||||
target = next((c for c in connections if c._transport), connections[0] if connections else None)
|
target = next((c for c in connections if c._transport), connections[0] if connections else None)
|
||||||
if target:
|
if target and send_shutdown:
|
||||||
try:
|
try:
|
||||||
await target.sendto({"shutdown": 1, "acks": target.ackcount})
|
await target.sendto({"shutdown": 1, "acks": target.ackcount})
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -0,0 +1,324 @@
|
|||||||
|
import time as time_mod
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from hbd.server import oauth
|
||||||
|
from hbd.server import users as users_mod
|
||||||
|
from hbd.server.users import User
|
||||||
|
|
||||||
|
|
||||||
|
CFG_OFF = {}
|
||||||
|
CFG_ON = {
|
||||||
|
"oauth": {
|
||||||
|
"gitea": {
|
||||||
|
"url": "https://git.example.com",
|
||||||
|
"client_id": "cid",
|
||||||
|
"client_secret": "csec",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CFG_PARTIAL = {"oauth": {"gitea": {"url": "https://git.example.com"}}}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_oauth_states():
|
||||||
|
oauth._states.clear()
|
||||||
|
yield
|
||||||
|
oauth._states.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_users_dict():
|
||||||
|
original = dict(users_mod.users)
|
||||||
|
yield
|
||||||
|
users_mod.users = original
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_enabled_when_all_keys_present():
|
||||||
|
assert oauth.is_enabled(CFG_ON) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_enabled_false_when_no_oauth_key():
|
||||||
|
assert oauth.is_enabled(CFG_OFF) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_enabled_false_when_partial_config():
|
||||||
|
assert oauth.is_enabled(CFG_PARTIAL) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_state_returns_unique_tokens():
|
||||||
|
s1 = oauth.make_state()
|
||||||
|
s2 = oauth.make_state()
|
||||||
|
assert s1 != s2
|
||||||
|
assert len(s1) == 64 # 32 bytes hex
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_valid():
|
||||||
|
state = oauth.make_state()
|
||||||
|
assert oauth.validate_state(state) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_consumed_on_use():
|
||||||
|
state = oauth.make_state()
|
||||||
|
oauth.validate_state(state)
|
||||||
|
assert oauth.validate_state(state) is False # replay rejected
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_unknown():
|
||||||
|
assert oauth.validate_state("notastate") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_state_expired(monkeypatch):
|
||||||
|
state = oauth.make_state()
|
||||||
|
# Wind expiry into the past
|
||||||
|
monkeypatch.setitem(oauth._states, state, time_mod.time() - 1000)
|
||||||
|
assert oauth.validate_state(state) is False
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_users(entries=None):
|
||||||
|
users_mod.users = entries or {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_oauth_user_new():
|
||||||
|
_reset_users()
|
||||||
|
user = users_mod.provision_oauth_user("gituser", "Git User", "https://example.com/avatar.png")
|
||||||
|
assert user.username == "gituser"
|
||||||
|
assert user.full_name == "Git User"
|
||||||
|
assert user.avatar == "https://example.com/avatar.png"
|
||||||
|
assert user.admin is False
|
||||||
|
assert user.password_hash == ""
|
||||||
|
assert "gituser" in users_mod.users
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_oauth_user_no_password_login():
|
||||||
|
_reset_users()
|
||||||
|
user = users_mod.provision_oauth_user("gituser", "Git User", "")
|
||||||
|
assert user.check_password("anything") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_oauth_user_existing_updates_profile():
|
||||||
|
existing = User(
|
||||||
|
username="alice",
|
||||||
|
full_name="Old Name",
|
||||||
|
avatar="old.png",
|
||||||
|
password_hash="pbkdf2:sha256:1:salt:abc",
|
||||||
|
admin=True,
|
||||||
|
notification_channels=["chan1"],
|
||||||
|
)
|
||||||
|
_reset_users({"alice": existing})
|
||||||
|
user = users_mod.provision_oauth_user("alice", "New Name", "new.png")
|
||||||
|
assert user.full_name == "New Name"
|
||||||
|
assert user.avatar == "new.png"
|
||||||
|
# Preserved
|
||||||
|
assert user.admin is True
|
||||||
|
assert user.password_hash == "pbkdf2:sha256:1:salt:abc"
|
||||||
|
assert user.notification_channels == ["chan1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_oauth_user_does_not_overwrite_with_empty():
|
||||||
|
existing = User(username="bob", full_name="Bob", avatar="bob.png")
|
||||||
|
_reset_users({"bob": existing})
|
||||||
|
user = users_mod.provision_oauth_user("bob", "", "")
|
||||||
|
assert user.full_name == "Bob"
|
||||||
|
assert user.avatar == "bob.png"
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_oauth_user_survives_config_reload():
|
||||||
|
_reset_users()
|
||||||
|
users_mod.provision_oauth_user("oauthonly", "OAuth Only", "https://example.com/a.png")
|
||||||
|
assert "oauthonly" in users_mod.users
|
||||||
|
# Reload with empty config — OAuth user should survive
|
||||||
|
users_mod.load_users({})
|
||||||
|
assert "oauthonly" in users_mod.users
|
||||||
|
|
||||||
|
|
||||||
|
def test_authorization_url_shape():
|
||||||
|
state = "teststate"
|
||||||
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||||||
|
url = oauth.authorization_url(CFG_ON, state, redirect_uri)
|
||||||
|
parsed = urlparse(url)
|
||||||
|
qs = parse_qs(parsed.query)
|
||||||
|
assert parsed.scheme == "https"
|
||||||
|
assert parsed.netloc == "git.example.com"
|
||||||
|
assert parsed.path == "/login/oauth/authorize"
|
||||||
|
assert qs["client_id"] == ["cid"]
|
||||||
|
assert qs["state"] == ["teststate"]
|
||||||
|
assert qs["redirect_uri"] == [redirect_uri]
|
||||||
|
assert qs["scope"] == ["user:email"]
|
||||||
|
assert qs["response_type"] == ["code"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_returns_token():
|
||||||
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={"access_token": "tok123"})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
token = await oauth.exchange_code(CFG_ON, "mycode", redirect_uri)
|
||||||
|
assert token == "tok123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_raises_on_error_status():
|
||||||
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 401
|
||||||
|
mock_response.text = AsyncMock(return_value="unauthorized")
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
with pytest.raises(oauth.OAuthError):
|
||||||
|
await oauth.exchange_code(CFG_ON, "badcode", redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_user_returns_profile():
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={
|
||||||
|
"login": "alice",
|
||||||
|
"full_name": "Alice Smith",
|
||||||
|
"avatar_url": "https://git.example.com/avatars/alice.png",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
profile = await oauth.fetch_user(CFG_ON, "tok123")
|
||||||
|
assert profile == {
|
||||||
|
"login": "alice",
|
||||||
|
"full_name": "Alice Smith",
|
||||||
|
"avatar_url": "https://git.example.com/avatars/alice.png",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_raises_when_no_access_token():
|
||||||
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={"error": "bad_request"})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
with pytest.raises(oauth.OAuthError):
|
||||||
|
await oauth.exchange_code(CFG_ON, "mycode", redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_user_raises_on_error_status():
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 401
|
||||||
|
mock_response.text = AsyncMock(return_value="unauthorized")
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
with pytest.raises(oauth.OAuthError):
|
||||||
|
await oauth.fetch_user(CFG_ON, "tok123")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration-style tests: callback logic chain
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_invalid_state_rejects():
|
||||||
|
"""Verify validate_state returns False for unknown state tokens."""
|
||||||
|
fake_state = "this-is-not-a-real-state"
|
||||||
|
assert oauth.validate_state(fake_state) is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_oauth_flow_chain():
|
||||||
|
"""Integration-style test: state → exchange → fetch → provision chain."""
|
||||||
|
redirect_uri = "https://hbd.example.com/login/oauth/gitea/callback"
|
||||||
|
|
||||||
|
# Step 1: create a state token
|
||||||
|
state = oauth.make_state()
|
||||||
|
assert oauth.validate_state(state) is True # consumed; replay would return False
|
||||||
|
|
||||||
|
# Step 2: exchange code → token (mocked)
|
||||||
|
mock_token_response = AsyncMock()
|
||||||
|
mock_token_response.status = 200
|
||||||
|
mock_token_response.json = AsyncMock(return_value={"access_token": "flow_token"})
|
||||||
|
|
||||||
|
mock_user_response = AsyncMock()
|
||||||
|
mock_user_response.status = 200
|
||||||
|
mock_user_response.json = AsyncMock(return_value={
|
||||||
|
"login": "flowuser",
|
||||||
|
"full_name": "Flow User",
|
||||||
|
"avatar_url": "https://git.example.com/avatars/flow.png",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.post = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_token_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
mock_session.get = MagicMock(return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_user_response),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("hbd.server.oauth.aiohttp.ClientSession", return_value=AsyncMock(
|
||||||
|
__aenter__=AsyncMock(return_value=mock_session),
|
||||||
|
__aexit__=AsyncMock(return_value=False),
|
||||||
|
)):
|
||||||
|
token = await oauth.exchange_code(CFG_ON, "authcode", redirect_uri)
|
||||||
|
profile = await oauth.fetch_user(CFG_ON, token)
|
||||||
|
|
||||||
|
assert token == "flow_token"
|
||||||
|
assert profile["login"] == "flowuser"
|
||||||
|
|
||||||
|
# Step 3: provision user
|
||||||
|
_reset_users()
|
||||||
|
user = users_mod.provision_oauth_user(
|
||||||
|
profile["login"], profile["full_name"], profile["avatar_url"]
|
||||||
|
)
|
||||||
|
assert user.username == "flowuser"
|
||||||
|
assert user.check_password("anything") is False
|
||||||
Reference in New Issue
Block a user