Saga Orchestrator with State Machine (Python)
import enum
from dataclasses import dataclass, field
class SagaState(enum.Enum):
STARTED = "STARTED"
PAYMENT_PENDING = "PAYMENT_PENDING"
PAYMENT_CONFIRMED = "PAYMENT_CONFIRMED"
INVENTORY_RESERVED = "INVENTORY_RESERVED"
SHIPPING_INITIATED = "SHIPPING_INITIATED"
COMPLETED = "COMPLETED"
COMPENSATING = "COMPENSATING"
FAILED = "FAILED"
@dataclass
class SagaExecution:
saga_id: str
order_id: str
state: SagaState = SagaState.STARTED
completed_steps: list = field(default_factory=list)
def persist(self, store):
"""Persist to DB -- survives orchestrator crashes."""
store.save(self.saga_id, {
"state": self.state.value,
"completed_steps": self.completed_steps
})
class SagaOrchestrator:
TRANSITIONS = {
SagaState.STARTED: ("process_payment", SagaState.PAYMENT_PENDING),
SagaState.PAYMENT_PENDING: ("confirm_payment", SagaState.PAYMENT_CONFIRMED),
SagaState.PAYMENT_CONFIRMED: ("reserve_inventory", SagaState.INVENTORY_RESERVED),
SagaState.INVENTORY_RESERVED: ("initiate_shipping", SagaState.SHIPPING_INITIATED),
SagaState.SHIPPING_INITIATED: ("finalize", SagaState.COMPLETED),
}
COMPENSATIONS = {
"initiate_shipping": "cancel_shipment",
"reserve_inventory": "release_inventory",
"confirm_payment": "refund_payment",
"process_payment": "void_authorization",
}
async def execute(self, saga: SagaExecution):
while saga.state in self.TRANSITIONS:
action, next_state = self.TRANSITIONS[saga.state]
try:
await getattr(self.services, action)(
saga.order_id,
idempotency_key=f"{saga.saga_id}:{action}"
)
saga.completed_steps.append(action)
saga.state = next_state
saga.persist(self.store)
except Exception as e:
saga.state = SagaState.COMPENSATING
saga.persist(self.store)
await self._compensate(saga)
return saga
return saga
async def _compensate(self, saga: SagaExecution):
"""Walk backward through completed steps."""
for step in reversed(saga.completed_steps):
comp = self.COMPENSATIONS.get(step)
if comp:
try:
await getattr(self.services, comp)(
saga.order_id,
idempotency_key=f"{saga.saga_id}:{comp}"
)
except Exception:
await self.store.send_to_dlq(saga.saga_id, comp)
saga.state = SagaState.FAILED
saga.persist(self.store)
Compensating Transaction Handler (Python)
import asyncio, logging, time
from dataclasses import dataclass
from typing import Callable, Awaitable
@dataclass
class CompensationStep:
name: str
execute: Callable[..., Awaitable]
max_retries: int = 3
base_delay: float = 1.0
class CompensationChain:
def __init__(self, dlq_client):
self.steps = []
self.dlq = dlq_client
self.log = logging.getLogger("compensation")
def add(self, name, handler, max_retries=3):
self.steps.append(CompensationStep(name, handler, max_retries))
return self
async def execute(self, saga_id: str, context: dict):
for step in reversed(self.steps):
await self._retry_step(saga_id, step, context)
async def _retry_step(self, saga_id, step, context):
for attempt in range(step.max_retries):
try:
idem_key = f"{saga_id}:comp:{step.name}:{attempt}"
await step.execute(context=context, idempotency_key=idem_key)
self.log.info(f"[{saga_id}] {step.name} compensated")
return
except Exception as e:
jitter = 0.7 + 0.6 * (hash(idem_key) % 100) / 100
delay = min(step.base_delay * (2 ** attempt) * jitter, 30.0)
self.log.warning(f"[{saga_id}] {step.name} retry {attempt+1}: {e}")
await asyncio.sleep(delay)
# Exhausted retries -- dead letter queue
self.log.critical(f"[{saga_id}] {step.name} -> DLQ")
await self.dlq.publish({
"saga_id": saga_id, "step": step.name,
"context": context, "failed_at": time.time()
})
# Usage
chain = CompensationChain(dlq_client=redis_dlq)
chain.add("cancel_shipment", shipping_svc.cancel)
chain.add("release_inventory", inventory_svc.release)
chain.add("refund_payment", payment_svc.refund)
await chain.execute("SAGA-12345", {"order_id": "ORD-2024-56789"})
Distributed Tracing with OpenTelemetry (Python)
from opentelemetry import trace
from opentelemetry.trace import StatusCode
from opentelemetry.propagate import inject
import time
tracer = trace.get_tracer("saga.orchestrator")
class TracedSagaOrchestrator:
async def execute_saga(self, order_id, steps):
with tracer.start_as_current_span(
"saga.execute",
attributes={"saga.order_id": order_id, "saga.steps": len(steps)}
) as parent:
trace_id = format(parent.get_span_context().trace_id, '032x')
completed = []
for step in steps:
try:
await self._traced_step(step, order_id)
completed.append(step)
except Exception as e:
parent.set_status(StatusCode.ERROR, str(e))
parent.set_attribute("saga.failed_at", step["name"])
await self._compensate_traced(completed, order_id)
return {"status": "compensated", "trace_id": trace_id}
parent.set_attribute("saga.status", "completed")
return {"status": "completed", "trace_id": trace_id}
async def _traced_step(self, step, order_id):
with tracer.start_as_current_span(
f"saga.step.{step['name']}",
attributes={"service.name": step["service"]}
) as span:
headers = {}
inject(headers) # propagate trace context across boundary
start = time.monotonic()
result = await step["handler"](
order_id=order_id, headers=headers,
deadline_ms=step.get("deadline_ms")
)
span.set_attribute("duration_ms", (time.monotonic() - start) * 1000)
return result
async def _compensate_traced(self, completed, order_id):
with tracer.start_as_current_span("saga.compensate"):
for step in reversed(completed):
with tracer.start_as_current_span(
f"saga.compensate.{step['name']}"
) as comp_span:
try:
await step["compensate"](order_id=order_id)
comp_span.set_attribute("status", "success")
except Exception as e:
comp_span.set_status(StatusCode.ERROR, str(e))
comp_span.set_attribute("status", "dlq")