diff options
Diffstat (limited to 'shared')
| -rw-r--r-- | shared/alembic.ini | 36 | ||||
| -rw-r--r-- | shared/alembic/env.py | 66 | ||||
| -rw-r--r-- | shared/alembic/script.py.mako | 26 | ||||
| -rw-r--r-- | shared/alembic/versions/.gitkeep | 0 | ||||
| -rw-r--r-- | shared/pyproject.toml | 7 | ||||
| -rw-r--r-- | shared/src/shared/sa_models.py | 93 | ||||
| -rw-r--r-- | shared/tests/test_sa_models.py | 244 |
7 files changed, 472 insertions, 0 deletions
diff --git a/shared/alembic.ini b/shared/alembic.ini new file mode 100644 index 0000000..2c4fd1f --- /dev/null +++ b/shared/alembic.ini @@ -0,0 +1,36 @@ +[alembic] +script_location = alembic +sqlalchemy.url = postgresql+asyncpg://postgres:postgres@localhost:5432/trading + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/shared/alembic/env.py b/shared/alembic/env.py new file mode 100644 index 0000000..14303f6 --- /dev/null +++ b/shared/alembic/env.py @@ -0,0 +1,66 @@ +"""Alembic environment configuration for async PostgreSQL migrations.""" + +import asyncio +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import async_engine_from_config + +from shared.sa_models import Base + +config = context.config + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + Configures the context with just a URL and not an Engine. + Calls to context.execute() here emit the given string to the script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): + context.configure(connection=connection, target_metadata=target_metadata) + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """Run migrations in 'online' mode using an async engine.""" + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/shared/alembic/script.py.mako b/shared/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/shared/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/shared/alembic/versions/.gitkeep b/shared/alembic/versions/.gitkeep new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/shared/alembic/versions/.gitkeep diff --git a/shared/pyproject.toml b/shared/pyproject.toml index bd09d3e..c36f00b 100644 --- a/shared/pyproject.toml +++ b/shared/pyproject.toml @@ -8,6 +8,13 @@ dependencies = [ "pydantic-settings>=2.0", "redis>=5.0", "asyncpg>=0.29", + "sqlalchemy[asyncio]>=2.0", + "alembic>=1.13", + "structlog>=24.0", + "prometheus-client>=0.20", + "pyyaml>=6.0", + "aiohttp>=3.9", + "rich>=13.0", ] [project.optional-dependencies] diff --git a/shared/src/shared/sa_models.py b/shared/src/shared/sa_models.py new file mode 100644 index 0000000..0537846 --- /dev/null +++ b/shared/src/shared/sa_models.py @@ -0,0 +1,93 @@ +"""SQLAlchemy 2.0 ORM models mirroring the existing asyncpg tables.""" + +from datetime import datetime +from decimal import Decimal + +from sqlalchemy import DateTime, ForeignKey, Numeric, Text +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(DeclarativeBase): + pass + + +class CandleRow(Base): + __tablename__ = "candles" + + symbol: Mapped[str] = mapped_column(Text, primary_key=True) + timeframe: Mapped[str] = mapped_column(Text, primary_key=True) + open_time: Mapped[datetime] = mapped_column(DateTime(timezone=True), primary_key=True) + open: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + high: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + low: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + close: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + volume: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + + +class SignalRow(Base): + __tablename__ = "signals" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + strategy: Mapped[str] = mapped_column(Text, nullable=False) + symbol: Mapped[str] = mapped_column(Text, nullable=False) + side: Mapped[str] = mapped_column(Text, nullable=False) + price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + reason: Mapped[str | None] = mapped_column(Text) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + +class OrderRow(Base): + __tablename__ = "orders" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + signal_id: Mapped[str | None] = mapped_column( + Text, ForeignKey("signals.id") + ) + symbol: Mapped[str] = mapped_column(Text, nullable=False) + side: Mapped[str] = mapped_column(Text, nullable=False) + type: Mapped[str] = mapped_column(Text, nullable=False) + price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + status: Mapped[str] = mapped_column( + Text, nullable=False, server_default="PENDING" + ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + filled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + + +class TradeRow(Base): + __tablename__ = "trades" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + order_id: Mapped[str | None] = mapped_column( + Text, ForeignKey("orders.id") + ) + symbol: Mapped[str] = mapped_column(Text, nullable=False) + side: Mapped[str] = mapped_column(Text, nullable=False) + price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + fee: Mapped[Decimal] = mapped_column( + Numeric, nullable=False, server_default="0" + ) + traded_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + +class PositionRow(Base): + __tablename__ = "positions" + + symbol: Mapped[str] = mapped_column(Text, primary_key=True) + quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + avg_entry_price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + current_price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + +class PortfolioSnapshotRow(Base): + __tablename__ = "portfolio_snapshots" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + total_value: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + realized_pnl: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + unrealized_pnl: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + snapshot_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py new file mode 100644 index 0000000..de994c5 --- /dev/null +++ b/shared/tests/test_sa_models.py @@ -0,0 +1,244 @@ +"""Tests for SQLAlchemy ORM models.""" + +import pytest +from sqlalchemy import inspect + + +def test_base_metadata_has_all_tables(): + from shared.sa_models import Base + + table_names = set(Base.metadata.tables.keys()) + expected = { + "candles", + "signals", + "orders", + "trades", + "positions", + "portfolio_snapshots", + } + assert expected == table_names + + +class TestCandleRow: + def test_table_name(self): + from shared.sa_models import CandleRow + + assert CandleRow.__tablename__ == "candles" + + def test_columns(self): + from shared.sa_models import CandleRow + + mapper = inspect(CandleRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "symbol", + "timeframe", + "open_time", + "open", + "high", + "low", + "close", + "volume", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import CandleRow + + mapper = inspect(CandleRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["symbol", "timeframe", "open_time"] + + +class TestSignalRow: + def test_table_name(self): + from shared.sa_models import SignalRow + + assert SignalRow.__tablename__ == "signals" + + def test_columns(self): + from shared.sa_models import SignalRow + + mapper = inspect(SignalRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "strategy", + "symbol", + "side", + "price", + "quantity", + "reason", + "created_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import SignalRow + + mapper = inspect(SignalRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + +class TestOrderRow: + def test_table_name(self): + from shared.sa_models import OrderRow + + assert OrderRow.__tablename__ == "orders" + + def test_columns(self): + from shared.sa_models import OrderRow + + mapper = inspect(OrderRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "signal_id", + "symbol", + "side", + "type", + "price", + "quantity", + "status", + "created_at", + "filled_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import OrderRow + + mapper = inspect(OrderRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_signal_id_foreign_key(self): + from shared.sa_models import OrderRow + + table = OrderRow.__table__ + fk_cols = { + fk.parent.name: fk.target_fullname for fk in table.foreign_keys + } + assert fk_cols == {"signal_id": "signals.id"} + + +class TestTradeRow: + def test_table_name(self): + from shared.sa_models import TradeRow + + assert TradeRow.__tablename__ == "trades" + + def test_columns(self): + from shared.sa_models import TradeRow + + mapper = inspect(TradeRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "order_id", + "symbol", + "side", + "price", + "quantity", + "fee", + "traded_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import TradeRow + + mapper = inspect(TradeRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_order_id_foreign_key(self): + from shared.sa_models import TradeRow + + table = TradeRow.__table__ + fk_cols = { + fk.parent.name: fk.target_fullname for fk in table.foreign_keys + } + assert fk_cols == {"order_id": "orders.id"} + + +class TestPositionRow: + def test_table_name(self): + from shared.sa_models import PositionRow + + assert PositionRow.__tablename__ == "positions" + + def test_columns(self): + from shared.sa_models import PositionRow + + mapper = inspect(PositionRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "symbol", + "quantity", + "avg_entry_price", + "current_price", + "updated_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import PositionRow + + mapper = inspect(PositionRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["symbol"] + + +class TestPortfolioSnapshotRow: + def test_table_name(self): + from shared.sa_models import PortfolioSnapshotRow + + assert PortfolioSnapshotRow.__tablename__ == "portfolio_snapshots" + + def test_columns(self): + from shared.sa_models import PortfolioSnapshotRow + + mapper = inspect(PortfolioSnapshotRow) + cols = {c.key for c in mapper.column_attrs} + expected = { + "id", + "total_value", + "realized_pnl", + "unrealized_pnl", + "snapshot_at", + } + assert expected == cols + + def test_primary_key(self): + from shared.sa_models import PortfolioSnapshotRow + + mapper = inspect(PortfolioSnapshotRow) + pk_cols = [c.name for c in mapper.mapper.primary_key] + assert pk_cols == ["id"] + + def test_id_is_autoincrement(self): + from shared.sa_models import PortfolioSnapshotRow + + table = PortfolioSnapshotRow.__table__ + id_col = table.c.id + assert id_col.autoincrement is True or id_col.autoincrement == "auto" + + +class TestStatusDefault: + def test_order_status_server_default(self): + from shared.sa_models import OrderRow + + table = OrderRow.__table__ + status_col = table.c.status + assert status_col.server_default is not None + assert status_col.server_default.arg == "PENDING" + + def test_trade_fee_server_default(self): + from shared.sa_models import TradeRow + + table = TradeRow.__table__ + fee_col = table.c.fee + assert fee_col.server_default is not None + assert fee_col.server_default.arg == "0" |
