diff options
Diffstat (limited to 'shared/tests/test_sa_models.py')
| -rw-r--r-- | shared/tests/test_sa_models.py | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py new file mode 100644 index 0000000..de994c5 --- /dev/null +++ b/shared/tests/test_sa_models.py @@ -0,0 +1,244 @@ +"""Tests for SQLAlchemy ORM models.""" + +import pytest +from sqlalchemy import inspect + + +def test_base_metadata_has_all_tables(): + from shared.sa_models import Base + + table_names = set(Base.metadata.tables.keys()) + expected = { + "candles", + "signals", + "orders", + "trades", + "positions", + "portfolio_snapshots", + } + assert expected == table_names + + +class TestCandleRow: + def test_table_name(self): + from shared.sa_models import CandleRow + + assert CandleRow.__tablename__ == "candles" + + def test_columns(self): + from shared.sa_models import CandleRow + + mapper = inspect(CandleRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "symbol", + "timeframe", + "open_time", + "open", + "high", + "low", + "close", + "volume", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import CandleRow + + mapper = inspect(CandleRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["symbol", "timeframe", "open_time"] + + +class TestSignalRow: + def test_table_name(self): + from shared.sa_models import SignalRow + + assert SignalRow.__tablename__ == "signals" + + def test_columns(self): + from shared.sa_models import SignalRow + + mapper = inspect(SignalRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "strategy", + "symbol", + "side", + "price", + "quantity", + "reason", + "created_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import SignalRow + + mapper = inspect(SignalRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + +class TestOrderRow: + def test_table_name(self): + from shared.sa_models import OrderRow + + assert OrderRow.__tablename__ == "orders" + + def test_columns(self): + from shared.sa_models import OrderRow + + mapper = inspect(OrderRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "signal_id", + "symbol", + "side", + "type", + "price", + "quantity", + "status", + "created_at", + "filled_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import OrderRow + + mapper = inspect(OrderRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_signal_id_foreign_key(self): + from shared.sa_models import OrderRow + + table = OrderRow.__table__ + fk_cols = { + fk.parent.name: fk.target_fullname for fk in table.foreign_keys + } + assert fk_cols == {"signal_id": "signals.id"} + + +class TestTradeRow: + def test_table_name(self): + from shared.sa_models import TradeRow + + assert TradeRow.__tablename__ == "trades" + + def test_columns(self): + from shared.sa_models import TradeRow + + mapper = inspect(TradeRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "order_id", + "symbol", + "side", + "price", + "quantity", + "fee", + "traded_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import TradeRow + + mapper = inspect(TradeRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_order_id_foreign_key(self): + from shared.sa_models import TradeRow + + table = TradeRow.__table__ + fk_cols = { + fk.parent.name: fk.target_fullname for fk in table.foreign_keys + } + assert fk_cols == {"order_id": "orders.id"} + + +class TestPositionRow: + def test_table_name(self): + from shared.sa_models import PositionRow + + assert PositionRow.__tablename__ == "positions" + + def test_columns(self): + from shared.sa_models import PositionRow + + mapper = inspect(PositionRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "symbol", + "quantity", + "avg_entry_price", + "current_price", + "updated_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import PositionRow + + mapper = inspect(PositionRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["symbol"] + + +class TestPortfolioSnapshotRow: + def test_table_name(self): + from shared.sa_models import PortfolioSnapshotRow + + assert PortfolioSnapshotRow.__tablename__ == "portfolio_snapshots" + + def test_columns(self): + from shared.sa_models import PortfolioSnapshotRow + + mapper = inspect(PortfolioSnapshotRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "total_value", + "realized_pnl", + "unrealized_pnl", + "snapshot_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import PortfolioSnapshotRow + + mapper = inspect(PortfolioSnapshotRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_id_is_autoincrement(self): + from shared.sa_models import PortfolioSnapshotRow + + table = PortfolioSnapshotRow.__table__ + id_col = table.c.id + assert id_col.autoincrement is True or id_col.autoincrement == "auto" + + +class TestStatusDefault: + def test_order_status_server_default(self): + from shared.sa_models import OrderRow + + table = OrderRow.__table__ + status_col = table.c.status + assert status_col.server_default is not None + assert status_col.server_default.arg == "PENDING" + + def test_trade_fee_server_default(self): + from shared.sa_models import TradeRow + + table = TradeRow.__table__ + fee_col = table.c.fee + assert fee_col.server_default is not None + assert fee_col.server_default.arg == "0" |
