"""Tests for SQLAlchemy ORM models.""" import pytest from sqlalchemy import inspect def test_base_metadata_has_all_tables(): from shared.sa_models import Base table_names = set(Base.metadata.tables.keys()) expected = { "candles", "signals", "orders", "trades", "positions", "portfolio_snapshots", } assert expected == table_names class TestCandleRow: def test_table_name(self): from shared.sa_models import CandleRow assert CandleRow.__tablename__ == "candles" def test_columns(self): from shared.sa_models import CandleRow mapper = inspect(CandleRow) cols = {c.key for c in mapper.column_attrs} expected = { "symbol", "timeframe", "open_time", "open", "high", "low", "close", "volume", } assert expected == cols def test_primary_key(self): from shared.sa_models import CandleRow mapper = inspect(CandleRow) pk_cols = [c.name for c in mapper.mapper.primary_key] assert pk_cols == ["symbol", "timeframe", "open_time"] class TestSignalRow: def test_table_name(self): from shared.sa_models import SignalRow assert SignalRow.__tablename__ == "signals" def test_columns(self): from shared.sa_models import SignalRow mapper = inspect(SignalRow) cols = {c.key for c in mapper.column_attrs} expected = { "id", "strategy", "symbol", "side", "price", "quantity", "reason", "created_at", } assert expected == cols def test_primary_key(self): from shared.sa_models import SignalRow mapper = inspect(SignalRow) pk_cols = [c.name for c in mapper.mapper.primary_key] assert pk_cols == ["id"] class TestOrderRow: def test_table_name(self): from shared.sa_models import OrderRow assert OrderRow.__tablename__ == "orders" def test_columns(self): from shared.sa_models import OrderRow mapper = inspect(OrderRow) cols = {c.key for c in mapper.column_attrs} expected = { "id", "signal_id", "symbol", "side", "type", "price", "quantity", "status", "created_at", "filled_at", } assert expected == cols def test_primary_key(self): from shared.sa_models import OrderRow mapper = inspect(OrderRow) pk_cols = [c.name for c in mapper.mapper.primary_key] assert pk_cols == ["id"] def test_signal_id_foreign_key(self): from shared.sa_models import OrderRow table = OrderRow.__table__ fk_cols = { fk.parent.name: fk.target_fullname for fk in table.foreign_keys } assert fk_cols == {"signal_id": "signals.id"} class TestTradeRow: def test_table_name(self): from shared.sa_models import TradeRow assert TradeRow.__tablename__ == "trades" def test_columns(self): from shared.sa_models import TradeRow mapper = inspect(TradeRow) cols = {c.key for c in mapper.column_attrs} expected = { "id", "order_id", "symbol", "side", "price", "quantity", "fee", "traded_at", } assert expected == cols def test_primary_key(self): from shared.sa_models import TradeRow mapper = inspect(TradeRow) pk_cols = [c.name for c in mapper.mapper.primary_key] assert pk_cols == ["id"] def test_order_id_foreign_key(self): from shared.sa_models import TradeRow table = TradeRow.__table__ fk_cols = { fk.parent.name: fk.target_fullname for fk in table.foreign_keys } assert fk_cols == {"order_id": "orders.id"} class TestPositionRow: def test_table_name(self): from shared.sa_models import PositionRow assert PositionRow.__tablename__ == "positions" def test_columns(self): from shared.sa_models import PositionRow mapper = inspect(PositionRow) cols = {c.key for c in mapper.column_attrs} expected = { "symbol", "quantity", "avg_entry_price", "current_price", "updated_at", } assert expected == cols def test_primary_key(self): from shared.sa_models import PositionRow mapper = inspect(PositionRow) pk_cols = [c.name for c in mapper.mapper.primary_key] assert pk_cols == ["symbol"] class TestPortfolioSnapshotRow: def test_table_name(self): from shared.sa_models import PortfolioSnapshotRow assert PortfolioSnapshotRow.__tablename__ == "portfolio_snapshots" def test_columns(self): from shared.sa_models import PortfolioSnapshotRow mapper = inspect(PortfolioSnapshotRow) cols = {c.key for c in mapper.column_attrs} expected = { "id", "total_value", "realized_pnl", "unrealized_pnl", "snapshot_at", } assert expected == cols def test_primary_key(self): from shared.sa_models import PortfolioSnapshotRow mapper = inspect(PortfolioSnapshotRow) pk_cols = [c.name for c in mapper.mapper.primary_key] assert pk_cols == ["id"] def test_id_is_autoincrement(self): from shared.sa_models import PortfolioSnapshotRow table = PortfolioSnapshotRow.__table__ id_col = table.c.id assert id_col.autoincrement is True or id_col.autoincrement == "auto" class TestStatusDefault: def test_order_status_server_default(self): from shared.sa_models import OrderRow table = OrderRow.__table__ status_col = table.c.status assert status_col.server_default is not None assert status_col.server_default.arg == "PENDING" def test_trade_fee_server_default(self): from shared.sa_models import TradeRow table = TradeRow.__table__ fee_col = table.c.fee assert fee_col.server_default is not None assert fee_col.server_default.arg == "0"