Unit of work. ORM data mapper no longer needed - Noloquideus/fastapi-template GitHub Wiki

Unit of Work Pattern

A comprehensive guide to implementing the Unit of Work pattern for managing database transactions and coordinating changes in your FastAPI Clean Architecture application.

Table of Contents

What is Unit of Work

The Unit of Work pattern is designed to track changes to objects and then coordinate saving them to the database in a single transaction. It maintains a list of objects affected by a business transaction and coordinates writing out changes.

Core Concept

Instead of immediately persisting changes to the database when they occur, Unit of Work:

  1. Registers changes (new, modified, deleted objects)
  2. Coordinates persistence when explicitly committed
  3. Manages transaction boundaries efficiently
# Without Unit of Work - immediate persistence
user = User(name="John")
session.add(user)  # Immediate DB interaction
session.commit()  # Another DB call

order = Order(user_id=user.id)
session.add(order)  # Yet another DB call
session.commit()  # And another commit

# With Unit of Work - coordinated persistence
with unit_of_work:
    user = unit_of_work.users.create(name="John")  # Registered, not persisted
    order = unit_of_work.orders.create(user=user)  # Registered, not persisted
    # Single transaction with optimized queries

Why Use Unit of Work

1. Transaction Lifetime Management

Controls exactly when database transactions begin and end, preventing long-running transactions.

async def transfer_funds(from_account_id: int, to_account_id: int, amount: Decimal):
    async with unit_of_work:
        # Short, well-defined transaction scope
        from_account = await unit_of_work.accounts.get(from_account_id)
        to_account = await unit_of_work.accounts.get(to_account_id)
        
        from_account.withdraw(amount)
        to_account.deposit(amount)
        
        # All changes committed together or not at all

2. Performance Optimization

Batches database operations and can optimize queries:

# Without UoW - N+1 database calls
for user_data in user_list:
    user = User(**user_data)
    session.add(user)
    session.commit()  # One commit per user

# With UoW - Single batch operation
async with unit_of_work:
    for user_data in user_list:
        unit_of_work.users.register_new(User(**user_data))
    # Single commit with batch insert

3. Change Tracking

Automatically tracks what needs to be persisted:

async with unit_of_work:
    # UoW tracks all these changes
    user = await unit_of_work.users.get(user_id)
    user.email = "[email protected]"  # Tracked as dirty
    
    order = Order(user_id=user.id)
    unit_of_work.orders.register_new(order)  # Tracked as new
    
    old_product = await unit_of_work.products.get(product_id)
    unit_of_work.products.register_deleted(old_product)  # Tracked as deleted

How It Works

Two-Phase Process

Phase 1: Registration

Objects register their changes with the Unit of Work:

  • register_new() - For new objects to be inserted
  • register_dirty() - For modified objects to be updated
  • register_deleted() - For objects to be deleted

Phase 2: Commitment

All registered changes are persisted in a single transaction:

  • commit() - Saves all changes to database
  • rollback() - Discards all changes if error occurs

Workflow Diagram

┌─────────────────┐    ┌──────────────────┐    ┌─────────────────┐
│   Application   │    │   Unit of Work   │    │    Database     │
└─────────────────┘    └──────────────────┘    └─────────────────┘
         │                       │                       │
         │ register_new(user)    │                       │
         │──────────────────────>│                       │
         │                       │                       │
         │ register_dirty(order) │                       │
         │──────────────────────>│                       │
         │                       │                       │
         │ commit()              │                       │
         │──────────────────────>│                       │
         │                       │ BEGIN TRANSACTION     │
         │                       │──────────────────────>│
         │                       │ INSERT user           │
         │                       │──────────────────────>│
         │                       │ UPDATE order          │
         │                       │──────────────────────>│
         │                       │ COMMIT                │
         │                       │──────────────────────>│

Related Patterns

Data Mapper Pattern

Unit of Work doesn't access the database directly. Instead, it delegates to Data Mapper objects:

class UserMapper:
    def insert(self, user: User) -> None:
        # Handle user insertion logic
    
    def update(self, user: User) -> None:
        # Handle user update logic
    
    def delete(self, user: User) -> None:
        # Handle user deletion logic

