4. Developer Guide - sahiltyagi4/FLORA_beta GitHub Wiki

All core components (algorithms, communicators, topologies) use the RequiredSetup mixin for consistent initialization - see the end of this guide for details.

Extending Core Components

Custom Algorithms

Implement these two methods (see source for full behavior). Minimal example:

from src.omnifed.algorithm.base import BaseAlgorithm
from src.omnifed.communicator import AggregationOp
from src.omnifed import utils
import torch

class FedLike(BaseAlgorithm):
  def _configure_local_optimizer(self, local_lr: float) -> torch.optim.Optimizer:
    return torch.optim.SGD(self.local_model.parameters(), lr=local_lr)

  def _compute_loss(self, batch) -> torch.Tensor:
    x, y = batch
    logits = self.local_model(x)
    return torch.nn.functional.cross_entropy(logits, y)

  def _aggregate_within_group(self, comm, weight):
    # sample-weighted averaging (default)
    utils.scale_params(self.local_model, weight, include_buffers=True)
    return comm.aggregate(self.local_model, reduction=AggregationOp.SUM)

Override _aggregate_within_group() or _aggregate_across_groups() only when you need custom weighting/logic; defaults implement sample-weighted averaging.

See implementation: src/omnifed/algorithm/base.py

Custom Communicators

Extend BaseCommunicator to implement new transport mechanisms. The framework includes gRPC and PyTorch distributed backends.

from src.omnifed.communicator.base import BaseCommunicator, AggregationOp
import torch

class MyCustomCommunicator(BaseCommunicator):
  def _setup(self):
    # Initialize your transport (connections, process groups, etc.)
    return self.init_transport()

  def broadcast(self, obj):
    # Broadcast from rank 0 to all others
    return self.transport_broadcast(obj)

  def aggregate(self, obj, reduction=AggregationOp.SUM):
    # Collect and reduce across all ranks
    return self.transport_aggregate(obj, reduction)

  def close(self):
    # Clean up connections/resources
    self.transport_cleanup()

The required methods are broadcast(), aggregate(), and close().

See implementation: src/omnifed/communicator/base.py

Custom Topologies

Extend BaseTopology and implement the required contract (signature and a short example):

from src.omnifed.topology.base import BaseTopology
from src.omnifed.node import NodeConfig

class MyTopology(BaseTopology):
  def __init__(self, num_clients: int):
    super().__init__()
    self.num_clients = num_clients

  def _setup(self, default_algorithm_cfg, default_model_cfg, default_datamodule_cfg):
    """Create NodeConfig objects using Engine-provided defaults."""
    nodes = []
    for rank in range(self.num_clients):
      node = NodeConfig(
        name=f"client{rank}",
        local_comm=comm_cfg,  # Your communicator config
        algorithm=default_algorithm_cfg,
        model=default_model_cfg,
        datamodule=default_datamodule_cfg,
        device_hint="auto",
      )
      nodes.append(node)
    return nodes

The Engine provides default configs to your _setup() method; use them and apply per-node overrides as needed.

See implementation: src/omnifed/topology/base.py

HierarchicalTopology Pattern

For topologies with child topologies, set up children first then assemble the results.

for child in self.topologies:
  child.setup(default_algorithm_cfg, default_model_cfg, default_datamodule_cfg)

node_configs = []
for g_idx, child in enumerate(self.topologies):
  for n in child.node_configs:
    # name nodes as group.rank and give group leaders a global_comm
    n.name = f"{g_idx}.{n.local_comm.rank}"
    if n.local_comm.rank == 0:
      n.global_comm = global_comm_cfg  # Cross-group communication config
    node_configs.append(n)

return node_configs

See implementation: src/omnifed/topology/hierarchical.py

NodeConfig

Internal dataclass used by topologies to specify node configurations. Main fields: name, local_comm, algorithm, model, datamodule.

See implementation: src/omnifed/node.py

RequiredSetup mixin

Many components need to do work that can only happen at runtime (open connections, discover ranks, allocate devices). The RequiredSetup mixin gives them a small, consistent contract for deferred initialization.

What it guarantees

  • You call setup() once, the class runs _setup(...), and caches the result.
  • setup_result exposes whatever _setup(...) returned (after setup only).
  • is_ready tells you whether setup finished.

Tiny example

from src.omnifed.utils import RequiredSetup

class MyComponent(RequiredSetup):
  def __init__(self, config):
    super().__init__()
    self.config = config

  def _setup(self):
    # Do expensive initialization here
    return self.config.get("resource")

# Usage
component = MyComponent(config)
component.setup()
result = component.setup_result

Notes

  • Donโ€™t access setup_result before calling setup() โ€” it raises a RuntimeError.
  • Keep constructors light; do real work in _setup().

See implementation: src/omnifed/utils/setup_mixin.py