summaryrefslogtreecommitdiff
path: root/shared/tests
diff options
context:
space:
mode:
Diffstat (limited to 'shared/tests')
-rw-r--r--shared/tests/test_sa_models.py244
1 files changed, 244 insertions, 0 deletions
diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py
new file mode 100644
index 0000000..de994c5
--- /dev/null
+++ b/shared/tests/test_sa_models.py
@@ -0,0 +1,244 @@
+"""Tests for SQLAlchemy ORM models."""
+
+import pytest
+from sqlalchemy import inspect
+
+
+def test_base_metadata_has_all_tables():
+ from shared.sa_models import Base
+
+ table_names = set(Base.metadata.tables.keys())
+ expected = {
+ "candles",
+ "signals",
+ "orders",
+ "trades",
+ "positions",
+ "portfolio_snapshots",
+ }
+ assert expected == table_names
+
+
+class TestCandleRow:
+ def test_table_name(self):
+ from shared.sa_models import CandleRow
+
+ assert CandleRow.__tablename__ == "candles"
+
+ def test_columns(self):
+ from shared.sa_models import CandleRow
+
+ mapper = inspect(CandleRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "symbol",
+ "timeframe",
+ "open_time",
+ "open",
+ "high",
+ "low",
+ "close",
+ "volume",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import CandleRow
+
+ mapper = inspect(CandleRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["symbol", "timeframe", "open_time"]
+
+
+class TestSignalRow:
+ def test_table_name(self):
+ from shared.sa_models import SignalRow
+
+ assert SignalRow.__tablename__ == "signals"
+
+ def test_columns(self):
+ from shared.sa_models import SignalRow
+
+ mapper = inspect(SignalRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "id",
+ "strategy",
+ "symbol",
+ "side",
+ "price",
+ "quantity",
+ "reason",
+ "created_at",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import SignalRow
+
+ mapper = inspect(SignalRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["id"]
+
+
+class TestOrderRow:
+ def test_table_name(self):
+ from shared.sa_models import OrderRow
+
+ assert OrderRow.__tablename__ == "orders"
+
+ def test_columns(self):
+ from shared.sa_models import OrderRow
+
+ mapper = inspect(OrderRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "id",
+ "signal_id",
+ "symbol",
+ "side",
+ "type",
+ "price",
+ "quantity",
+ "status",
+ "created_at",
+ "filled_at",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import OrderRow
+
+ mapper = inspect(OrderRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["id"]
+
+ def test_signal_id_foreign_key(self):
+ from shared.sa_models import OrderRow
+
+ table = OrderRow.__table__
+ fk_cols = {
+ fk.parent.name: fk.target_fullname for fk in table.foreign_keys
+ }
+ assert fk_cols == {"signal_id": "signals.id"}
+
+
+class TestTradeRow:
+ def test_table_name(self):
+ from shared.sa_models import TradeRow
+
+ assert TradeRow.__tablename__ == "trades"
+
+ def test_columns(self):
+ from shared.sa_models import TradeRow
+
+ mapper = inspect(TradeRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "id",
+ "order_id",
+ "symbol",
+ "side",
+ "price",
+ "quantity",
+ "fee",
+ "traded_at",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import TradeRow
+
+ mapper = inspect(TradeRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["id"]
+
+ def test_order_id_foreign_key(self):
+ from shared.sa_models import TradeRow
+
+ table = TradeRow.__table__
+ fk_cols = {
+ fk.parent.name: fk.target_fullname for fk in table.foreign_keys
+ }
+ assert fk_cols == {"order_id": "orders.id"}
+
+
+class TestPositionRow:
+ def test_table_name(self):
+ from shared.sa_models import PositionRow
+
+ assert PositionRow.__tablename__ == "positions"
+
+ def test_columns(self):
+ from shared.sa_models import PositionRow
+
+ mapper = inspect(PositionRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "symbol",
+ "quantity",
+ "avg_entry_price",
+ "current_price",
+ "updated_at",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import PositionRow
+
+ mapper = inspect(PositionRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["symbol"]
+
+
+class TestPortfolioSnapshotRow:
+ def test_table_name(self):
+ from shared.sa_models import PortfolioSnapshotRow
+
+ assert PortfolioSnapshotRow.__tablename__ == "portfolio_snapshots"
+
+ def test_columns(self):
+ from shared.sa_models import PortfolioSnapshotRow
+
+ mapper = inspect(PortfolioSnapshotRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "id",
+ "total_value",
+ "realized_pnl",
+ "unrealized_pnl",
+ "snapshot_at",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import PortfolioSnapshotRow
+
+ mapper = inspect(PortfolioSnapshotRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["id"]
+
+ def test_id_is_autoincrement(self):
+ from shared.sa_models import PortfolioSnapshotRow
+
+ table = PortfolioSnapshotRow.__table__
+ id_col = table.c.id
+ assert id_col.autoincrement is True or id_col.autoincrement == "auto"
+
+
+class TestStatusDefault:
+ def test_order_status_server_default(self):
+ from shared.sa_models import OrderRow
+
+ table = OrderRow.__table__
+ status_col = table.c.status
+ assert status_col.server_default is not None
+ assert status_col.server_default.arg == "PENDING"
+
+ def test_trade_fee_server_default(self):
+ from shared.sa_models import TradeRow
+
+ table = TradeRow.__table__
+ fee_col = table.c.fee
+ assert fee_col.server_default is not None
+ assert fee_col.server_default.arg == "0"