class UnitOfWork:
    def __init__(self):
        self.user_mapper = UserMapper()
        self.new_objects = []
        self.dirty_objects = []
        self.deleted_objects = []
    
    def commit(self):
        for obj in self.new_objects:
            self.user_mapper.insert(obj)
        # ... handle dirty and deleted objects

Identity Map Pattern

Often used together with Unit of Work to ensure object identity:

class IdentityMap:
    def __init__(self):
        self._objects = {}
    
    def get(self, object_id: int, object_type: type):
        key = (object_type, object_id)
        return self._objects.get(key)
    
    def add(self, obj, object_id: int):
        key = (type(obj), object_id)
        self._objects[key] = obj

Implementation in FastAPI

Abstract Interface

from abc import ABC, abstractmethod
from sqlalchemy.ext.asyncio import AsyncSession

class IUnitOfWork(ABC):
    """Abstract Unit of Work interface."""
    
    # DAO interfaces for different entities
    organization_dao: IOrganizationDao = None
    waste_type_dao: IWasteTypeDao = None
    storage_dao: IStorageDao = None
    waste_transfer_dao: IWasteTransferDao = None

    def __init__(self):
        self.session_factory = async_session_maker

    @abstractmethod
    async def __aenter__(self):
        """Async context manager entry."""
        raise NotImplementedError

    @abstractmethod
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit."""
        raise NotImplementedError

    @abstractmethod
    async def commit(self):
        """Commit all registered changes."""
        raise NotImplementedError

    @abstractmethod
    async def rollback(self):
        """Rollback all registered changes."""
        raise NotImplementedError

Concrete Implementation

from src.infrastructure.logger import logger

class UnitOfWork(IUnitOfWork):
    """Concrete Unit of Work implementation using SQLAlchemy."""

    async def __aenter__(self):
        """Initialize session and DAOs."""
        logger.debug('Unit of work created')
        self.__session: AsyncSession = self.session_factory()
        
        # Initialize all DAOs with the same session
        self.organization_dao = OrganizationDao(self.__session)
        self.waste_type_dao = WasteTypeDao(self.__session)
        self.storage_dao = StorageDao(self.__session)
        self.waste_transfer_dao = WasteTransferDao(self.__session)
        
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Handle transaction completion and cleanup."""
        logger.debug('Unit of work closing')
        
        try:
            if exc_type:
                logger.debug(f'Exception occurred: {exc_val}')
                logger.debug('Rolling back transaction')
                await self.rollback()
            else:
                logger.debug('Committing transaction')
                await self.__session.flush()  # Flush pending changes
                logger.debug('Session flushed')
                await self.commit()
                logger.debug('Transaction committed')
        finally:
            logger.debug('Closing session')
            await self.__session.close()

    async def commit(self):
        """Commit the current transaction."""
        logger.debug('Committing transaction')
        await self.__session.commit()

    async def rollback(self):
        """Rollback the current transaction."""
        logger.debug('Rolling back transaction')
        await self.__session.rollback()

DAO (Data Access Object) Example

from abc import ABC, abstractmethod
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select

class IOrganizationDao(ABC):
    """Abstract DAO interface for Organization operations."""
    
    @abstractmethod
    async def get_by_id(self, org_id: int) -> Optional[Organization]:
        pass
    
    @abstractmethod
    async def get_all(self) -> List[Organization]:
        pass
    
    @abstractmethod
    async def create(self, organization: Organization) -> Organization:
        pass
    
    @abstractmethod
    async def update(self, organization: Organization) -> Organization:
        pass
    
    @abstractmethod
    async def delete(self, org_id: int) -> bool:
        pass

