summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
Diffstat (limited to 'shared')
-rw-r--r--shared/alembic.ini36
-rw-r--r--shared/alembic/env.py66
-rw-r--r--shared/alembic/script.py.mako26
-rw-r--r--shared/alembic/versions/.gitkeep0
-rw-r--r--shared/pyproject.toml7
-rw-r--r--shared/src/shared/sa_models.py93
-rw-r--r--shared/tests/test_sa_models.py244
7 files changed, 472 insertions, 0 deletions
diff --git a/shared/alembic.ini b/shared/alembic.ini
new file mode 100644
index 0000000..2c4fd1f
--- /dev/null
+++ b/shared/alembic.ini
@@ -0,0 +1,36 @@
+[alembic]
+script_location = alembic
+sqlalchemy.url = postgresql+asyncpg://postgres:postgres@localhost:5432/trading
+
+[loggers]
+keys = root,sqlalchemy,alembic
+
+[handlers]
+keys = console
+
+[formatters]
+keys = generic
+
+[logger_root]
+level = WARN
+handlers = console
+
+[logger_sqlalchemy]
+level = WARN
+handlers =
+qualname = sqlalchemy.engine
+
+[logger_alembic]
+level = INFO
+handlers =
+qualname = alembic
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatter_generic]
+format = %(levelname)-5.5s [%(name)s] %(message)s
+datefmt = %H:%M:%S
diff --git a/shared/alembic/env.py b/shared/alembic/env.py
new file mode 100644
index 0000000..14303f6
--- /dev/null
+++ b/shared/alembic/env.py
@@ -0,0 +1,66 @@
+"""Alembic environment configuration for async PostgreSQL migrations."""
+
+import asyncio
+from logging.config import fileConfig
+
+from alembic import context
+from sqlalchemy import pool
+from sqlalchemy.ext.asyncio import async_engine_from_config
+
+from shared.sa_models import Base
+
+config = context.config
+
+if config.config_file_name is not None:
+ fileConfig(config.config_file_name)
+
+target_metadata = Base.metadata
+
+
+def run_migrations_offline() -> None:
+ """Run migrations in 'offline' mode.
+
+ Configures the context with just a URL and not an Engine.
+ Calls to context.execute() here emit the given string to the script output.
+ """
+ url = config.get_main_option("sqlalchemy.url")
+ context.configure(
+ url=url,
+ target_metadata=target_metadata,
+ literal_binds=True,
+ dialect_opts={"paramstyle": "named"},
+ )
+
+ with context.begin_transaction():
+ context.run_migrations()
+
+
+def do_run_migrations(connection):
+ context.configure(connection=connection, target_metadata=target_metadata)
+ with context.begin_transaction():
+ context.run_migrations()
+
+
+async def run_async_migrations() -> None:
+ """Run migrations in 'online' mode using an async engine."""
+ connectable = async_engine_from_config(
+ config.get_section(config.config_ini_section, {}),
+ prefix="sqlalchemy.",
+ poolclass=pool.NullPool,
+ )
+
+ async with connectable.connect() as connection:
+ await connection.run_sync(do_run_migrations)
+
+ await connectable.dispose()
+
+
+def run_migrations_online() -> None:
+ """Run migrations in 'online' mode."""
+ asyncio.run(run_async_migrations())
+
+
+if context.is_offline_mode():
+ run_migrations_offline()
+else:
+ run_migrations_online()
diff --git a/shared/alembic/script.py.mako b/shared/alembic/script.py.mako
new file mode 100644
index 0000000..fbc4b07
--- /dev/null
+++ b/shared/alembic/script.py.mako
@@ -0,0 +1,26 @@
+"""${message}
+
+Revision ID: ${up_revision}
+Revises: ${down_revision | comma,n}
+Create Date: ${create_date}
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+${imports if imports else ""}
+
+# revision identifiers, used by Alembic.
+revision: str = ${repr(up_revision)}
+down_revision: Union[str, None] = ${repr(down_revision)}
+branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
+depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
+
+
+def upgrade() -> None:
+ ${upgrades if upgrades else "pass"}
+
+
+def downgrade() -> None:
+ ${downgrades if downgrades else "pass"}
diff --git a/shared/alembic/versions/.gitkeep b/shared/alembic/versions/.gitkeep
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/shared/alembic/versions/.gitkeep
diff --git a/shared/pyproject.toml b/shared/pyproject.toml
index bd09d3e..c36f00b 100644
--- a/shared/pyproject.toml
+++ b/shared/pyproject.toml
@@ -8,6 +8,13 @@ dependencies = [
"pydantic-settings>=2.0",
"redis>=5.0",
"asyncpg>=0.29",
+ "sqlalchemy[asyncio]>=2.0",
+ "alembic>=1.13",
+ "structlog>=24.0",
+ "prometheus-client>=0.20",
+ "pyyaml>=6.0",
+ "aiohttp>=3.9",
+ "rich>=13.0",
]
[project.optional-dependencies]
diff --git a/shared/src/shared/sa_models.py b/shared/src/shared/sa_models.py
new file mode 100644
index 0000000..0537846
--- /dev/null
+++ b/shared/src/shared/sa_models.py
@@ -0,0 +1,93 @@
+"""SQLAlchemy 2.0 ORM models mirroring the existing asyncpg tables."""
+
+from datetime import datetime
+from decimal import Decimal
+
+from sqlalchemy import DateTime, ForeignKey, Numeric, Text
+from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+class CandleRow(Base):
+ __tablename__ = "candles"
+
+ symbol: Mapped[str] = mapped_column(Text, primary_key=True)
+ timeframe: Mapped[str] = mapped_column(Text, primary_key=True)
+ open_time: Mapped[datetime] = mapped_column(DateTime(timezone=True), primary_key=True)
+ open: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ high: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ low: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ close: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ volume: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+
+
+class SignalRow(Base):
+ __tablename__ = "signals"
+
+ id: Mapped[str] = mapped_column(Text, primary_key=True)
+ strategy: Mapped[str] = mapped_column(Text, nullable=False)
+ symbol: Mapped[str] = mapped_column(Text, nullable=False)
+ side: Mapped[str] = mapped_column(Text, nullable=False)
+ price: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ reason: Mapped[str | None] = mapped_column(Text)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
+
+
+class OrderRow(Base):
+ __tablename__ = "orders"
+
+ id: Mapped[str] = mapped_column(Text, primary_key=True)
+ signal_id: Mapped[str | None] = mapped_column(
+ Text, ForeignKey("signals.id")
+ )
+ symbol: Mapped[str] = mapped_column(Text, nullable=False)
+ side: Mapped[str] = mapped_column(Text, nullable=False)
+ type: Mapped[str] = mapped_column(Text, nullable=False)
+ price: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ status: Mapped[str] = mapped_column(
+ Text, nullable=False, server_default="PENDING"
+ )
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
+ filled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
+
+
+class TradeRow(Base):
+ __tablename__ = "trades"
+
+ id: Mapped[str] = mapped_column(Text, primary_key=True)
+ order_id: Mapped[str | None] = mapped_column(
+ Text, ForeignKey("orders.id")
+ )
+ symbol: Mapped[str] = mapped_column(Text, nullable=False)
+ side: Mapped[str] = mapped_column(Text, nullable=False)
+ price: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ fee: Mapped[Decimal] = mapped_column(
+ Numeric, nullable=False, server_default="0"
+ )
+ traded_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
+
+
+class PositionRow(Base):
+ __tablename__ = "positions"
+
+ symbol: Mapped[str] = mapped_column(Text, primary_key=True)
+ quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ avg_entry_price: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ current_price: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
+
+
+class PortfolioSnapshotRow(Base):
+ __tablename__ = "portfolio_snapshots"
+
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
+ total_value: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ realized_pnl: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ unrealized_pnl: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
+ snapshot_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py
new file mode 100644
index 0000000..de994c5
--- /dev/null
+++ b/shared/tests/test_sa_models.py
@@ -0,0 +1,244 @@
+"""Tests for SQLAlchemy ORM models."""
+
+import pytest
+from sqlalchemy import inspect
+
+
+def test_base_metadata_has_all_tables():
+ from shared.sa_models import Base
+
+ table_names = set(Base.metadata.tables.keys())
+ expected = {
+ "candles",
+ "signals",
+ "orders",
+ "trades",
+ "positions",
+ "portfolio_snapshots",
+ }
+ assert expected == table_names
+
+
+class TestCandleRow:
+ def test_table_name(self):
+ from shared.sa_models import CandleRow
+
+ assert CandleRow.__tablename__ == "candles"
+
+ def test_columns(self):
+ from shared.sa_models import CandleRow
+
+ mapper = inspect(CandleRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "symbol",
+ "timeframe",
+ "open_time",
+ "open",
+ "high",
+ "low",
+ "close",
+ "volume",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import CandleRow
+
+ mapper = inspect(CandleRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["symbol", "timeframe", "open_time"]
+
+
+class TestSignalRow:
+ def test_table_name(self):
+ from shared.sa_models import SignalRow
+
+ assert SignalRow.__tablename__ == "signals"
+
+ def test_columns(self):
+ from shared.sa_models import SignalRow
+
+ mapper = inspect(SignalRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "id",
+ "strategy",
+ "symbol",
+ "side",
+ "price",
+ "quantity",
+ "reason",
+ "created_at",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import SignalRow
+
+ mapper = inspect(SignalRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["id"]
+
+
+class TestOrderRow:
+ def test_table_name(self):
+ from shared.sa_models import OrderRow
+
+ assert OrderRow.__tablename__ == "orders"
+
+ def test_columns(self):
+ from shared.sa_models import OrderRow
+
+ mapper = inspect(OrderRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "id",
+ "signal_id",
+ "symbol",
+ "side",
+ "type",
+ "price",
+ "quantity",
+ "status",
+ "created_at",
+ "filled_at",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import OrderRow
+
+ mapper = inspect(OrderRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["id"]
+
+ def test_signal_id_foreign_key(self):
+ from shared.sa_models import OrderRow
+
+ table = OrderRow.__table__
+ fk_cols = {
+ fk.parent.name: fk.target_fullname for fk in table.foreign_keys
+ }
+ assert fk_cols == {"signal_id": "signals.id"}
+
+
+class TestTradeRow:
+ def test_table_name(self):
+ from shared.sa_models import TradeRow
+
+ assert TradeRow.__tablename__ == "trades"
+
+ def test_columns(self):
+ from shared.sa_models import TradeRow
+
+ mapper = inspect(TradeRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "id",
+ "order_id",
+ "symbol",
+ "side",
+ "price",
+ "quantity",
+ "fee",
+ "traded_at",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import TradeRow
+
+ mapper = inspect(TradeRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["id"]
+
+ def test_order_id_foreign_key(self):
+ from shared.sa_models import TradeRow
+
+ table = TradeRow.__table__
+ fk_cols = {
+ fk.parent.name: fk.target_fullname for fk in table.foreign_keys
+ }
+ assert fk_cols == {"order_id": "orders.id"}
+
+
+class TestPositionRow:
+ def test_table_name(self):
+ from shared.sa_models import PositionRow
+
+ assert PositionRow.__tablename__ == "positions"
+
+ def test_columns(self):
+ from shared.sa_models import PositionRow
+
+ mapper = inspect(PositionRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "symbol",
+ "quantity",
+ "avg_entry_price",
+ "current_price",
+ "updated_at",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import PositionRow
+
+ mapper = inspect(PositionRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["symbol"]
+
+
+class TestPortfolioSnapshotRow:
+ def test_table_name(self):
+ from shared.sa_models import PortfolioSnapshotRow
+
+ assert PortfolioSnapshotRow.__tablename__ == "portfolio_snapshots"
+
+ def test_columns(self):
+ from shared.sa_models import PortfolioSnapshotRow
+
+ mapper = inspect(PortfolioSnapshotRow)
+ cols = {c.key for c in mapper.column_attrs}
+ expected = {
+ "id",
+ "total_value",
+ "realized_pnl",
+ "unrealized_pnl",
+ "snapshot_at",
+ }
+ assert expected == cols
+
+ def test_primary_key(self):
+ from shared.sa_models import PortfolioSnapshotRow
+
+ mapper = inspect(PortfolioSnapshotRow)
+ pk_cols = [c.name for c in mapper.mapper.primary_key]
+ assert pk_cols == ["id"]
+
+ def test_id_is_autoincrement(self):
+ from shared.sa_models import PortfolioSnapshotRow
+
+ table = PortfolioSnapshotRow.__table__
+ id_col = table.c.id
+ assert id_col.autoincrement is True or id_col.autoincrement == "auto"
+
+
+class TestStatusDefault:
+ def test_order_status_server_default(self):
+ from shared.sa_models import OrderRow
+
+ table = OrderRow.__table__
+ status_col = table.c.status
+ assert status_col.server_default is not None
+ assert status_col.server_default.arg == "PENDING"
+
+ def test_trade_fee_server_default(self):
+ from shared.sa_models import TradeRow
+
+ table = TradeRow.__table__
+ fee_col = table.c.fee
+ assert fee_col.server_default is not None
+ assert fee_col.server_default.arg == "0"