Source code for database.db_manager

"""Database manager for storing and querying simulation data (SQLite)."""

import sqlite3
from contextlib import contextmanager
from typing import List, Optional, Tuple

VALID_TABLES = ("measurements", "forecast", "scheduling")

DEFAULT_METRICS = (
    ("power_active_consumption", "kW", "float"),
    ("power_active_production", "kW", "float"),
    ("power_active_import", "kW", "float"),
    ("power_active_export", "kW", "float"),
    ("available_power", "kW", "float"),
    ("contracted_power", "kW", "float"),
    ("power_active", "kW", "float"),
    ("power_setpoint", "kW", "float"),
    ("soc", "%", "float"),
    ("state", "", "str"),
    ("bess_discharge", "kW", "float"),
    ("bess_charge", "kW", "float"),
    ("energy_stored", "kWh", "float"),
    ("bess_available", "kWh", "float"),
    ("bess_capacity", "kWh", "float"),
    ("boat_required_energy", "kWh", "float"),
    ("boat_available", "", "int"),
    ("temperature", "°C", "float"),
    ("humidity", "%", "float"),
    ("dew_point", "°C", "float"),
    ("precipitation", "mm", "float"),
    ("weather_code", "", "int"),
    ("cloud_cover", "%", "float"),
    ("wind_speed", "m/s", "float"),
    ("wind_direction", "°", "float"),
    ("ghi", "W/m²", "float"),
    ("direct_radiation", "W/m²", "float"),
    ("dhi", "W/m²", "float"),
    ("dni", "W/m²", "float"),
    ("net_balance", "kW", "float"),
)