class OrganizationDao(IOrganizationDao):
    """Concrete DAO implementation."""
    
    def __init__(self, session: AsyncSession):
        self._session = session
    
    async def get_by_id(self, org_id: int) -> Optional[Organization]:
        stmt = select(Organization).where(Organization.id == org_id)
        result = await self._session.execute(stmt)
        return result.scalar_one_or_none()
    
    async def get_all(self) -> List[Organization]:
        stmt = select(Organization)
        result = await self._session.execute(stmt)
        return list(result.scalars().all())
    
    async def create(self, organization: Organization) -> Organization:
        self._session.add(organization)
        # Note: No commit here - UoW handles it
        return organization
    
    async def update(self, organization: Organization) -> Organization:
        # Object is already tracked by session
        # Changes will be persisted on UoW commit
        return organization
    
    async def delete(self, org_id: int) -> bool:
        organization = await self.get_by_id(org_id)
        if organization:
            await self._session.delete(organization)
            return True
        return False

Service Layer Integration

class OrganizationService:
    """Application service using Unit of Work."""
    
    def __init__(self, unit_of_work: IUnitOfWork):
        self._uow = unit_of_work
    
    async def create_organization_with_storage(
        self, 
        org_data: dict, 
        storage_data: dict
    ) -> dict:
        """Create organization and its storage in single transaction."""
        
        async with self._uow:
            # Create organization
            organization = Organization(**org_data)
            created_org = await self._uow.organization_dao.create(organization)
            
            # Create associated storage
            storage_data['organization_id'] = created_org.id
            storage = Storage(**storage_data)
            created_storage = await self._uow.storage_dao.create(storage)
            
            # Both operations committed together
            # If either fails, both are rolled back
            
            return {
                'organization': created_org,
                'storage': created_storage
            }
    
    async def transfer_waste(
        self, 
        from_storage_id: int, 
        to_storage_id: int, 
        waste_type_id: int, 
        quantity: float
    ) -> WasteTransfer:
        """Transfer waste between storages atomically."""
        
        async with self._uow:
            # Get entities
            from_storage = await self._uow.storage_dao.get_by_id(from_storage_id)
            to_storage = await self._uow.storage_dao.get_by_id(to_storage_id)
            waste_type = await self._uow.waste_type_dao.get_by_id(waste_type_id)
            
            if not all([from_storage, to_storage, waste_type]):
                raise ValueError("Invalid storage or waste type")
            
            # Business logic validation
            if from_storage.get_waste_quantity(waste_type_id) < quantity:
                raise ValueError("Insufficient waste quantity")
            
            # Update storages
            from_storage.remove_waste(waste_type_id, quantity)
            to_storage.add_waste(waste_type_id, quantity)
            
            await self._uow.storage_dao.update(from_storage)
            await self._uow.storage_dao.update(to_storage)
            
            # Create transfer record
            transfer = WasteTransfer(
                from_storage_id=from_storage_id,
                to_storage_id=to_storage_id,
                waste_type_id=waste_type_id,
                quantity=quantity
            )
            
            return await self._uow.waste_transfer_dao.create(transfer)

FastAPI Controller Integration

from fastapi import APIRouter, Depends, HTTPException
from dependency_injector.wiring import inject, Provide

router = APIRouter(prefix="/organizations", tags=["organizations"])

@router.post("/")
@inject
async def create_organization_with_storage(
    request: CreateOrganizationRequest,
    service: OrganizationService = Depends(Provide["organization_service"])
):
    """Create organization and storage in single transaction."""
    try:
        result = await service.create_organization_with_storage(
            org_data=request.organization.dict(),
            storage_data=request.storage.dict()
        )
        return {
            "organization_id": result['organization'].id,
            "storage_id": result['storage'].id,
            "message": "Organization and storage created successfully"
        }
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail="Internal server error")

@router.post("/waste-transfer")
@inject
async def transfer_waste(
    request: WasteTransferRequest,
    service: OrganizationService = Depends(Provide["organization_service"])
):
    """Transfer waste between storages."""
    try:
        transfer = await service.transfer_waste(
            from_storage_id=request.from_storage_id,
            to_storage_id=request.to_storage_id,
            waste_type_id=request.waste_type_id,
            quantity=request.quantity
        )
        return {
            "transfer_id": transfer.id,
            "message": "Waste transfer completed successfully"
        }
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))

Advanced Usage

Custom Change Tracking

