Files
gandalf/db.py

366 lines
12 KiB
Python
Raw Normal View History

"""Database operations for Gandalf network monitor."""
import json
import logging
import threading
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Optional
import pymysql
import pymysql.cursors
logger = logging.getLogger(__name__)
_config_cache = None
_local = threading.local()
def _config() -> dict:
global _config_cache
if _config_cache is None:
with open('config.json') as f:
_config_cache = json.load(f)['database']
return _config_cache
@contextmanager
def get_conn():
"""Yield a per-thread cached database connection, reconnecting as needed."""
cfg = _config()
conn = getattr(_local, 'conn', None)
if conn is None:
conn = pymysql.connect(
host=cfg['host'],
port=cfg.get('port', 3306),
user=cfg['user'],
password=cfg['password'],
database=cfg['name'],
autocommit=True,
cursorclass=pymysql.cursors.DictCursor,
connect_timeout=10,
charset='utf8mb4',
)
_local.conn = conn
else:
conn.ping(reconnect=True)
yield conn
# ---------------------------------------------------------------------------
# Monitor state (key/value store)
# ---------------------------------------------------------------------------
def set_state(key: str, value) -> None:
if not isinstance(value, str):
value = json.dumps(value, default=str)
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""INSERT INTO monitor_state (key_name, value)
VALUES (%s, %s)
ON DUPLICATE KEY UPDATE value=VALUES(value), updated_at=NOW()""",
(key, value),
)
def get_state(key: str, default=None):
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute('SELECT value FROM monitor_state WHERE key_name=%s', (key,))
row = cur.fetchone()
return row['value'] if row else default
# ---------------------------------------------------------------------------
# Interface baseline tracking
# ---------------------------------------------------------------------------
def get_baseline() -> dict:
raw = get_state('interface_baseline')
if raw:
try:
return json.loads(raw)
except Exception:
logger.error('Failed to parse interface_baseline JSON; resetting baseline')
return {}
def set_baseline(baseline: dict) -> None:
set_state('interface_baseline', json.dumps(baseline))
# ---------------------------------------------------------------------------
# Network events
# ---------------------------------------------------------------------------
def upsert_event(
event_type: str,
severity: str,
source_type: str,
target_name: str,
target_detail: str,
description: str,
) -> tuple:
"""Insert or update a network event. Returns (id, is_new, consecutive_failures)."""
detail = target_detail or ''
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""SELECT id, consecutive_failures FROM network_events
WHERE event_type=%s AND target_name=%s AND target_detail=%s
AND resolved_at IS NULL LIMIT 1""",
(event_type, target_name, detail),
)
existing = cur.fetchone()
if existing:
new_count = existing['consecutive_failures'] + 1
cur.execute(
"""UPDATE network_events
SET last_seen=NOW(), consecutive_failures=%s, description=%s
WHERE id=%s""",
(new_count, description, existing['id']),
)
return existing['id'], False, new_count
else:
cur.execute(
"""INSERT INTO network_events
(event_type, severity, source_type, target_name, target_detail, description)
VALUES (%s, %s, %s, %s, %s, %s)""",
(event_type, severity, source_type, target_name, detail, description),
)
return cur.lastrowid, True, 1
def resolve_event(event_type: str, target_name: str, target_detail: str = '') -> None:
detail = target_detail or ''
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""UPDATE network_events SET resolved_at=NOW()
WHERE event_type=%s AND target_name=%s AND target_detail=%s
AND resolved_at IS NULL""",
(event_type, target_name, detail),
)
def set_ticket_id(event_id: int, ticket_id: str) -> None:
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
'UPDATE network_events SET ticket_id=%s WHERE id=%s',
(ticket_id, event_id),
)
def get_active_events(limit: int = 200, offset: int = 0) -> list:
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""SELECT * FROM network_events
WHERE resolved_at IS NULL
ORDER BY
FIELD(severity,'critical','warning','info'),
first_seen DESC
LIMIT %s OFFSET %s""",
(limit, offset),
)
rows = cur.fetchall()
for r in rows:
for k in ('first_seen', 'last_seen'):
if r.get(k) and hasattr(r[k], 'isoformat'):
r[k] = r[k].isoformat()
return rows
def count_active_events() -> int:
"""Return count of all unresolved events (for pagination)."""
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT COUNT(*) AS n FROM network_events WHERE resolved_at IS NULL"
)
return cur.fetchone()['n']
def get_recent_resolved(hours: int = 24, limit: int = 50) -> list:
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""SELECT * FROM network_events
WHERE resolved_at IS NOT NULL
AND resolved_at > DATE_SUB(NOW(), INTERVAL %s HOUR)
ORDER BY resolved_at DESC LIMIT %s""",
(hours, limit),
)
rows = cur.fetchall()
for r in rows:
for k in ('first_seen', 'last_seen', 'resolved_at'):
if r.get(k) and hasattr(r[k], 'isoformat'):
r[k] = r[k].isoformat()
return rows
def get_status_summary() -> dict:
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""SELECT severity, COUNT(*) as cnt FROM network_events
WHERE resolved_at IS NULL GROUP BY severity"""
)
counts = {r['severity']: r['cnt'] for r in cur.fetchall()}
return {
'critical': counts.get('critical', 0),
'warning': counts.get('warning', 0),
'info': counts.get('info', 0),
}
# ---------------------------------------------------------------------------
# Suppression rules
# ---------------------------------------------------------------------------
def get_active_suppressions() -> list:
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""SELECT * FROM suppression_rules
WHERE active=TRUE AND (expires_at IS NULL OR expires_at > NOW())
ORDER BY created_at DESC"""
)
rows = cur.fetchall()
for r in rows:
for k in ('created_at', 'expires_at'):
if r.get(k) and hasattr(r[k], 'isoformat'):
r[k] = r[k].isoformat()
return rows
def get_suppression_history(limit: int = 50) -> list:
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
'SELECT * FROM suppression_rules ORDER BY created_at DESC LIMIT %s',
(limit,),
)
rows = cur.fetchall()
for r in rows:
for k in ('created_at', 'expires_at'):
if r.get(k) and hasattr(r[k], 'isoformat'):
r[k] = r[k].isoformat()
return rows
def create_suppression(
target_type: str,
target_name: str,
target_detail: str,
reason: str,
suppressed_by: str,
expires_minutes: Optional[int] = None,
) -> int:
expires_at = None
if expires_minutes:
expires_at = datetime.utcnow() + timedelta(minutes=int(expires_minutes))
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""INSERT INTO suppression_rules
(target_type, target_name, target_detail, reason, suppressed_by, expires_at, active)
VALUES (%s, %s, %s, %s, %s, %s, TRUE)""",
(target_type, target_name or '', target_detail or '', reason, suppressed_by, expires_at),
)
return cur.lastrowid
def deactivate_suppression(sup_id: int) -> None:
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
'UPDATE suppression_rules SET active=FALSE WHERE id=%s', (sup_id,)
)
def cleanup_expired_suppressions() -> int:
"""Mark expired time-limited suppressions as inactive. Returns count deactivated."""
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""UPDATE suppression_rules
SET active=FALSE
WHERE active=TRUE AND expires_at IS NOT NULL AND expires_at <= NOW()"""
)
n = cur.rowcount
if n:
logger.info(f'Deactivated {n} expired suppression(s)')
return n
def purge_old_resolved_events(days: int = 90) -> int:
"""Delete resolved events older than `days` days. Returns count deleted."""
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""DELETE FROM network_events
WHERE resolved_at IS NOT NULL
AND resolved_at < DATE_SUB(NOW(), INTERVAL %s DAY)""",
(days,),
)
n = cur.rowcount
if n:
logger.info(f'Purged {n} old resolved event(s) (>{days}d)')
return n
def check_suppressed(suppressions: list, target_type: str, target_name: str, target_detail: str = '') -> bool:
"""Check suppression against a pre-loaded list (avoids per-call DB queries)."""
for s in suppressions:
if s['target_type'] == 'all':
return True
if s['target_type'] == target_type and s['target_name'] == target_name:
if not (s.get('target_detail') or ''):
return True
if target_detail and s.get('target_detail') == target_detail:
return True
return False
def is_suppressed(target_type: str, target_name: str, target_detail: str = '') -> bool:
with get_conn() as conn:
with conn.cursor() as cur:
# Global suppression (all)
cur.execute(
"""SELECT id FROM suppression_rules
WHERE active=TRUE AND (expires_at IS NULL OR expires_at > NOW())
AND target_type='all' LIMIT 1"""
)
if cur.fetchone():
return True
if not target_name:
return False
# Host-level suppression (covers all interfaces on that host)
cur.execute(
"""SELECT id FROM suppression_rules
WHERE active=TRUE AND (expires_at IS NULL OR expires_at > NOW())
AND target_type=%s AND target_name=%s
AND (target_detail IS NULL OR target_detail='') LIMIT 1""",
(target_type, target_name),
)
if cur.fetchone():
return True
# Interface/device-specific suppression
if target_detail:
cur.execute(
"""SELECT id FROM suppression_rules
WHERE active=TRUE AND (expires_at IS NULL OR expires_at > NOW())
AND target_type=%s AND target_name=%s AND target_detail=%s LIMIT 1""",
(target_type, target_name, target_detail),
)
if cur.fetchone():
return True
return False