diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 16:00:38 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 16:02:03 +0900 |
| commit | 2b1db156c7ea7e0be543ab91813922b95eb043cb (patch) | |
| tree | fd19abb0845d96e160c68817190b33a0f6c0034d | |
| parent | 33b14aaa2344b0fd95d1629627c3d135b24ae102 (diff) | |
feat: add SQLAlchemy ORM models and Alembic migration setup
Add SA 2.0 declarative models (CandleRow, SignalRow, OrderRow, TradeRow,
PositionRow, PortfolioSnapshotRow) mirroring existing asyncpg tables.
Set up Alembic with async PostgreSQL support and add migrate/migrate-down/
migrate-new Makefile targets. Update shared dependencies with sqlalchemy,
alembic, structlog, prometheus-client, pyyaml, aiohttp, and rich.
| -rw-r--r-- | Makefile | 11 | ||||
| -rw-r--r-- | shared/alembic.ini | 36 | ||||
| -rw-r--r-- | shared/alembic/env.py | 66 | ||||
| -rw-r--r-- | shared/alembic/script.py.mako | 26 | ||||
| -rw-r--r-- | shared/alembic/versions/.gitkeep | 0 | ||||
| -rw-r--r-- | shared/pyproject.toml | 7 | ||||
| -rw-r--r-- | shared/src/shared/sa_models.py | 93 | ||||
| -rw-r--r-- | shared/tests/test_sa_models.py | 244 |
8 files changed, 482 insertions, 1 deletions
@@ -1,4 +1,4 @@ -.PHONY: infra up down logs test lint format +.PHONY: infra up down logs test lint format migrate migrate-down migrate-new infra: docker compose up -d redis postgres @@ -22,3 +22,12 @@ lint: format: ruff check --fix . ruff format . + +migrate: + cd shared && alembic upgrade head + +migrate-down: + cd shared && alembic downgrade -1 + +migrate-new: + cd shared && alembic revision --autogenerate -m "$(msg)" diff --git a/shared/alembic.ini b/shared/alembic.ini new file mode 100644 index 0000000..2c4fd1f --- /dev/null +++ b/shared/alembic.ini @@ -0,0 +1,36 @@ +[alembic] +script_location = alembic +sqlalchemy.url = postgresql+asyncpg://postgres:postgres@localhost:5432/trading + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/shared/alembic/env.py b/shared/alembic/env.py new file mode 100644 index 0000000..14303f6 --- /dev/null +++ b/shared/alembic/env.py @@ -0,0 +1,66 @@ +"""Alembic environment configuration for async PostgreSQL migrations.""" + +import asyncio +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import async_engine_from_config + +from shared.sa_models import Base + +config = context.config + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + Configures the context with just a URL and not an Engine. + Calls to context.execute() here emit the given string to the script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): + context.configure(connection=connection, target_metadata=target_metadata) + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """Run migrations in 'online' mode using an async engine.""" + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/shared/alembic/script.py.mako b/shared/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/shared/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/shared/alembic/versions/.gitkeep b/shared/alembic/versions/.gitkeep new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/shared/alembic/versions/.gitkeep diff --git a/shared/pyproject.toml b/shared/pyproject.toml index bd09d3e..c36f00b 100644 --- a/shared/pyproject.toml +++ b/shared/pyproject.toml @@ -8,6 +8,13 @@ dependencies = [ "pydantic-settings>=2.0", "redis>=5.0", "asyncpg>=0.29", + "sqlalchemy[asyncio]>=2.0", + "alembic>=1.13", + "structlog>=24.0", + "prometheus-client>=0.20", + "pyyaml>=6.0", + "aiohttp>=3.9", + "rich>=13.0", ] [project.optional-dependencies] diff --git a/shared/src/shared/sa_models.py b/shared/src/shared/sa_models.py new file mode 100644 index 0000000..0537846 --- /dev/null +++ b/shared/src/shared/sa_models.py @@ -0,0 +1,93 @@ +"""SQLAlchemy 2.0 ORM models mirroring the existing asyncpg tables.""" + +from datetime import datetime +from decimal import Decimal + +from sqlalchemy import DateTime, ForeignKey, Numeric, Text +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(DeclarativeBase): + pass + + +class CandleRow(Base): + __tablename__ = "candles" + + symbol: Mapped[str] = mapped_column(Text, primary_key=True) + timeframe: Mapped[str] = mapped_column(Text, primary_key=True) + open_time: Mapped[datetime] = mapped_column(DateTime(timezone=True), primary_key=True) + open: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + high: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + low: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + close: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + volume: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + + +class SignalRow(Base): + __tablename__ = "signals" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + strategy: Mapped[str] = mapped_column(Text, nullable=False) + symbol: Mapped[str] = mapped_column(Text, nullable=False) + side: Mapped[str] = mapped_column(Text, nullable=False) + price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + reason: Mapped[str | None] = mapped_column(Text) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + +class OrderRow(Base): + __tablename__ = "orders" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + signal_id: Mapped[str | None] = mapped_column( + Text, ForeignKey("signals.id") + ) + symbol: Mapped[str] = mapped_column(Text, nullable=False) + side: Mapped[str] = mapped_column(Text, nullable=False) + type: Mapped[str] = mapped_column(Text, nullable=False) + price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + status: Mapped[str] = mapped_column( + Text, nullable=False, server_default="PENDING" + ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + filled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + + +class TradeRow(Base): + __tablename__ = "trades" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + order_id: Mapped[str | None] = mapped_column( + Text, ForeignKey("orders.id") + ) + symbol: Mapped[str] = mapped_column(Text, nullable=False) + side: Mapped[str] = mapped_column(Text, nullable=False) + price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + fee: Mapped[Decimal] = mapped_column( + Numeric, nullable=False, server_default="0" + ) + traded_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + +class PositionRow(Base): + __tablename__ = "positions" + + symbol: Mapped[str] = mapped_column(Text, primary_key=True) + quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + avg_entry_price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + current_price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + +class PortfolioSnapshotRow(Base): + __tablename__ = "portfolio_snapshots" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + total_value: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + realized_pnl: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + unrealized_pnl: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + snapshot_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py new file mode 100644 index 0000000..de994c5 --- /dev/null +++ b/shared/tests/test_sa_models.py @@ -0,0 +1,244 @@ +"""Tests for SQLAlchemy ORM models.""" + +import pytest +from sqlalchemy import inspect + + +def test_base_metadata_has_all_tables(): + from shared.sa_models import Base + + table_names = set(Base.metadata.tables.keys()) + expected = { + "candles", + "signals", + "orders", + "trades", + "positions", + "portfolio_snapshots", + } + assert expected == table_names + + +class TestCandleRow: + def test_table_name(self): + from shared.sa_models import CandleRow + + assert CandleRow.__tablename__ == "candles" + + def test_columns(self): + from shared.sa_models import CandleRow + + mapper = inspect(CandleRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "symbol", + "timeframe", + "open_time", + "open", + "high", + "low", + "close", + "volume", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import CandleRow + + mapper = inspect(CandleRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["symbol", "timeframe", "open_time"] + + +class TestSignalRow: + def test_table_name(self): + from shared.sa_models import SignalRow + + assert SignalRow.__tablename__ == "signals" + + def test_columns(self): + from shared.sa_models import SignalRow + + mapper = inspect(SignalRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "strategy", + "symbol", + "side", + "price", + "quantity", + "reason", + "created_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import SignalRow + + mapper = inspect(SignalRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + +class TestOrderRow: + def test_table_name(self): + from shared.sa_models import OrderRow + + assert OrderRow.__tablename__ == "orders" + + def test_columns(self): + from shared.sa_models import OrderRow + + mapper = inspect(OrderRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "signal_id", + "symbol", + "side", + "type", + "price", + "quantity", + "status", + "created_at", + "filled_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import OrderRow + + mapper = inspect(OrderRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_signal_id_foreign_key(self): + from shared.sa_models import OrderRow + + table = OrderRow.__table__ + fk_cols = { + fk.parent.name: fk.target_fullname for fk in table.foreign_keys + } + assert fk_cols == {"signal_id": "signals.id"} + + +class TestTradeRow: + def test_table_name(self): + from shared.sa_models import TradeRow + + assert TradeRow.__tablename__ == "trades" + + def test_columns(self): + from shared.sa_models import TradeRow + + mapper = inspect(TradeRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "order_id", + "symbol", + "side", + "price", + "quantity", + "fee", + "traded_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import TradeRow + + mapper = inspect(TradeRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_order_id_foreign_key(self): + from shared.sa_models import TradeRow + + table = TradeRow.__table__ + fk_cols = { + fk.parent.name: fk.target_fullname for fk in table.foreign_keys + } + assert fk_cols == {"order_id": "orders.id"} + + +class TestPositionRow: + def test_table_name(self): + from shared.sa_models import PositionRow + + assert PositionRow.__tablename__ == "positions" + + def test_columns(self): + from shared.sa_models import PositionRow + + mapper = inspect(PositionRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "symbol", + "quantity", + "avg_entry_price", + "current_price", + "updated_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import PositionRow + + mapper = inspect(PositionRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["symbol"] + + +class TestPortfolioSnapshotRow: + def test_table_name(self): + from shared.sa_models import PortfolioSnapshotRow + + assert PortfolioSnapshotRow.__tablename__ == "portfolio_snapshots" + + def test_columns(self): + from shared.sa_models import PortfolioSnapshotRow + + mapper = inspect(PortfolioSnapshotRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "total_value", + "realized_pnl", + "unrealized_pnl", + "snapshot_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import PortfolioSnapshotRow + + mapper = inspect(PortfolioSnapshotRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_id_is_autoincrement(self): + from shared.sa_models import PortfolioSnapshotRow + + table = PortfolioSnapshotRow.__table__ + id_col = table.c.id + assert id_col.autoincrement is True or id_col.autoincrement == "auto" + + +class TestStatusDefault: + def test_order_status_server_default(self): + from shared.sa_models import OrderRow + + table = OrderRow.__table__ + status_col = table.c.status + assert status_col.server_default is not None + assert status_col.server_default.arg == "PENDING" + + def test_trade_fee_server_default(self): + from shared.sa_models import TradeRow + + table = TradeRow.__table__ + fee_col = table.c.fee + assert fee_col.server_default is not None + assert fee_col.server_default.arg == "0" |
