Scaling laws - choderalab/modelforge GitHub Wiki
Introduction
Understanding the scaling laws for inference in neural network potentials (NNPs) is essential for designing efficient architectures, especially as system size and model complexity increase. The computational cost and memory requirements typically scale with the number of atoms $𝑁$ in the system. These scaling factors can become critical bottlenecks for large-scale molecular dynamics simulations, especially when utilizing GPUs and optimizing for computational efficiency.
[!NOTE] For the sake of simplicity, we will assume 64 bits for integers and floating-point numbers. In practice we use 32 bits for floats and integers (for integers that are used as indices PyTorch requires 64 bits).
(GPU) Memory consumption
Given the limited and finite nature of GPU memory compared to CPU memory, optimizing memory consumption during inference is crucial. GPU memory is often the limiting factor for large systems and complex models. Here, we break down the memory consumption for inference in NNPs into two main categories: model memory, force calculation memory, and system-dependent memory.
Model memory consumption
The model itself must be transferred to GPU memory before inference. For a typical NNP with 2 million float64 parameters, the model's footprint in memory is straightforward to calculate. Each float64 parameter requires 8 bytes, resulting in: $2*10^6×8~\text{bytes}\approx128~\text{MB}$
This is the baseline memory allocation purely for the model's parameters. During backpropagation (when derivatives are computed for forces), memory requirements can increase substantially, as activations and gradients need to be stored for each parameter, which will depend on the model architecture.
Force calculation memory consumption
When forces are calculated (i.e., backpropagation with respect to atomic coordinates), memory consumption increases because the derivatives of the energy with respect to the coordinates must be computed and stored. This involves storing gradients for each of the $3N$ atomic coordinates. For a system with $N$ atoms this requires $3xNx8~\text{bytes}$.
This storage represents only the gradients of the atomic positions, but memory consumption increases significantly as it scales with the number of layers ($L$) and neurons ($H$) per layer. For each layer, the activation of the layer needs to be stored since it will be used later during the backward pass to compute gradients. The memory required to store these activations scales with the number of layers $L$, the number of hidden units $H$, and the number of interactions $NxM$ (interacting atoms):
$$ O(LxNxMxH) $$
In practice, the memory consumption will depend on the architecture of the model. But in general, this shows that memory consumption can grow significatnly as the depth of the model and the number of interactions increase. Using (automatic) mixed (lower) precision and gradient-checkpoints are viable strategies to mitigate some of these issues.
As an example, for a fully connected graph (NxM = N^2) that passes through 4 linear layers, each with output dimension 128 the total memory scales with $$N^2 * 4*128$$.
System-dependent memory consumption
Neighborlist calculation
The memory consumption also scales with the number of atoms $N$, primarily due to the need for constructing a neighbor list, which identifies nearby atoms that contribute to local interactions. Current efficient neighborlist implementations involves an $O(N^2)$ scaling step, where each atom is compared to every other atom within a specified cutoff distance. For a system with $1000$ atoms, this results in a neighbor list calculation with memory consumption scaling as $N^2 * (338~\text{bytes} + 18~\text{bytes} + 28~\text{bytes} = N^2 * 96~\text{bytes})$.
In this expression, the terms correspond to storing 3D positions (9 floats), atom pair distances (1 float), and atom pair indices (2 ints). For $3*10^3$ atoms, this leads to:
$$ (3*10^3)^2 x 96~\text{bytes} \approx 900~\text{MB} $$
Featurization of atom pair interaction
The neighborlist identifies the $NxM$ atom pairs that have interactions contributing to local interactions. Each of these interactions is typically parameterized using e.g. radial symmetry functions. If the number of radial symmetry functions is eg 32 this requires $NxMx32x8\text{bytes}$. For $N=1*10^3$ and $M=20$ the memory consumption is about 4 MB.
Experiment to evaluate memory footprint
The following plots are generated using a realistic set of hyperparamters and torch.float32
.
Timings
Only forward pass:
Forward and backward pass:
import torch
from openmmtools.testsystems import WaterBox
from simtk import unit
from typing import List
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import time # Import time module for timing
def measure_performance_for_edge_sizes(
edge_sizes: List[float],
potential_names: List[str],
):
"""
Measures GPU memory utilization and computation time for force calculations
for water boxes of different edge sizes across multiple potentials.
Parameters
----------
edge_sizes : List[float]
A list of edge sizes (in nanometers) for the water boxes.
potential_names : List[str]
A list of potential names to use in the model setup.
Returns
-------
List[dict]
A list of dictionaries containing edge size, number of water molecules,
potential name, memory usage in bytes, and computation time in seconds.
"""
results = []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
precicion = torch.float32
for potential_name in potential_names:
for edge_size in edge_sizes:
# Generate water box with the given edge size
test_system = WaterBox(box_edge=edge_size * unit.nanometer)
positions = test_system.positions # Positions in nanometers
topology = test_system.topology
# Extract atomic numbers and residue indices
atomic_numbers = []
residue_indices = []
for residue_index, residue in enumerate(topology.residues()):
for atom in residue.atoms():
atomic_numbers.append(atom.element.atomic_number)
residue_indices.append(residue_index)
num_waters = len(list(topology.residues()))
positions_in_nanometers = positions.value_in_unit(unit.nanometer)
# Convert to torch tensors and move to GPU
torch_atomic_numbers = torch.tensor(atomic_numbers, dtype=torch.long, device=device)
torch_positions = torch.tensor(positions_in_nanometers, dtype=torch.float32, device=device, requires_grad=True)
torch_atomic_subsystem_indices = torch.zeros_like(torch_atomic_numbers, dtype=torch.long, device=device)
torch_total_charge = torch.zeros(num_waters, dtype=torch.float32, device=device)
nnp_input = NNPInput(
atomic_numbers=torch_atomic_numbers,
positions=torch_positions,
atomic_subsystem_indices=torch_atomic_subsystem_indices,
total_charge=torch_total_charge,
).to(dtype=precicion)
# Import your model setup function
from modelforge.tests.helper_functions import setup_potential_for_test
# Setup model
model = setup_potential_for_test(
potential_name,
"inference",
potential_seed=42,
use_training_mode_neighborlist=False,
simulation_environment='PyTorch',
)
model.to(device)
model.to(precicion)
total_params = sum(p.numel() for p in model.parameters())
# Measure GPU memory usage and computation time
torch.cuda.reset_peak_memory_stats(device=device)
torch.cuda.synchronize()
# Run forward pass and time it
start_time = time.perf_counter()
try:
output = model(nnp_input.as_namedtuple())["per_molecule_energy"]
except :
print("Out of memory error during forward pass")
continue
try:
F_training = -torch.autograd.grad(
output.sum(), nnp_input.positions, create_graph=True, retain_graph=True
)[0]
except :
print("Out of memory error during backward pass")
continue
torch.cuda.synchronize()
end_time = time.perf_counter()
max_memory_allocated = torch.cuda.max_memory_allocated(device=device)
computation_time = end_time - start_time
results.append({
'potential_name': f"{potential_name}: {total_params:.1e} params",
'edge_size_nm': edge_size,
'num_waters': num_waters,
'memory_usage_bytes': max_memory_allocated,
'computation_time_s': computation_time
})
# Clean up
del nnp_input, output, model,
try:
del F_training
except:
pass
torch.cuda.empty_cache()
return results
def plot_computation_time(results):
"""
Plots computation time against the number of water molecules for multiple potentials.
Parameters
----------
results : List[dict]
A list of dictionaries containing edge size, number of water molecules,
potential name, memory usage in bytes, and computation time in seconds.
"""
# Create a DataFrame for plotting
df = pd.DataFrame(results)
df['computation_time_ms'] = df['computation_time_s'] * 1000 # Convert seconds to milliseconds
# Plot using seaborn
sns.set(style="whitegrid")
plt.figure(figsize=(10, 6))
sns.lineplot(
data=df,
x='num_waters',
y='computation_time_ms',
hue='potential_name',
units='potential_name',
estimator=None, # Do not aggregate data
marker='o',
linewidth=2,
markersize=8
)
plt.title('Computation Time vs Number of Water Molecules for Different Potentials')
plt.xlabel('Number of Water Molecules')
plt.ylabel('Computation Time (ms)')
plt.xticks(sorted(df['num_waters'].unique()))
plt.legend(title='Potential Name')
plt.tight_layout()
plt.show()
# Example usage:
edge_sizes = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5] # Edge sizes in nanometers
potential_names = ['schnet', 'painn', 'physnet', 'ani2x', 'aimnet2', 'sake']
results = measure_performance_for_edge_sizes(
edge_sizes=edge_sizes,
potential_names=potential_names,
)
def plot_gpu_memory_usage(results):
"""
Plots GPU memory usage against the number of water molecules for multiple potentials.
Parameters
----------
results : List[dict]
A list of dictionaries containing edge size, number of water molecules,
potential name, and memory usage in bytes.
"""
# Create a DataFrame for plotting
df = pd.DataFrame(results)
df['memory_usage_mb'] = df['memory_usage_bytes'] / 1e6 # Convert bytes to megabytes
# Plot using seaborn
sns.set(style="whitegrid")
plt.figure(figsize=(10, 6))
sns.lineplot(
data=df,
x='num_waters',
y='memory_usage_mb',
units='potential_name',
estimator=None, # Do not aggregate data
hue='potential_name',
marker='o',
linewidth=2,
markersize=8,
)
plt.title('Backward pass: GPU Memory Usage vs Number of Water Molecules for Different Potentials')
plt.xlabel('Number of Water Molecules')
plt.ylabel('GPU Memory Usage (MB)')
plt.xticks(sorted(df['num_waters'].unique()))
plt.legend(title='Potential Name')
plt.tight_layout()
plt.show()
# Print the results
for result in results:
print(f"Potential: {result['potential_name']}, "
f"Edge Size: {result['edge_size_nm']} nm, "
f"Number of Waters: {result['num_waters']}, "
f"Memory Usage: {result['memory_usage_bytes']/1e6:.2f} MB, "
f"Computation Time: {result['computation_time_s']*1000:.2f} ms")
# Plot the computation time
plot_computation_time(results)
plot_gpu_memory_usage(results)