configuration_diffusion - dingdongdengdong/astra_ws GitHub Wiki
configuration_diffusion.py: Code Flow Analysis and Diagram
Mermaid Diagram
graph TD
A[DiffusionConfig_dataclass_PreTrainedConfig]
A -->|__post_init__| B[Input_validation]
A -->|get_optimizer_preset| C[Returns_AdamConfig]
A -->|get_scheduler_preset| D[Returns_DiffuserSchedulerConfig]
A -->|validate_features| E[Feature_validation]
A -->|observation_delta_indices| F[Property_list]
A -->|action_delta_indices| G[Property_list]
A -->|reward_delta_indices| H[Property_None]
B --> B1[vision_backbone_check]
B --> B2[prediction_type_check]
B --> B3[noise_scheduler_type_check]
B --> B4[horizon_downsampling_compatibility]
E --> E1[image_env_state_presence]
E --> E2[crop_shape_fits_images]
E --> E3[all_images_same_shape]
Code Flow Explanation
DiffusionConfig Class
- Inherits from
PreTrainedConfig
and is registered as a "diffusion" policy configuration. - Contains all hyperparameters and architectural settings for a diffusion-based policy, including input/output shapes, normalization, vision backbone, U-Net parameters, noise scheduler, and training presets.
Key Methods and Properties
__post_init__
: Validates configuration after initialization. Checks:- Vision backbone is a ResNet variant.
- Prediction type is supported ("epsilon" or "sample").
- Noise scheduler type is supported ("DDPM" or "DDIM").
- Horizon is compatible with U-Net downsampling.
get_optimizer_preset
: Returns anAdamConfig
object with optimizer settings.get_scheduler_preset
: Returns aDiffuserSchedulerConfig
object with scheduler settings.validate_features
: Ensures required image/environment state features are present, crop shapes fit, and all images have the same shape.observation_delta_indices
,action_delta_indices
,reward_delta_indices
: Properties for index calculations used in temporal data handling.
Data Structures
- Uses Python dataclasses for field definitions and default values.
- Relies on dictionaries for input/output shapes and normalization modes.
Control Flow
- Initialization and validation are tightly coupled to ensure configuration correctness before use in downstream models.
- Properties provide index ranges for temporal stacking of observations and actions.
Summary
This configuration file defines the structure and validation logic for all parameters required by a diffusion policy, ensuring that downstream models receive consistent and valid settings for training and inference.
Additional Diagrams and Explanations
Class Structure Diagram
classDiagram
class PreTrainedConfig
class DiffusionConfig {
+int n_obs_steps
+int horizon
+int n_action_steps
+dict normalization_mapping
+str vision_backbone
+tuple crop_shape
+bool crop_is_random
+str pretrained_backbone_weights
+bool use_group_norm
+int spatial_softmax_num_keypoints
+bool use_separate_rgb_encoder_per_camera
+tuple down_dims
+int kernel_size
+int n_groups
+int diffusion_step_embed_dim
+bool use_film_scale_modulation
+str noise_scheduler_type
+int num_train_timesteps
+str beta_schedule
+float beta_start
+float beta_end
+str prediction_type
+bool clip_sample
+float clip_sample_range
+int num_inference_steps
+bool do_mask_loss_for_padding
+float optimizer_lr
+tuple optimizer_betas
+float optimizer_eps
+float optimizer_weight_decay
+str scheduler_name
+int scheduler_warmup_steps
+post_init()
+get_optimizer_preset()
+get_scheduler_preset()
+validate_features()
+observation_delta_indices
+action_delta_indices
+reward_delta_indices
}
PreTrainedConfig <|-- DiffusionConfig
Validation and Error Handling Flows
flowchart TD
A[DiffusionConfig_post_init] --> B{vision_backbone_valid}
B -- No --> C[Raise_ValueError]
B -- Yes --> D[Continue_1]
D{prediction_type_valid}
D -- No --> E[Raise_ValueError]
D -- Yes --> F[Continue_2]
F{noise_scheduler_type_valid}
F -- No --> G[Raise_ValueError]
F -- Yes --> H[Continue_3]
H{horizon_mod_downsampling_is_0}
H -- No --> I[Raise_ValueError]
H -- Yes --> J[Continue]
Normalization/Unnormalization Data Flow
flowchart LR
A[Raw_Batch] --> B[Normalize_Inputs_Normalize]
B --> C[Model_Forward_Action_Selection]
C --> D[Unnormalize_Outputs_Unnormalize]
D --> E[Environment_Training]
Configuration Propagation
flowchart TD
A[User_Script] --> B[DiffusionConfig]
B --> C[DiffusionPolicy_config]
C --> D[DiffusionModel_config]
D --> E[DiffusionConditionalUnet1d_config]
D --> F[NoiseScheduler_config]
D --> G[DiffusionRgbEncoder_config]
Edge Case: Image Feature Validation
flowchart TD
A[validate_features] --> B{image_features_or_env_state_feature_present}
B -- No --> C[Raise_ValueError]
B -- Yes --> D[Continue_1]
D{crop_shape_fits_all_images}
D -- No --> E[Raise_ValueError]
D -- Yes --> F[Continue_2]
F{all_images_same_shape}
F -- No --> G[Raise_ValueError]
F -- Yes --> H[Continue]