Unit of work. ORM data mapper no longer needed - Noloquideus/fastapi-template GitHub Wiki
Unit of Work Pattern
A comprehensive guide to implementing the Unit of Work pattern for managing database transactions and coordinating changes in your FastAPI Clean Architecture application.
Table of Contents
- What is Unit of Work
- Why Use Unit of Work
- How It Works
- Related Patterns
- Implementation in FastAPI
- Advanced Usage
- Best Practices
- Testing
- Common Pitfalls
What is Unit of Work
The Unit of Work pattern is designed to track changes to objects and then coordinate saving them to the database in a single transaction. It maintains a list of objects affected by a business transaction and coordinates writing out changes.
Core Concept
Instead of immediately persisting changes to the database when they occur, Unit of Work:
- Registers changes (new, modified, deleted objects)
- Coordinates persistence when explicitly committed
- Manages transaction boundaries efficiently
# Without Unit of Work - immediate persistence
user = User(name="John")
session.add(user) # Immediate DB interaction
session.commit() # Another DB call
order = Order(user_id=user.id)
session.add(order) # Yet another DB call
session.commit() # And another commit
# With Unit of Work - coordinated persistence
with unit_of_work:
user = unit_of_work.users.create(name="John") # Registered, not persisted
order = unit_of_work.orders.create(user=user) # Registered, not persisted
# Single transaction with optimized queries
Why Use Unit of Work
1. Transaction Lifetime Management
Controls exactly when database transactions begin and end, preventing long-running transactions.
async def transfer_funds(from_account_id: int, to_account_id: int, amount: Decimal):
async with unit_of_work:
# Short, well-defined transaction scope
from_account = await unit_of_work.accounts.get(from_account_id)
to_account = await unit_of_work.accounts.get(to_account_id)
from_account.withdraw(amount)
to_account.deposit(amount)
# All changes committed together or not at all
2. Performance Optimization
Batches database operations and can optimize queries:
# Without UoW - N+1 database calls
for user_data in user_list:
user = User(**user_data)
session.add(user)
session.commit() # One commit per user
# With UoW - Single batch operation
async with unit_of_work:
for user_data in user_list:
unit_of_work.users.register_new(User(**user_data))
# Single commit with batch insert
3. Change Tracking
Automatically tracks what needs to be persisted:
async with unit_of_work:
# UoW tracks all these changes
user = await unit_of_work.users.get(user_id)
user.email = "[email protected]" # Tracked as dirty
order = Order(user_id=user.id)
unit_of_work.orders.register_new(order) # Tracked as new
old_product = await unit_of_work.products.get(product_id)
unit_of_work.products.register_deleted(old_product) # Tracked as deleted
How It Works
Two-Phase Process
Phase 1: Registration
Objects register their changes with the Unit of Work:
- register_new() - For new objects to be inserted
- register_dirty() - For modified objects to be updated
- register_deleted() - For objects to be deleted
Phase 2: Commitment
All registered changes are persisted in a single transaction:
- commit() - Saves all changes to database
- rollback() - Discards all changes if error occurs
Workflow Diagram
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
│ Application │ │ Unit of Work │ │ Database │
└─────────────────┘ └──────────────────┘ └─────────────────┘
│ │ │
│ register_new(user) │ │
│──────────────────────>│ │
│ │ │
│ register_dirty(order) │ │
│──────────────────────>│ │
│ │ │
│ commit() │ │
│──────────────────────>│ │
│ │ BEGIN TRANSACTION │
│ │──────────────────────>│
│ │ INSERT user │
│ │──────────────────────>│
│ │ UPDATE order │
│ │──────────────────────>│
│ │ COMMIT │
│ │──────────────────────>│
Related Patterns
Data Mapper Pattern
Unit of Work doesn't access the database directly. Instead, it delegates to Data Mapper objects:
class UserMapper:
def insert(self, user: User) -> None:
# Handle user insertion logic
def update(self, user: User) -> None:
# Handle user update logic
def delete(self, user: User) -> None:
# Handle user deletion logic
class UnitOfWork:
def __init__(self):
self.user_mapper = UserMapper()
self.new_objects = []
self.dirty_objects = []
self.deleted_objects = []
def commit(self):
for obj in self.new_objects:
self.user_mapper.insert(obj)
# ... handle dirty and deleted objects
Identity Map Pattern
Often used together with Unit of Work to ensure object identity:
class IdentityMap:
def __init__(self):
self._objects = {}
def get(self, object_id: int, object_type: type):
key = (object_type, object_id)
return self._objects.get(key)
def add(self, obj, object_id: int):
key = (type(obj), object_id)
self._objects[key] = obj
Implementation in FastAPI
Abstract Interface
from abc import ABC, abstractmethod
from sqlalchemy.ext.asyncio import AsyncSession
class IUnitOfWork(ABC):
"""Abstract Unit of Work interface."""
# DAO interfaces for different entities
organization_dao: IOrganizationDao = None
waste_type_dao: IWasteTypeDao = None
storage_dao: IStorageDao = None
waste_transfer_dao: IWasteTransferDao = None
def __init__(self):
self.session_factory = async_session_maker
@abstractmethod
async def __aenter__(self):
"""Async context manager entry."""
raise NotImplementedError
@abstractmethod
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
raise NotImplementedError
@abstractmethod
async def commit(self):
"""Commit all registered changes."""
raise NotImplementedError
@abstractmethod
async def rollback(self):
"""Rollback all registered changes."""
raise NotImplementedError
Concrete Implementation
from src.infrastructure.logger import logger
class UnitOfWork(IUnitOfWork):
"""Concrete Unit of Work implementation using SQLAlchemy."""
async def __aenter__(self):
"""Initialize session and DAOs."""
logger.debug('Unit of work created')
self.__session: AsyncSession = self.session_factory()
# Initialize all DAOs with the same session
self.organization_dao = OrganizationDao(self.__session)
self.waste_type_dao = WasteTypeDao(self.__session)
self.storage_dao = StorageDao(self.__session)
self.waste_transfer_dao = WasteTransferDao(self.__session)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Handle transaction completion and cleanup."""
logger.debug('Unit of work closing')
try:
if exc_type:
logger.debug(f'Exception occurred: {exc_val}')
logger.debug('Rolling back transaction')
await self.rollback()
else:
logger.debug('Committing transaction')
await self.__session.flush() # Flush pending changes
logger.debug('Session flushed')
await self.commit()
logger.debug('Transaction committed')
finally:
logger.debug('Closing session')
await self.__session.close()
async def commit(self):
"""Commit the current transaction."""
logger.debug('Committing transaction')
await self.__session.commit()
async def rollback(self):
"""Rollback the current transaction."""
logger.debug('Rolling back transaction')
await self.__session.rollback()
DAO (Data Access Object) Example
from abc import ABC, abstractmethod
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
class IOrganizationDao(ABC):
"""Abstract DAO interface for Organization operations."""
@abstractmethod
async def get_by_id(self, org_id: int) -> Optional[Organization]:
pass
@abstractmethod
async def get_all(self) -> List[Organization]:
pass
@abstractmethod
async def create(self, organization: Organization) -> Organization:
pass
@abstractmethod
async def update(self, organization: Organization) -> Organization:
pass
@abstractmethod
async def delete(self, org_id: int) -> bool:
pass
class OrganizationDao(IOrganizationDao):
"""Concrete DAO implementation."""
def __init__(self, session: AsyncSession):
self._session = session
async def get_by_id(self, org_id: int) -> Optional[Organization]:
stmt = select(Organization).where(Organization.id == org_id)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_all(self) -> List[Organization]:
stmt = select(Organization)
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def create(self, organization: Organization) -> Organization:
self._session.add(organization)
# Note: No commit here - UoW handles it
return organization
async def update(self, organization: Organization) -> Organization:
# Object is already tracked by session
# Changes will be persisted on UoW commit
return organization
async def delete(self, org_id: int) -> bool:
organization = await self.get_by_id(org_id)
if organization:
await self._session.delete(organization)
return True
return False
Service Layer Integration
class OrganizationService:
"""Application service using Unit of Work."""
def __init__(self, unit_of_work: IUnitOfWork):
self._uow = unit_of_work
async def create_organization_with_storage(
self,
org_data: dict,
storage_data: dict
) -> dict:
"""Create organization and its storage in single transaction."""
async with self._uow:
# Create organization
organization = Organization(**org_data)
created_org = await self._uow.organization_dao.create(organization)
# Create associated storage
storage_data['organization_id'] = created_org.id
storage = Storage(**storage_data)
created_storage = await self._uow.storage_dao.create(storage)
# Both operations committed together
# If either fails, both are rolled back
return {
'organization': created_org,
'storage': created_storage
}
async def transfer_waste(
self,
from_storage_id: int,
to_storage_id: int,
waste_type_id: int,
quantity: float
) -> WasteTransfer:
"""Transfer waste between storages atomically."""
async with self._uow:
# Get entities
from_storage = await self._uow.storage_dao.get_by_id(from_storage_id)
to_storage = await self._uow.storage_dao.get_by_id(to_storage_id)
waste_type = await self._uow.waste_type_dao.get_by_id(waste_type_id)
if not all([from_storage, to_storage, waste_type]):
raise ValueError("Invalid storage or waste type")
# Business logic validation
if from_storage.get_waste_quantity(waste_type_id) < quantity:
raise ValueError("Insufficient waste quantity")
# Update storages
from_storage.remove_waste(waste_type_id, quantity)
to_storage.add_waste(waste_type_id, quantity)
await self._uow.storage_dao.update(from_storage)
await self._uow.storage_dao.update(to_storage)
# Create transfer record
transfer = WasteTransfer(
from_storage_id=from_storage_id,
to_storage_id=to_storage_id,
waste_type_id=waste_type_id,
quantity=quantity
)
return await self._uow.waste_transfer_dao.create(transfer)
FastAPI Controller Integration
from fastapi import APIRouter, Depends, HTTPException
from dependency_injector.wiring import inject, Provide
router = APIRouter(prefix="/organizations", tags=["organizations"])
@router.post("/")
@inject
async def create_organization_with_storage(
request: CreateOrganizationRequest,
service: OrganizationService = Depends(Provide["organization_service"])
):
"""Create organization and storage in single transaction."""
try:
result = await service.create_organization_with_storage(
org_data=request.organization.dict(),
storage_data=request.storage.dict()
)
return {
"organization_id": result['organization'].id,
"storage_id": result['storage'].id,
"message": "Organization and storage created successfully"
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/waste-transfer")
@inject
async def transfer_waste(
request: WasteTransferRequest,
service: OrganizationService = Depends(Provide["organization_service"])
):
"""Transfer waste between storages."""
try:
transfer = await service.transfer_waste(
from_storage_id=request.from_storage_id,
to_storage_id=request.to_storage_id,
waste_type_id=request.waste_type_id,
quantity=request.quantity
)
return {
"transfer_id": transfer.id,
"message": "Waste transfer completed successfully"
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
Advanced Usage
Custom Change Tracking
class AdvancedUnitOfWork(IUnitOfWork):
"""UoW with explicit change tracking."""
def __init__(self):
super().__init__()
self._new_objects = set()
self._dirty_objects = set()
self._deleted_objects = set()
def register_new(self, obj):
"""Register new object for insertion."""
self._new_objects.add(obj)
def register_dirty(self, obj):
"""Register object for update."""
if obj not in self._new_objects:
self._dirty_objects.add(obj)
def register_deleted(self, obj):
"""Register object for deletion."""
if obj in self._new_objects:
self._new_objects.remove(obj)
else:
self._dirty_objects.discard(obj)
self._deleted_objects.add(obj)
async def commit(self):
"""Commit all tracked changes."""
# Insert new objects
for obj in self._new_objects:
self._session.add(obj)
# Delete objects
for obj in self._deleted_objects:
await self._session.delete(obj)
# Dirty objects are automatically tracked by SQLAlchemy
await self._session.commit()
self._clear_tracking()
def _clear_tracking(self):
"""Clear all tracking sets after commit."""
self._new_objects.clear()
self._dirty_objects.clear()
self._deleted_objects.clear()
Nested Unit of Work
class NestedUnitOfWork:
"""Support for nested transactions using savepoints."""
async def __aenter__(self):
if hasattr(self, '_session') and self._session:
# Create savepoint for nested transaction
self._savepoint = await self._session.begin_nested()
else:
# Create main transaction
self._session = self.session_factory()
self._savepoint = None
self._initialize_daos()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type:
if self._savepoint:
await self._savepoint.rollback()
else:
await self._session.rollback()
else:
if self._savepoint:
await self._savepoint.commit()
else:
await self._session.commit()
finally:
if not self._savepoint:
await self._session.close()
Performance Optimization
class OptimizedUnitOfWork(UnitOfWork):
"""UoW with performance optimizations."""
def __init__(self):
super().__init__()
self._batch_size = 1000
self._deferred_operations = []
async def bulk_insert(self, objects: List[Any]):
"""Perform bulk insert operation."""
if len(objects) > self._batch_size:
# Split into batches
for i in range(0, len(objects), self._batch_size):
batch = objects[i:i + self._batch_size]
await self._session.bulk_insert_mappings(
type(batch[0]),
[obj.__dict__ for obj in batch]
)
else:
self._session.add_all(objects)
async def bulk_update(self, model_class, updates: List[dict]):
"""Perform bulk update operation."""
await self._session.bulk_update_mappings(model_class, updates)
def defer_operation(self, operation_func, *args, **kwargs):
"""Defer operation until commit."""
self._deferred_operations.append((operation_func, args, kwargs))
async def commit(self):
"""Execute deferred operations then commit."""
for operation_func, args, kwargs in self._deferred_operations:
await operation_func(*args, **kwargs)
await super().commit()
self._deferred_operations.clear()
Best Practices
1. Keep Transactions Short
# ✅ Good - Short transaction scope
async def process_order(order_data: dict):
async with unit_of_work:
order = await unit_of_work.orders.create(Order(**order_data))
await unit_of_work.inventory.reserve_items(order.items)
# Transaction ends here
# ❌ Bad - Long-running transaction
async def process_order_bad(order_data: dict):
async with unit_of_work:
order = await unit_of_work.orders.create(Order(**order_data))
# Long-running operations inside transaction
await send_confirmation_email(order.user.email) # Network call
await update_external_system(order) # Another network call
await generate_pdf_invoice(order) # CPU intensive
2. Handle Errors Gracefully
async def transfer_funds_safe(from_id: int, to_id: int, amount: Decimal):
try:
async with unit_of_work:
from_account = await unit_of_work.accounts.get(from_id)
to_account = await unit_of_work.accounts.get(to_id)
if not from_account or not to_account:
raise ValueError("Invalid account")
if from_account.balance < amount:
raise ValueError("Insufficient funds")
from_account.withdraw(amount)
to_account.deposit(amount)
# Automatic rollback on exception
except ValueError:
# Business logic errors - re-raise
raise
except Exception as e:
# Unexpected errors - log and handle
logger.error(f"Unexpected error in fund transfer: {e}")
raise RuntimeError("Transfer failed due to system error")
3. Use Dependency Injection
# Container configuration
from dependency_injector import containers, providers
class Container(containers.DeclarativeContainer):
# Session factory
session_factory = providers.Factory(async_session_maker)
# Unit of Work
unit_of_work = providers.Factory(UnitOfWork)
# Services
organization_service = providers.Factory(
OrganizationService,
unit_of_work=unit_of_work
)
# FastAPI dependency
def get_unit_of_work() -> IUnitOfWork:
return Container.unit_of_work()
# Usage in endpoints
@router.post("/")
async def create_org(
request: CreateOrgRequest,
uow: IUnitOfWork = Depends(get_unit_of_work)
):
async with uow:
# Use unit of work
pass
4. Implement Proper Logging
class LoggedUnitOfWork(UnitOfWork):
"""UoW with comprehensive logging."""
async def __aenter__(self):
correlation_id = getattr(contextvars.trace_id, 'get', lambda: 'unknown')()
logger.info(f"Starting transaction [correlation_id={correlation_id}]")
start_time = time.time()
result = await super().__aenter__()
setup_time = time.time() - start_time
logger.debug(f"UoW setup completed in {setup_time:.3f}s")
return result
async def __aexit__(self, exc_type, exc_val, exc_tb):
start_time = time.time()
try:
await super().__aexit__(exc_type, exc_val, exc_tb)
if exc_type:
logger.warning(f"Transaction rolled back due to {exc_type.__name__}: {exc_val}")
else:
completion_time = time.time() - start_time
logger.info(f"Transaction committed successfully in {completion_time:.3f}s")
except Exception as e:
logger.error(f"Error during transaction cleanup: {e}")
raise
Testing
Unit Testing
import pytest
from unittest.mock import AsyncMock, Mock
class TestUnitOfWork:
@pytest.fixture
def mock_session_factory(self):
session_mock = AsyncMock()
factory_mock = Mock(return_value=session_mock)
return factory_mock, session_mock
@pytest.fixture
def unit_of_work(self, mock_session_factory):
factory_mock, session_mock = mock_session_factory
uow = UnitOfWork()
uow.session_factory = factory_mock
return uow, session_mock
async def test_successful_commit(self, unit_of_work):
uow, session_mock = unit_of_work
async with uow:
# Simulate operations
await uow.organization_dao.create(Organization(name="Test"))
# Verify session interactions
session_mock.flush.assert_called_once()
session_mock.commit.assert_called_once()
session_mock.close.assert_called_once()
async def test_rollback_on_exception(self, unit_of_work):
uow, session_mock = unit_of_work
with pytest.raises(ValueError):
async with uow:
raise ValueError("Test error")
# Verify rollback was called
session_mock.rollback.assert_called_once()
session_mock.close.assert_called_once()
assert not session_mock.commit.called
Integration Testing
@pytest.mark.asyncio
async def test_organization_creation_integration():
"""Test actual database integration."""
async with UnitOfWork() as uow:
# Create organization
org_data = {"name": "Test Org", "address": "123 Test St"}
organization = Organization(**org_data)
created_org = await uow.organization_dao.create(organization)
# Create storage
storage_data = {"name": "Main Storage", "organization_id": created_org.id}
storage = Storage(**storage_data)
created_storage = await uow.storage_dao.create(storage)
# Verify objects are created
assert created_org.id is not None
assert created_storage.id is not None
assert created_storage.organization_id == created_org.id
@pytest.mark.asyncio
async def test_transaction_rollback_integration():
"""Test that rollback works correctly."""
initial_count = await get_organization_count()
with pytest.raises(ValueError):
async with UnitOfWork() as uow:
# Create organization
org = Organization(name="Test Org")
await uow.organization_dao.create(org)
# Force an error
raise ValueError("Simulated error")
# Verify no data was committed
final_count = await get_organization_count()
assert final_count == initial_count
Common Pitfalls
1. Long-Running Transactions
# ❌ Bad - Transaction held too long
async def bad_order_processing():
async with unit_of_work:
order = await unit_of_work.orders.create(order_data)
# These operations shouldn't be in the transaction
await send_email_notification(order) # Network I/O
await call_external_api(order) # External dependency
await generate_report(order) # CPU intensive
# Transaction held for too long!
# ✅ Good - Separate concerns
async def good_order_processing():
# Quick transaction for data consistency
async with unit_of_work:
order = await unit_of_work.orders.create(order_data)
order_id = order.id
# Separate operations outside transaction
await send_email_notification(order_id)
await call_external_api(order_id)
await generate_report(order_id)
2. Mixing Direct Session Usage
# ❌ Bad - Bypassing UoW
async with unit_of_work:
organization = await unit_of_work.organization_dao.create(org_data)
# Don't do this - bypasses UoW coordination
session = unit_of_work._UnitOfWork__session
session.add(Storage(organization_id=organization.id))
await session.commit() # Commits outside UoW control
# ✅ Good - Use UoW consistently
async with unit_of_work:
organization = await unit_of_work.organization_dao.create(org_data)
storage = await unit_of_work.storage_dao.create(
Storage(organization_id=organization.id)
)
3. Forgetting Error Handling
# ❌ Bad - No error handling
async def create_organization(org_data):
async with unit_of_work:
return await unit_of_work.organization_dao.create(Organization(**org_data))
# ✅ Good - Proper error handling
async def create_organization(org_data):
try:
async with unit_of_work:
organization = Organization(**org_data)
return await unit_of_work.organization_dao.create(organization)
except IntegrityError:
raise ValueError("Organization with this name already exists")
except Exception as e:
logger.error(f"Failed to create organization: {e}")
raise RuntimeError("Failed to create organization")
4. Nested Transaction Issues
# ❌ Bad - Uncontrolled nesting
async def problematic_nesting():
async with unit_of_work: # Outer transaction
org = await unit_of_work.organization_dao.create(org_data)
# This might create issues
async with unit_of_work: # Inner transaction
storage = await unit_of_work.storage_dao.create(storage_data)
# ✅ Good - Clear transaction boundaries
async def clear_boundaries():
# First transaction
async with unit_of_work:
org = await unit_of_work.organization_dao.create(org_data)
org_id = org.id
# Second transaction
async with unit_of_work:
storage_data['organization_id'] = org_id
storage = await unit_of_work.storage_dao.create(storage_data)
Conclusion
The Unit of Work pattern is essential for managing database transactions in complex applications. It provides:
Key Benefits:
- Transaction Control - Precise management of transaction boundaries
- Performance - Optimized database operations through batching
- Consistency - Ensures data integrity across multiple operations
- Testability - Easy to mock and test business logic
When to Use:
- Complex business transactions involving multiple entities
- Performance-critical applications requiring optimized database access
- Domain-driven design where business operations span multiple aggregates
- Applications requiring strict data consistency
Remember:
- Keep transactions short and focused
- Handle errors gracefully with proper rollback
- Don't mix direct session usage with UoW
- Test thoroughly including failure scenarios
- Log appropriately for debugging and monitoring
The Unit of Work pattern, when implemented correctly, provides a solid foundation for managing data persistence in your FastAPI Clean Architecture application.
See Also:
- Database Operations - Working with databases efficiently
- Testing Guide - Testing strategies for data operations
- Clean Architecture Guide - Understanding the overall architecture
- API Development - Building robust API endpoints