summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 17:56:50 +0900
committerTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 17:56:50 +0900
commit13a9b2c80bb3eb1353cf2d49bdbf7d0dbd858ccc (patch)
tree4595b83f1ba4fe5d1bdf4694f53496120956085a
parentfa7e1dc44787592da647bdda0a63310be0cfcc8b (diff)
feat(broker): add Redis consumer groups for reliable message processing
-rw-r--r--services/order-executor/src/order_executor/main.py40
-rw-r--r--services/portfolio-manager/src/portfolio_manager/main.py41
-rw-r--r--shared/src/shared/broker.py66
-rw-r--r--shared/tests/test_broker.py120
4 files changed, 251 insertions, 16 deletions
diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py
index 4a51d5d..930517e 100644
--- a/services/order-executor/src/order_executor/main.py
+++ b/services/order-executor/src/order_executor/main.py
@@ -62,7 +62,8 @@ async def run() -> None:
dry_run=config.dry_run,
)
- last_id = "$"
+ GROUP = "order-executor"
+ CONSUMER = "executor-1"
stream = "signals"
health = HealthCheckServer(
@@ -76,10 +77,35 @@ async def run() -> None:
log.info("service_started", stream=stream, dry_run=config.dry_run)
+ await broker.ensure_group(stream, GROUP)
+
+ # Process pending messages first (from previous crash)
+ pending = await broker.read_pending(stream, GROUP, CONSUMER)
+ for msg_id, msg in pending:
+ try:
+ event = Event.from_dict(msg)
+ if event.type == EventType.SIGNAL:
+ signal = event.data
+ log.info(
+ "processing_pending_signal", signal_id=str(signal.id), symbol=signal.symbol
+ )
+ await executor.execute(signal)
+ metrics.events_processed.labels(
+ service="order-executor", event_type="signal"
+ ).inc()
+ await broker.ack(stream, GROUP, msg_id)
+ except Exception as exc:
+ log.error("pending_process_failed", error=str(exc), msg_id=msg_id)
+ metrics.errors_total.labels(
+ service="order-executor", error_type="processing"
+ ).inc()
+
try:
while True:
- messages = await broker.read(stream, last_id=last_id, count=10, block=5000)
- for msg in messages:
+ messages = await broker.read_group(
+ stream, GROUP, CONSUMER, count=10, block=5000
+ )
+ for msg_id, msg in messages:
try:
event = Event.from_dict(msg)
if event.type == EventType.SIGNAL:
@@ -91,16 +117,12 @@ async def run() -> None:
metrics.events_processed.labels(
service="order-executor", event_type="signal"
).inc()
+ await broker.ack(stream, GROUP, msg_id)
except Exception as exc:
- log.error("message_processing_failed", error=str(exc))
+ log.error("message_processing_failed", error=str(exc), msg_id=msg_id)
metrics.errors_total.labels(
service="order-executor", error_type="processing"
).inc()
- if messages:
- # Advance last_id to avoid re-reading — broker.read returns decoded dicts,
- # so we track progress by re-reading with "0" for replaying or "$" for new only.
- # Since we block on "$" we get only new messages each iteration.
- pass
except Exception as exc:
log.error("fatal_error", error=str(exc))
await notifier.send_error(str(exc), "order-executor")
diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py
index a7f1a14..87e4c64 100644
--- a/services/portfolio-manager/src/portfolio_manager/main.py
+++ b/services/portfolio-manager/src/portfolio_manager/main.py
@@ -84,13 +84,43 @@ async def run() -> None:
snapshot_loop(db, tracker, notifier, config.snapshot_interval_hours, log)
)
- last_id = "$"
+ GROUP = "portfolio-manager"
+ CONSUMER = "portfolio-1"
log.info("service_started", stream=ORDERS_STREAM)
+ await broker.ensure_group(ORDERS_STREAM, GROUP)
+
+ # Process pending messages first (from previous crash)
+ pending = await broker.read_pending(ORDERS_STREAM, GROUP, CONSUMER)
+ for msg_id, msg in pending:
+ try:
+ event = Event.from_dict(msg)
+ if isinstance(event, OrderEvent):
+ order = event.data
+ tracker.apply_order(order)
+ log.info(
+ "pending_order_applied",
+ symbol=order.symbol,
+ side=str(order.side),
+ quantity=str(order.quantity),
+ price=str(order.price),
+ )
+ metrics.events_processed.labels(
+ service="portfolio-manager", event_type="order"
+ ).inc()
+ await broker.ack(ORDERS_STREAM, GROUP, msg_id)
+ except Exception as exc:
+ log.error("pending_process_failed", error=str(exc), msg_id=msg_id)
+ metrics.errors_total.labels(
+ service="portfolio-manager", error_type="processing"
+ ).inc()
+
try:
while True:
- messages = await broker.read(ORDERS_STREAM, last_id=last_id, block=1000)
- for msg in messages:
+ messages = await broker.read_group(
+ ORDERS_STREAM, GROUP, CONSUMER, count=10, block=1000
+ )
+ for msg_id, msg in messages:
try:
event = Event.from_dict(msg)
if isinstance(event, OrderEvent):
@@ -108,13 +138,12 @@ async def run() -> None:
metrics.events_processed.labels(
service="portfolio-manager", event_type="order"
).inc()
+ await broker.ack(ORDERS_STREAM, GROUP, msg_id)
except Exception as exc:
- log.exception("message_processing_failed", error=str(exc))
+ log.exception("message_processing_failed", error=str(exc), msg_id=msg_id)
metrics.errors_total.labels(
service="portfolio-manager", error_type="processing"
).inc()
- # Update last_id to the latest processed message id if broker returns ids
- # Since broker.read returns parsed payloads (not ids), we use "$" to get new msgs
except Exception as exc:
log.error("fatal_error", error=str(exc))
await notifier.send_error(str(exc), "portfolio-manager")
diff --git a/shared/src/shared/broker.py b/shared/src/shared/broker.py
index 9c6c4c6..c060c24 100644
--- a/shared/src/shared/broker.py
+++ b/shared/src/shared/broker.py
@@ -17,6 +17,70 @@ class RedisBroker:
payload = json.dumps(data)
await self._redis.xadd(stream, {"payload": payload})
+ async def ensure_group(self, stream: str, group: str) -> None:
+ """Create a consumer group if it doesn't exist."""
+ try:
+ await self._redis.xgroup_create(stream, group, id="0", mkstream=True)
+ except redis.ResponseError as e:
+ if "BUSYGROUP" not in str(e):
+ raise
+
+ async def read_group(
+ self,
+ stream: str,
+ group: str,
+ consumer: str,
+ count: int = 10,
+ block: int = 0,
+ ) -> list[tuple[str, dict[str, Any]]]:
+ """Read messages from a consumer group. Returns list of (message_id, data)."""
+ results = await self._redis.xreadgroup(
+ group, consumer, {stream: ">"}, count=count, block=block
+ )
+ messages = []
+ if results:
+ for _stream, entries in results:
+ for msg_id, fields in entries:
+ payload = fields.get(b"payload") or fields.get("payload")
+ if payload:
+ if isinstance(payload, bytes):
+ payload = payload.decode()
+ if isinstance(msg_id, bytes):
+ msg_id = msg_id.decode()
+ messages.append((msg_id, json.loads(payload)))
+ return messages
+
+ async def ack(self, stream: str, group: str, *msg_ids: str) -> None:
+ """Acknowledge messages in a consumer group."""
+ if msg_ids:
+ await self._redis.xack(stream, group, *msg_ids)
+
+ async def read_pending(
+ self,
+ stream: str,
+ group: str,
+ consumer: str,
+ count: int = 10,
+ ) -> list[tuple[str, dict[str, Any]]]:
+ """Read pending (unacknowledged) messages for this consumer."""
+ results = await self._redis.xreadgroup(
+ group, consumer, {stream: "0"}, count=count
+ )
+ messages = []
+ if results:
+ for _stream, entries in results:
+ for msg_id, fields in entries:
+ if not fields: # Empty fields means already acknowledged
+ continue
+ payload = fields.get(b"payload") or fields.get("payload")
+ if payload:
+ if isinstance(payload, bytes):
+ payload = payload.decode()
+ if isinstance(msg_id, bytes):
+ msg_id = msg_id.decode()
+ messages.append((msg_id, json.loads(payload)))
+ return messages
+
async def read(
self,
stream: str,
@@ -24,7 +88,7 @@ class RedisBroker:
count: int = 10,
block: int = 0,
) -> list[dict[str, Any]]:
- """Read messages from a Redis stream."""
+ """Read messages (original method, kept for backward compatibility)."""
results = await self._redis.xread({stream: last_id}, count=count, block=block)
messages = []
if results:
diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py
index ea8b148..c33f6ec 100644
--- a/shared/tests/test_broker.py
+++ b/shared/tests/test_broker.py
@@ -2,6 +2,7 @@
import pytest
import json
+import redis
from unittest.mock import AsyncMock, patch
@@ -68,3 +69,122 @@ async def test_broker_close():
await broker.close()
mock_redis.aclose.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_broker_ensure_group():
+ """Test that ensure_group creates a consumer group."""
+ from shared.broker import RedisBroker
+
+ mock_redis = AsyncMock()
+ broker = RedisBroker.__new__(RedisBroker)
+ broker._redis = mock_redis
+
+ await broker.ensure_group("test-stream", "test-group")
+ mock_redis.xgroup_create.assert_called_once_with(
+ "test-stream", "test-group", id="0", mkstream=True
+ )
+
+
+@pytest.mark.asyncio
+async def test_broker_ensure_group_already_exists():
+ """Test that ensure_group ignores BUSYGROUP error."""
+ from shared.broker import RedisBroker
+
+ mock_redis = AsyncMock()
+ mock_redis.xgroup_create = AsyncMock(
+ side_effect=redis.ResponseError("BUSYGROUP Consumer Group name already exists")
+ )
+ broker = RedisBroker.__new__(RedisBroker)
+ broker._redis = mock_redis
+
+ # Should not raise
+ await broker.ensure_group("test-stream", "test-group")
+
+
+@pytest.mark.asyncio
+async def test_broker_read_group():
+ """Test that read_group parses xreadgroup response correctly."""
+ from shared.broker import RedisBroker
+
+ mock_redis = AsyncMock()
+ mock_redis.xreadgroup = AsyncMock(
+ return_value=[
+ (b"stream", [(b"1-0", {b"payload": b'{"type": "test"}'})])
+ ]
+ )
+ broker = RedisBroker.__new__(RedisBroker)
+ broker._redis = mock_redis
+
+ messages = await broker.read_group("stream", "group", "consumer")
+ assert len(messages) == 1
+ assert messages[0][0] == "1-0"
+ assert messages[0][1] == {"type": "test"}
+
+
+@pytest.mark.asyncio
+async def test_broker_ack():
+ """Test that ack calls xack on the redis connection."""
+ from shared.broker import RedisBroker
+
+ mock_redis = AsyncMock()
+ broker = RedisBroker.__new__(RedisBroker)
+ broker._redis = mock_redis
+
+ await broker.ack("stream", "group", "1-0", "2-0")
+ mock_redis.xack.assert_called_once_with("stream", "group", "1-0", "2-0")
+
+
+@pytest.mark.asyncio
+async def test_broker_read_pending():
+ """Test that read_pending reads unacknowledged messages."""
+ from shared.broker import RedisBroker
+
+ mock_redis = AsyncMock()
+ mock_redis.xreadgroup = AsyncMock(
+ return_value=[
+ (b"stream", [(b"1-0", {b"payload": b'{"type": "pending"}'})])
+ ]
+ )
+ broker = RedisBroker.__new__(RedisBroker)
+ broker._redis = mock_redis
+
+ messages = await broker.read_pending("stream", "group", "consumer")
+ assert len(messages) == 1
+ assert messages[0][0] == "1-0"
+ assert messages[0][1] == {"type": "pending"}
+ # Verify it uses "0" (not ">") to read pending
+ mock_redis.xreadgroup.assert_called_once_with(
+ "group", "consumer", {"stream": "0"}, count=10
+ )
+
+
+@pytest.mark.asyncio
+async def test_broker_read_pending_skips_empty_fields():
+ """Test that read_pending skips already-acknowledged entries with empty fields."""
+ from shared.broker import RedisBroker
+
+ mock_redis = AsyncMock()
+ mock_redis.xreadgroup = AsyncMock(
+ return_value=[
+ (b"stream", [(b"1-0", {})])
+ ]
+ )
+ broker = RedisBroker.__new__(RedisBroker)
+ broker._redis = mock_redis
+
+ messages = await broker.read_pending("stream", "group", "consumer")
+ assert len(messages) == 0
+
+
+@pytest.mark.asyncio
+async def test_broker_ack_no_ids():
+ """Test that ack does nothing when no message IDs are provided."""
+ from shared.broker import RedisBroker
+
+ mock_redis = AsyncMock()
+ broker = RedisBroker.__new__(RedisBroker)
+ broker._redis = mock_redis
+
+ await broker.ack("stream", "group")
+ mock_redis.xack.assert_not_called()