class AdvancedUnitOfWork(IUnitOfWork):
    """UoW with explicit change tracking."""
    
    def __init__(self):
        super().__init__()
        self._new_objects = set()
        self._dirty_objects = set()
        self._deleted_objects = set()
    
    def register_new(self, obj):
        """Register new object for insertion."""
        self._new_objects.add(obj)
    
    def register_dirty(self, obj):
        """Register object for update."""
        if obj not in self._new_objects:
            self._dirty_objects.add(obj)
    
    def register_deleted(self, obj):
        """Register object for deletion."""
        if obj in self._new_objects:
            self._new_objects.remove(obj)
        else:
            self._dirty_objects.discard(obj)
            self._deleted_objects.add(obj)
    
    async def commit(self):
        """Commit all tracked changes."""
        # Insert new objects
        for obj in self._new_objects:
            self._session.add(obj)
        
        # Delete objects
        for obj in self._deleted_objects:
            await self._session.delete(obj)
        
        # Dirty objects are automatically tracked by SQLAlchemy
        
        await self._session.commit()
        self._clear_tracking()
    
    def _clear_tracking(self):
        """Clear all tracking sets after commit."""
        self._new_objects.clear()
        self._dirty_objects.clear()
        self._deleted_objects.clear()

Nested Unit of Work

class NestedUnitOfWork:
    """Support for nested transactions using savepoints."""
    
    async def __aenter__(self):
        if hasattr(self, '_session') and self._session:
            # Create savepoint for nested transaction
            self._savepoint = await self._session.begin_nested()
        else:
            # Create main transaction
            self._session = self.session_factory()
            self._savepoint = None
        
        self._initialize_daos()
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        try:
            if exc_type:
                if self._savepoint:
                    await self._savepoint.rollback()
                else:
                    await self._session.rollback()
            else:
                if self._savepoint:
                    await self._savepoint.commit()
                else:
                    await self._session.commit()
        finally:
            if not self._savepoint:
                await self._session.close()

Performance Optimization

class OptimizedUnitOfWork(UnitOfWork):
    """UoW with performance optimizations."""
    
    def __init__(self):
        super().__init__()
        self._batch_size = 1000
        self._deferred_operations = []
    
    async def bulk_insert(self, objects: List[Any]):
        """Perform bulk insert operation."""
        if len(objects) > self._batch_size:
            # Split into batches
            for i in range(0, len(objects), self._batch_size):
                batch = objects[i:i + self._batch_size]
                await self._session.bulk_insert_mappings(
                    type(batch[0]), 
                    [obj.__dict__ for obj in batch]
                )
        else:
            self._session.add_all(objects)
    
    async def bulk_update(self, model_class, updates: List[dict]):
        """Perform bulk update operation."""
        await self._session.bulk_update_mappings(model_class, updates)
    
    def defer_operation(self, operation_func, *args, **kwargs):
        """Defer operation until commit."""
        self._deferred_operations.append((operation_func, args, kwargs))
    
    async def commit(self):
        """Execute deferred operations then commit."""
        for operation_func, args, kwargs in self._deferred_operations:
            await operation_func(*args, **kwargs)
        
        await super().commit()
        self._deferred_operations.clear()

Best Practices

1. Keep Transactions Short

# ✅ Good - Short transaction scope
async def process_order(order_data: dict):
    async with unit_of_work:
        order = await unit_of_work.orders.create(Order(**order_data))
        await unit_of_work.inventory.reserve_items(order.items)
        # Transaction ends here

# ❌ Bad - Long-running transaction
async def process_order_bad(order_data: dict):
    async with unit_of_work:
        order = await unit_of_work.orders.create(Order(**order_data))
        
        # Long-running operations inside transaction
        await send_confirmation_email(order.user.email)  # Network call
        await update_external_system(order)  # Another network call
        await generate_pdf_invoice(order)  # CPU intensive

2. Handle Errors Gracefully

