CS7545_Sp24_Lecture_20 - mltheory/CS7545 GitHub Wiki

CS 7545: Machine Learning Theory -- Spring 2024

Instructor: Guanghui Wang

Notes for Lecture 20: Federated Learning

April 2, 2024

Scribes: Yash Arora, Amit Singh

Goal: Federated Learning: Privacy-preserving model training in heterogeneous and distributed networks. Distributed optimization + large-scale ML, privacy sensitive data.

The first section of these lecture notes will be derived from the presentation on the slides for the first half of class: The presentation is linked in Canvas files here.

We have a server model to each local device from which we receive local updates, however local data must stay within the device because that is private.

Application of Federated Learning: Can be used in finance for money laundering detection, financial fraud detection since issues that arise here are that other banks data is sensitive and cannot be shared across banks, however we have some level of central data which can receive local updates from each bank so federated learning allows for some collective/collaborative effort to prevent financial crime.

Question: How are Clients selected virtually?

Answer: Basic method is to select clients randomly, however we can do weighted selections and this can either be done heuristically on in a more algorithmic way (i.e. loss of each local device)

There is a slideshow that is linked which gives an introduction to federated learning. We cover the FedAvg algorithm, which is the first and simplest federated learning algorithm. This concept was introduced by Google in a research paper.

Now let's speak on the challenges of Federated Learning:

  • Expensive Communication: much more $$ than that in classical
  • Statistical Heterogeneity: Non i.i.d data, amount of data can vary a lot as well
  • System Heterogeneity: Differences in devices such as iPhone, iPad, laptop results in different net connectivity etc.
  • Privacy Concerns: Reconstructing client data partially from model gradients is possible which creates privacy concerns

Communication Efficiency

  1. Multiple local updates; local SGD (the FedAvg algorithm)

These techniques serve to reduce the number of communicated rounds

  1. Compression Techniques: sparsification and quantization. Reducing transmitted bits from the server to client.

Compression techniques can involve either Sparsification or Quantization which essentially serve to reduce the # of communicated bits.

The focus for this lecture is on 1.

Minibatch SGD vs Local SGD

Minibatch: There is 1 local update. Local: There are $h$ local updates.

Assume there are $K$ devices. Let $N_{c, m}$ = the number of communication rounds for minibatch SGD.

SGD is strongly convex and smooth. It takes $O(1/T)$ time for 1 local device. It takes $O(1/KT)$ time for K local devices. This is known as Linear Speedup. For additional context, the whole purpose of local SGD is to reduce the number of communication rounds while achieving the same SGD.

Local SGD: $N_{c, l}$ = the number of communication rounds for local SGD. Let $T$ = the total number of iterations.

Then, $N_{c, l} = T/H$.

Local SGD: $N_{c, l} = T/H$. Thus, $H * N_{c, l} = T$. Minibatch SGD: 1 update * $N_{c, m} = T$ => $T = N_{c}$.

The goal is to ultimately reduce communication rounds to achieve desired convergence rate. This was reduced over time to $Õ(K^{1/3})$. Specifically: $N_{c, l} = O(T^{.5}K^{.5}) \rightarrow Õ(K) \rightarrow Õ(K^{1/2})$ or $Õ(K^{1/3})$

Question: What is Õ?

Answer: Õ can be thought of as a smaller approximation of O that we use to get rid of T.

Question: Do we have to load models onto local clients?

Answer: We can update the server model on the local device and then send local updates. If the server model is too big for local clients then we could use compression possibly to reduce model sizes.

The paper the remainder of the lecture is derived from is "Local SGD converges fast and communicates little" (Sebastian Stich, 2019).

Problem Setup (convex)

  1. General Form: $min_{x \in \mathbb{R}} f(x) \overset{\Delta}{=} \frac{1}{k} * \sum\limits_{i=1}^kf_{i}(x)$, where $f_{i}(x) = E_{z ~ D_i}[l_{i}(x, z)]$

Homogeneous Setting: $D_{1} = \ldots = D_{k}$, $l_{1} = \cdots = l_{k}$

  1. Assume IID. $\min_{x \in \mathbb{R}^{d}} f(x) = E_{Z \sim D} [f(x, z)]$.

  2. Full Participation

Local SGD Algorithm

$K$: Number of Devices, $T$ = total number of steps

$L_{T} \subseteq [T]$, $T \in L_{T}$, where $L_{T}$ is the set of communication rounds. \

$x_{t+1}^{k} \coloneq x_{t}^{k} - \eta_{t} \nabla f_{it^{k}}(x_{t}^{k})$ if $t+1 \notin L_{T}$
$x_{t+1}^{k} \coloneq \frac{1}{k} \sum_{i=1}^{k}(x_{T}^{k} - \eta_{T} \nabla f_{it^{k}} (x_{T}^{k}))$ if $t+1 \in L_{T}$ \

$p \coloneq {p_{0}, \cdots, p_{t}}$, where there are $t+1$ integers, $p_{i} \leq p_{i+1}$
$gap(p_{i}) = max_{1 \leq i \leq t} (p_{i} - p_{i+1})$

Assumptions

  • (a) $f$ is $\mu$-strongly convex $(f(x) \geq f(y) + <\nabla f(y), x-y> + \frac{1}{2} \mu ||x-y||^{2}$

  • (b) $L$-smooth, $f(x) \leq f(y) + \langle \nabla f(y), x-y \rangle + \frac{1}{2} L ||x-y||^{2} \rightarrow f(x) \geq f(y) + \langle x-y, \nabla f(y) \rangle + \frac{1}{2L} ||\nabla f(x) - \nabla f(y)^{2}$

  • (c) The variance of stochastic gradients is uniformly banned. $E_{i} || \nabla f_{i} (x_{t}^{k}) - \nabla f(x_{t}^{k}||^{2} \leq \sigma^{2}$

  • (d) The expected squared norm of stochastic gradients is uniformly bounded. $E_{i} || \nabla f_{i} (x_{t}^{k})||^{2} \leq G^{2}$ \

Question: Is this bound saying we have an unbiased value?

Answer: Essentially, yes.

Convergence Analysis: Virtual Sequence: Average of models at time t

Perturbed iterate analysis (Mania et. al, 2019)

We define a virtual sequence $\lbrace\overline{x_t}\rbrace_{t \geq 0}$

$\longrightarrow \overline{x_0} = x_0$, $\overline{x_t} = \frac{1}{k} \sum\limits_{k=1}^k x^k_t$, $f^* \leftarrow f(\overline{x_t})$

$\overline{g_t} := \frac{1}{k} \sum\limits_{k=1}^k \nabla f_{i_t^k} (x^k_t)$, $\overline{g_t} := \frac{1}{k} \sum\limits_{k=1}^k \nabla f (x^k_t)$

$\overline{x_{t+1}} = \bar{x_{t}} - \eta_{t}g_{t}$ and $E[g_{t}] = \bar{g_{t}}$

⚠️ **GitHub.com Fallback** ⚠️