[docs] class DatabaseManager: """ Manages SQLite database operations for the simulator. Uses in-memory caches for source_id and metric_id to avoid repeated lookups. """ def __init__(self, db_path: str): """ Initialize the database manager. Args: db_path: Path to the SQLite database file. """ self.db_path = db_path self._connection: Optional[sqlite3.Connection] = None self._source_cache: dict[str, int] = {} self._metric_cache: dict[str, int] = {}
[docs] def connect(self): """Open a persistent database connection and set row_factory to sqlite3.Row.""" self._connection = sqlite3.connect(self.db_path) self._connection.row_factory = sqlite3.Row
[docs] def close(self): """Close the persistent database connection if open.""" if self._connection: self._connection.close() self._connection = None
[docs] @contextmanager def get_connection(self): """Context manager: yields a connection with foreign keys on; commits on success, rolls back on exception.""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row conn.execute("PRAGMA foreign_keys = ON") try: yield conn conn.commit() except Exception as e: conn.rollback() raise e finally: conn.close()
[docs] def initialize_schema(self): """Create all tables (source, metric, measurements, forecast, scheduling) and their indexes.""" with self.get_connection() as conn: cursor = conn.cursor() cursor.execute( """ CREATE TABLE IF NOT EXISTS source ( source_id INTEGER PRIMARY KEY AUTOINCREMENT, source_name TEXT NOT NULL UNIQUE, source_type TEXT NOT NULL ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS metric ( metric_id INTEGER PRIMARY KEY AUTOINCREMENT, metric_name TEXT NOT NULL UNIQUE, unit TEXT NOT NULL, data_type TEXT NOT NULL ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS measurements ( measurement_id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL, source_id INTEGER NOT NULL, metric_id INTEGER NOT NULL, value TEXT NOT NULL, FOREIGN KEY (source_id) REFERENCES source(source_id), FOREIGN KEY (metric_id) REFERENCES metric(metric_id) ) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_measurements_timestamp ON measurements(timestamp) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_measurements_source_id ON measurements(source_id) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_measurements_metric_id ON measurements(metric_id) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS forecast ( measurement_id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL, source_id INTEGER NOT NULL, metric_id INTEGER NOT NULL, value TEXT NOT NULL, FOREIGN KEY (source_id) REFERENCES source(source_id), FOREIGN KEY (metric_id) REFERENCES metric(metric_id) ) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_forecast_timestamp ON forecast(timestamp) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_forecast_source_id ON forecast(source_id) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_forecast_metric_id ON forecast(metric_id) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS scheduling ( measurement_id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL, source_id INTEGER NOT NULL, metric_id INTEGER NOT NULL, value TEXT NOT NULL, FOREIGN KEY (source_id) REFERENCES source(source_id), FOREIGN KEY (metric_id) REFERENCES metric(metric_id) ) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_scheduling_timestamp ON scheduling(timestamp) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_scheduling_source_id ON scheduling(source_id) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_scheduling_metric_id ON scheduling(metric_id) """ ) conn.commit()
[docs] def initialize_default_metrics(self): """Insert all DEFAULT_METRICS into the metric table (INSERT OR IGNORE).""" with self.get_connection() as conn: cursor = conn.cursor() for metric_name, unit, data_type in DEFAULT_METRICS: cursor.execute( """ INSERT OR IGNORE INTO metric (metric_name, unit, data_type) VALUES (?, ?, ?) """, (metric_name, unit, data_type), ) conn.commit()
def _validate_table(self, table: str): """Raise ValueError if table is not one of VALID_TABLES.""" if table not in VALID_TABLES: raise ValueError(f"Invalid table '{table}'. Must be one of: {VALID_TABLES}")
[docs] def save_record( self, table: str, timestamp: str, source_id: int, metric_id: int, value: str ) -> int: """ Insert a single record into the given table. Args: table: One of 'measurements', 'forecast', 'scheduling'. timestamp: ISO-format UTC timestamp string. source_id: Foreign key to source. metric_id: Foreign key to metric. value: Value stored as string. Returns: The measurement_id of the inserted row. """ self._validate_table(table) with self.get_connection() as conn: cursor = conn.cursor() cursor.execute( f""" INSERT INTO {table} (timestamp, source_id, metric_id, value) VALUES (?, ?, ?, ?) """, (timestamp, source_id, metric_id, value), ) return cursor.lastrowid
[docs] def save_records_batch(self, table: str, records: List[Tuple[str, int, int, str]]): """ Insert multiple records in one transaction. Args: table: One of 'measurements', 'forecast', 'scheduling'. records: List of (timestamp, source_id, metric_id, value) tuples. """ self._validate_table(table) with self.get_connection() as conn: cursor = conn.cursor() cursor.executemany( f""" INSERT INTO {table} (timestamp, source_id, metric_id, value) VALUES (?, ?, ?, ?) """, records, )
[docs] def get_records( self, table: str, source_id: int = None, metric_id: int = None, start_time: str = None, end_time: str = None, ): """ Return records from the table with optional filters. Args: table: One of 'measurements', 'forecast', 'scheduling'. source_id: Optional filter by source_id. metric_id: Optional filter by metric_id. start_time: Optional inclusive start timestamp (ISO string). end_time: Optional inclusive end timestamp (ISO string). Returns: List of rows (sqlite3.Row), ordered by timestamp. """ self._validate_table(table) query = f"SELECT * FROM {table} WHERE 1=1" params = [] if source_id is not None: query += " AND source_id = ?" params.append(source_id) if metric_id is not None: query += " AND metric_id = ?" params.append(metric_id) if start_time is not None: query += " AND timestamp >= ?" params.append(start_time) if end_time is not None: query += " AND timestamp <= ?" params.append(end_time) query += " ORDER BY timestamp" with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(query, params) return cursor.fetchall()
[docs] def clear_records(self, table: str, source_id: int = None, from_time: str = None): """ Delete records from the table, optionally by source_id and/or from_time. Args: table: One of 'measurements', 'forecast', 'scheduling'. source_id: If set, delete only records for this source. from_time: If set, delete only records with timestamp >= from_time (ISO string). """ self._validate_table(table) with self.get_connection() as conn: cursor = conn.cursor() if source_id is not None and from_time is not None: cursor.execute( f"DELETE FROM {table} WHERE source_id = ? AND timestamp >= ?", (source_id, from_time), ) elif source_id is not None: cursor.execute(f"DELETE FROM {table} WHERE source_id = ?", (source_id,)) elif from_time is not None: cursor.execute( f"DELETE FROM {table} WHERE timestamp >= ?", (from_time,) ) else: cursor.execute(f"DELETE FROM {table}")
[docs] def save_source(self, source_name: str, source_type: str) -> int: """ Insert a source row. Args: source_name: Unique name (e.g. 'boat_001', 'charger_01'). source_type: Type (e.g. 'boat', 'charger', 'battery'). Returns: The inserted source_id. """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute( """ INSERT INTO source (source_name, source_type) VALUES (?, ?) """, (source_name, source_type), ) return cursor.lastrowid
[docs] def get_source(self, source_id: int = None, source_name: str = None): """ Return one source by source_id or source_name; at least one must be given. Returns: sqlite3.Row or None if not found. """ with self.get_connection() as conn: cursor = conn.cursor() if source_id is not None: cursor.execute("SELECT * FROM source WHERE source_id = ?", (source_id,)) elif source_name is not None: cursor.execute( "SELECT * FROM source WHERE source_name = ?", (source_name,) ) else: return None return cursor.fetchone()
[docs] def get_all_sources(self, source_type: str = None): """ Return all source rows, optionally filtered by source_type. """ with self.get_connection() as conn: cursor = conn.cursor() if source_type: cursor.execute( "SELECT * FROM source WHERE source_type = ?", (source_type,) ) else: cursor.execute("SELECT * FROM source") return cursor.fetchall()
[docs] def delete_source(self, source_id: int): """Delete the source row with the given source_id.""" with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM source WHERE source_id = ?", (source_id,))
[docs] def save_metric(self, metric_name: str, unit: str, data_type: str) -> int: """ Insert a metric row. Args: metric_name: Unique name (e.g. 'power', 'soc', 'temperature'). unit: Unit string (e.g. 'kW', '%', '°C'). data_type: One of 'float', 'int', 'str', etc. Returns: The inserted metric_id. """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute( """ INSERT INTO metric (metric_name, unit, data_type) VALUES (?, ?, ?) """, (metric_name, unit, data_type), ) return cursor.lastrowid
[docs] def get_metric(self, metric_id: int = None, metric_name: str = None): """ Return one metric by metric_id or metric_name; at least one must be given. Returns: sqlite3.Row or None if not found. """ with self.get_connection() as conn: cursor = conn.cursor() if metric_id is not None: cursor.execute("SELECT * FROM metric WHERE metric_id = ?", (metric_id,)) elif metric_name is not None: cursor.execute( "SELECT * FROM metric WHERE metric_name = ?", (metric_name,) ) else: return None return cursor.fetchone()
[docs] def get_all_metrics(self, data_type: str = None): """Return all metric rows, optionally filtered by data_type.""" with self.get_connection() as conn: cursor = conn.cursor() if data_type: cursor.execute("SELECT * FROM metric WHERE data_type = ?", (data_type,)) else: cursor.execute("SELECT * FROM metric") return cursor.fetchall()
[docs] def delete_metric(self, metric_id: int): """Delete the metric row with the given metric_id.""" with self.get_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM metric WHERE metric_id = ?", (metric_id,))
[docs] def get_or_create_source(self, source_name: str, source_type: str) -> int: """ Return source_id for the given source; insert the source if it does not exist. Results are cached to avoid repeated lookups. Args: source_name: Unique source name (e.g. 'port', 'SeaBreeze', 'FastCharger_A'). source_type: Type (e.g. 'port', 'boat', 'charger', 'pv', 'bess', 'weather'). Returns: The source_id. """ if source_name in self._source_cache: return self._source_cache[source_name] source = self.get_source(source_name=source_name) if source: source_id = source["source_id"] self._source_cache[source_name] = source_id return source_id with self.get_connection() as conn: cursor = conn.cursor() cursor.execute( """ INSERT INTO source (source_name, source_type) VALUES (?, ?) """, (source_name, source_type), ) source_id = cursor.lastrowid self._source_cache[source_name] = source_id return source_id
[docs] def get_metric_id(self, metric_name: str) -> int: """ Return metric_id for the given metric name. Uses caching. The metric must already exist (e.g. from DEFAULT_METRICS or manual creation). Args: metric_name: Metric name (e.g. 'power_active', 'soc'). Returns: The metric_id. Raises: ValueError: If no metric with that name exists. """ if metric_name in self._metric_cache: return self._metric_cache[metric_name] metric = self.get_metric(metric_name=metric_name) if metric: metric_id = metric["metric_id"] self._metric_cache[metric_name] = metric_id return metric_id raise ValueError( f"Metric '{metric_name}' not found. Ensure it's in DEFAULT_METRICS or created manually." )
[docs] def clear_caches(self): """Clear in-memory source_id and metric_id caches.""" self._source_cache.clear() self._metric_cache.clear()