async def transfer_funds_safe(from_id: int, to_id: int, amount: Decimal):
    try:
        async with unit_of_work:
            from_account = await unit_of_work.accounts.get(from_id)
            to_account = await unit_of_work.accounts.get(to_id)
            
            if not from_account or not to_account:
                raise ValueError("Invalid account")
            
            if from_account.balance < amount:
                raise ValueError("Insufficient funds")
            
            from_account.withdraw(amount)
            to_account.deposit(amount)
            
            # Automatic rollback on exception
            
    except ValueError:
        # Business logic errors - re-raise
        raise
    except Exception as e:
        # Unexpected errors - log and handle
        logger.error(f"Unexpected error in fund transfer: {e}")
        raise RuntimeError("Transfer failed due to system error")

3. Use Dependency Injection

# Container configuration
from dependency_injector import containers, providers

class Container(containers.DeclarativeContainer):
    # Session factory
    session_factory = providers.Factory(async_session_maker)
    
    # Unit of Work
    unit_of_work = providers.Factory(UnitOfWork)
    
    # Services
    organization_service = providers.Factory(
        OrganizationService,
        unit_of_work=unit_of_work
    )

# FastAPI dependency
def get_unit_of_work() -> IUnitOfWork:
    return Container.unit_of_work()

# Usage in endpoints
@router.post("/")
async def create_org(
    request: CreateOrgRequest,
    uow: IUnitOfWork = Depends(get_unit_of_work)
):
    async with uow:
        # Use unit of work
        pass

4. Implement Proper Logging

class LoggedUnitOfWork(UnitOfWork):
    """UoW with comprehensive logging."""
    
    async def __aenter__(self):
        correlation_id = getattr(contextvars.trace_id, 'get', lambda: 'unknown')()
        logger.info(f"Starting transaction [correlation_id={correlation_id}]")
        
        start_time = time.time()
        result = await super().__aenter__()
        
        setup_time = time.time() - start_time
        logger.debug(f"UoW setup completed in {setup_time:.3f}s")
        
        return result
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        start_time = time.time()
        
        try:
            await super().__aexit__(exc_type, exc_val, exc_tb)
            
            if exc_type:
                logger.warning(f"Transaction rolled back due to {exc_type.__name__}: {exc_val}")
            else:
                completion_time = time.time() - start_time
                logger.info(f"Transaction committed successfully in {completion_time:.3f}s")
                
        except Exception as e:
            logger.error(f"Error during transaction cleanup: {e}")
            raise

Testing

Unit Testing

import pytest
from unittest.mock import AsyncMock, Mock

class TestUnitOfWork:
    
    @pytest.fixture
    def mock_session_factory(self):
        session_mock = AsyncMock()
        factory_mock = Mock(return_value=session_mock)
        return factory_mock, session_mock
    
    @pytest.fixture
    def unit_of_work(self, mock_session_factory):
        factory_mock, session_mock = mock_session_factory
        uow = UnitOfWork()
        uow.session_factory = factory_mock
        return uow, session_mock
    
    async def test_successful_commit(self, unit_of_work):
        uow, session_mock = unit_of_work
        
        async with uow:
            # Simulate operations
            await uow.organization_dao.create(Organization(name="Test"))
        
        # Verify session interactions
        session_mock.flush.assert_called_once()
        session_mock.commit.assert_called_once()
        session_mock.close.assert_called_once()
    
    async def test_rollback_on_exception(self, unit_of_work):
        uow, session_mock = unit_of_work
        
        with pytest.raises(ValueError):
            async with uow:
                raise ValueError("Test error")
        
        # Verify rollback was called
        session_mock.rollback.assert_called_once()
        session_mock.close.assert_called_once()
        assert not session_mock.commit.called

Integration Testing

@pytest.mark.asyncio
async def test_organization_creation_integration():
    """Test actual database integration."""
    
    async with UnitOfWork() as uow:
        # Create organization
        org_data = {"name": "Test Org", "address": "123 Test St"}
        organization = Organization(**org_data)
        created_org = await uow.organization_dao.create(organization)
        
        # Create storage
        storage_data = {"name": "Main Storage", "organization_id": created_org.id}
        storage = Storage(**storage_data)
        created_storage = await uow.storage_dao.create(storage)
        
        # Verify objects are created
        assert created_org.id is not None
        assert created_storage.id is not None
        assert created_storage.organization_id == created_org.id

