From 2b1db156c7ea7e0be543ab91813922b95eb043cb Mon Sep 17 00:00:00 2001 From: TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:00:38 +0900 Subject: feat: add SQLAlchemy ORM models and Alembic migration setup Add SA 2.0 declarative models (CandleRow, SignalRow, OrderRow, TradeRow, PositionRow, PortfolioSnapshotRow) mirroring existing asyncpg tables. Set up Alembic with async PostgreSQL support and add migrate/migrate-down/ migrate-new Makefile targets. Update shared dependencies with sqlalchemy, alembic, structlog, prometheus-client, pyyaml, aiohttp, and rich. --- shared/tests/test_sa_models.py | 244 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 shared/tests/test_sa_models.py (limited to 'shared/tests') diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py new file mode 100644 index 0000000..de994c5 --- /dev/null +++ b/shared/tests/test_sa_models.py @@ -0,0 +1,244 @@ +"""Tests for SQLAlchemy ORM models.""" + +import pytest +from sqlalchemy import inspect + + +def test_base_metadata_has_all_tables(): + from shared.sa_models import Base + + table_names = set(Base.metadata.tables.keys()) + expected = { + "candles", + "signals", + "orders", + "trades", + "positions", + "portfolio_snapshots", + } + assert expected == table_names + + +class TestCandleRow: + def test_table_name(self): + from shared.sa_models import CandleRow + + assert CandleRow.__tablename__ == "candles" + + def test_columns(self): + from shared.sa_models import CandleRow + + mapper = inspect(CandleRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "symbol", + "timeframe", + "open_time", + "open", + "high", + "low", + "close", + "volume", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import CandleRow + + mapper = inspect(CandleRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["symbol", "timeframe", "open_time"] + + +class TestSignalRow: + def test_table_name(self): + from shared.sa_models import SignalRow + + assert SignalRow.__tablename__ == "signals" + + def test_columns(self): + from shared.sa_models import SignalRow + + mapper = inspect(SignalRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "strategy", + "symbol", + "side", + "price", + "quantity", + "reason", + "created_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import SignalRow + + mapper = inspect(SignalRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + +class TestOrderRow: + def test_table_name(self): + from shared.sa_models import OrderRow + + assert OrderRow.__tablename__ == "orders" + + def test_columns(self): + from shared.sa_models import OrderRow + + mapper = inspect(OrderRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "signal_id", + "symbol", + "side", + "type", + "price", + "quantity", + "status", + "created_at", + "filled_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import OrderRow + + mapper = inspect(OrderRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_signal_id_foreign_key(self): + from shared.sa_models import OrderRow + + table = OrderRow.__table__ + fk_cols = { + fk.parent.name: fk.target_fullname for fk in table.foreign_keys + } + assert fk_cols == {"signal_id": "signals.id"} + + +class TestTradeRow: + def test_table_name(self): + from shared.sa_models import TradeRow + + assert TradeRow.__tablename__ == "trades" + + def test_columns(self): + from shared.sa_models import TradeRow + + mapper = inspect(TradeRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "order_id", + "symbol", + "side", + "price", + "quantity", + "fee", + "traded_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import TradeRow + + mapper = inspect(TradeRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_order_id_foreign_key(self): + from shared.sa_models import TradeRow + + table = TradeRow.__table__ + fk_cols = { + fk.parent.name: fk.target_fullname for fk in table.foreign_keys + } + assert fk_cols == {"order_id": "orders.id"} + + +class TestPositionRow: + def test_table_name(self): + from shared.sa_models import PositionRow + + assert PositionRow.__tablename__ == "positions" + + def test_columns(self): + from shared.sa_models import PositionRow + + mapper = inspect(PositionRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "symbol", + "quantity", + "avg_entry_price", + "current_price", + "updated_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import PositionRow + + mapper = inspect(PositionRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["symbol"] + + +class TestPortfolioSnapshotRow: + def test_table_name(self): + from shared.sa_models import PortfolioSnapshotRow + + assert PortfolioSnapshotRow.__tablename__ == "portfolio_snapshots" + + def test_columns(self): + from shared.sa_models import PortfolioSnapshotRow + + mapper = inspect(PortfolioSnapshotRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "total_value", + "realized_pnl", + "unrealized_pnl", + "snapshot_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import PortfolioSnapshotRow + + mapper = inspect(PortfolioSnapshotRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_id_is_autoincrement(self): + from shared.sa_models import PortfolioSnapshotRow + + table = PortfolioSnapshotRow.__table__ + id_col = table.c.id + assert id_col.autoincrement is True or id_col.autoincrement == "auto" + + +class TestStatusDefault: + def test_order_status_server_default(self): + from shared.sa_models import OrderRow + + table = OrderRow.__table__ + status_col = table.c.status + assert status_col.server_default is not None + assert status_col.server_default.arg == "PENDING" + + def test_trade_fee_server_default(self): + from shared.sa_models import TradeRow + + table = TradeRow.__table__ + fee_col = table.c.fee + assert fee_col.server_default is not None + assert fee_col.server_default.arg == "0" -- cgit v1.2.3