CS7545_Sp23_Lecture_27: Generalization Bounds are Vacuous in Deep Learning - mltheory/CS7545 GitHub Wiki
A few reasons why it's important for us to study generalization bounds & complexity measures -
- Provide insight & prediction of model behavior (explain weird things like the phenomenon of double descent, weaker performance when using distributed training, etc.)
- Algorithms and regularization techniques
- Provable guarantees on our algorithms (even before training or testing on real-life tests, for instance, models for airplanes/nuclear reactors)
- Use complexity measure as a regularizer (the ideas behind some oldest ML models are directly minimizing things like the Rademacher generalization bounds, if we can find some the right complexity measure for NNs, then theoretically speaking, we should get better performance set of optimizers)
- Rigorous mindset in the industry
Let
However, both are insufficient for modern Neural Networks, reasons are as follows:
-
For overparametrized network (i.e.,
$W \gg m$ ), the VC-dim complexity term$\sqrt \frac{VC}{m} \gg 1$ -
Modern NNs can be very deep, the Rademacher complexity is exponential in
$D$ $(\gg 1)$ -
Gradient descent tends to increase
$||w_i||$ , then$\prod||w_i||_F$ could be large. We could use early stopping$\rightarrow$ Terminate model training before SGD convergence.
Emphasize: These bounds above are not wrong, they are just loose for practical models.
Example: Support Vector Machine(SVM), which tries to
We showed that the Rademacher Complexity for
is
However, it seems that overparameterized Neural Networks may not need regularization to perform well.
Example: Although adding regularization could improve performances, but without regularization, NNs also get acceptable results (Zhang, et al., 2017).
-
CIFAR 10 dataset, Inception: with weight decay (equivalent to
$\ell_2$ regularizer on the weights) 86.03% accuracy, without weight decay 85.75% accuracy -
ImageNet dataset, Inception: with weight decay 67.18% accuracy, without weight decay 59.80% accuracy
Neural Network performance is nearly dependent on optimization.
- SVM (will obtain the same solution with the following methods)
- Solve QP exactly
- GD algorithm
- Coordinate descent
- NNs trained using gradient descent
- No unique solution
- The way you train actually matters
- GD finds a low-complexity solution all by itself (implicit regularization)
Classical Bias-Variance Tradeoff Curve1
Double Descent Risk Curve (Belkin, et al., 2019)
Note: This naturally leaves us with a question to be further investigated - from the "modern" perspective, "What should be on the x-axis?", i.e., what would be an appropriate approach to reason about the generalization of (deep) neural networks.