@pytest.mark.asyncio 
async def test_transaction_rollback_integration():
    """Test that rollback works correctly."""
    
    initial_count = await get_organization_count()
    
    with pytest.raises(ValueError):
        async with UnitOfWork() as uow:
            # Create organization
            org = Organization(name="Test Org")
            await uow.organization_dao.create(org)
            
            # Force an error
            raise ValueError("Simulated error")
    
    # Verify no data was committed
    final_count = await get_organization_count()
    assert final_count == initial_count

Common Pitfalls

1. Long-Running Transactions

# ❌ Bad - Transaction held too long
async def bad_order_processing():
    async with unit_of_work:
        order = await unit_of_work.orders.create(order_data)
        
        # These operations shouldn't be in the transaction
        await send_email_notification(order)  # Network I/O
        await call_external_api(order)        # External dependency
        await generate_report(order)          # CPU intensive
        
        # Transaction held for too long!

# ✅ Good - Separate concerns
async def good_order_processing():
    # Quick transaction for data consistency
    async with unit_of_work:
        order = await unit_of_work.orders.create(order_data)
        order_id = order.id
    
    # Separate operations outside transaction
    await send_email_notification(order_id)
    await call_external_api(order_id)
    await generate_report(order_id)

2. Mixing Direct Session Usage

# ❌ Bad - Bypassing UoW
async with unit_of_work:
    organization = await unit_of_work.organization_dao.create(org_data)
    
    # Don't do this - bypasses UoW coordination
    session = unit_of_work._UnitOfWork__session
    session.add(Storage(organization_id=organization.id))
    await session.commit()  # Commits outside UoW control

# ✅ Good - Use UoW consistently
async with unit_of_work:
    organization = await unit_of_work.organization_dao.create(org_data)
    storage = await unit_of_work.storage_dao.create(
        Storage(organization_id=organization.id)
    )

3. Forgetting Error Handling

# ❌ Bad - No error handling
async def create_organization(org_data):
    async with unit_of_work:
        return await unit_of_work.organization_dao.create(Organization(**org_data))

# ✅ Good - Proper error handling
async def create_organization(org_data):
    try:
        async with unit_of_work:
            organization = Organization(**org_data)
            return await unit_of_work.organization_dao.create(organization)
    except IntegrityError:
        raise ValueError("Organization with this name already exists")
    except Exception as e:
        logger.error(f"Failed to create organization: {e}")
        raise RuntimeError("Failed to create organization")

4. Nested Transaction Issues

# ❌ Bad - Uncontrolled nesting
async def problematic_nesting():
    async with unit_of_work:  # Outer transaction
        org = await unit_of_work.organization_dao.create(org_data)
        
        # This might create issues
        async with unit_of_work:  # Inner transaction
            storage = await unit_of_work.storage_dao.create(storage_data)

# ✅ Good - Clear transaction boundaries
async def clear_boundaries():
    # First transaction
    async with unit_of_work:
        org = await unit_of_work.organization_dao.create(org_data)
        org_id = org.id
    
    # Second transaction
    async with unit_of_work:
        storage_data['organization_id'] = org_id
        storage = await unit_of_work.storage_dao.create(storage_data)

Conclusion

The Unit of Work pattern is essential for managing database transactions in complex applications. It provides:

Key Benefits:

  1. Transaction Control - Precise management of transaction boundaries
  2. Performance - Optimized database operations through batching
  3. Consistency - Ensures data integrity across multiple operations
  4. Testability - Easy to mock and test business logic

When to Use:

  • Complex business transactions involving multiple entities
  • Performance-critical applications requiring optimized database access
  • Domain-driven design where business operations span multiple aggregates
  • Applications requiring strict data consistency

Remember:

  • Keep transactions short and focused
  • Handle errors gracefully with proper rollback
  • Don't mix direct session usage with UoW
  • Test thoroughly including failure scenarios
  • Log appropriately for debugging and monitoring

The Unit of Work pattern, when implemented correctly, provides a solid foundation for managing data persistence in your FastAPI Clean Architecture application.


See Also: