modeling_diffusion - dingdongdengdong/astra_ws GitHub Wiki
modeling_diffusion.py: Code Flow Analysis and Diagram
Mermaid Diagram
graph TD
A[DiffusionPolicy_PreTrainedPolicy]
A -->|init| B[DiffusionModel]
A -->|normalize_inputs_targets| C[Normalize_Unnormalize]
A -->|reset| D[Queue_Initialization]
A -->|select_action| E[Action_Selection_Flow]
A -->|forward| F[Training_Loss_Flow]
B -->|init| G[DiffusionConditionalUnet1d]
B -->|init| H[NoiseScheduler]
B -->|generate_actions| I[ActionGeneration]
B -->|compute_loss| J[LossComputation]
G -->|forward| K[UNetForwardPass]
H[DDPM_DDIM_Scheduler]
subgraph Encoders
L[DiffusionRgbEncoder]
M[SpatialSoftmax]
end
G --> L
L --> M
subgraph Blocks
N[DiffusionConditionalResidualBlock1d]
O[DiffusionConv1dBlock]
P[DiffusionSinusoidalPosEmb]
end
G --> N
N --> O
G --> P
Code Flow Explanation
DiffusionPolicy Class
- Inherits from
PreTrainedPolicy
and serves as the main interface for the diffusion policy. - Handles normalization, queue management, and action selection.
- Key methods:
__init__
: Initializes normalization, diffusion model, and queues.reset
: Resets observation/action queues.select_action
: Maintains observation/action history, invokes the diffusion model to generate actions, and unnormalizes outputs.forward
: Computes training loss using normalized inputs and targets.
DiffusionModel Class
- Core model for action generation and training.
- Initializes observation encoders (state, image, environment), U-Net, and noise scheduler.
- Key methods:
conditional_sample
: Iteratively denoises a sample using the U-Net and noise scheduler._prepare_global_conditioning
: Encodes and concatenates all conditioning features.generate_actions
: Prepares conditioning, samples actions, and extracts the relevant action window.compute_loss
: Adds noise to actions, runs the U-Net, and computes MSE loss (optionally masking padded actions).
DiffusionConditionalUnet1d Class
- Implements a 1D U-Net with FiLM conditioning for the diffusion process.
- Encoder and decoder blocks are built from
DiffusionConditionalResidualBlock1d
andDiffusionConv1dBlock
. - Uses sinusoidal positional embeddings for timestep encoding.
- Forward pass encodes, processes, and decodes the input trajectory, applying FiLM conditioning at each block.
DiffusionRgbEncoder and SpatialSoftmax
DiffusionRgbEncoder
: Encodes images using a ResNet backbone, optional cropping, and spatial softmax pooling.SpatialSoftmax
: Extracts keypoints from feature maps as a form of spatial attention.
Supporting Blocks
DiffusionConditionalResidualBlock1d
: Residual block with FiLM modulation for global conditioning.DiffusionConv1dBlock
: Conv1d + GroupNorm + Mish activation.DiffusionSinusoidalPosEmb
: Provides positional embeddings for diffusion timesteps.
Data Flow and Dependencies
- Data flows from environment observations through normalization, encoding, and the diffusion model, with actions generated via iterative denoising.
- The architecture is modular, with clear separation between configuration, encoding, diffusion modeling, and training/inference logic.
Summary
This file implements the full diffusion policy pipeline, from observation encoding and normalization to action generation and training, using a modular and extensible architecture suitable for complex visuomotor policy learning tasks.
Additional Diagrams and Explanations
Class Structure Diagram
classDiagram
class PreTrainedPolicy
class DiffusionPolicy {
-DiffusionModel diffusion
-Normalize normalize_inputs
-Normalize normalize_targets
-Unnormalize unnormalize_outputs
-dict _queues
+select_action_batch()
+forward_batch()
+reset()
}
class DiffusionModel {
-DiffusionConditionalUnet1d unet
-NoiseScheduler noise_scheduler
+generate_actions_batch()
+compute_loss_batch()
+conditional_sample_batch_size_global_cond()
}
class DiffusionConditionalUnet1d {
-DiffusionSinusoidalPosEmb diffusion_step_encoder
-ModuleList down_modules
-ModuleList mid_modules
-ModuleList up_modules
-Sequential final_conv
+forward_x_timestep_global_cond()
}
class DiffusionRgbEncoder
class SpatialSoftmax
class DiffusionConditionalResidualBlock1d
class DiffusionConv1dBlock
class DiffusionSinusoidalPosEmb
PreTrainedPolicy <|-- DiffusionPolicy
DiffusionPolicy o-- DiffusionModel
DiffusionModel o-- DiffusionConditionalUnet1d
DiffusionConditionalUnet1d o-- DiffusionConditionalResidualBlock1d
DiffusionConditionalUnet1d o-- DiffusionConv1dBlock
DiffusionConditionalUnet1d o-- DiffusionSinusoidalPosEmb
DiffusionModel o-- DiffusionRgbEncoder
DiffusionRgbEncoder o-- SpatialSoftmax
select_action Method Flowchart
flowchart TD
A[select_action_batch] --> B[Normalize_inputs]
B --> C{Has_image_features}
C -- Yes --> D[Stack_images_into_batch_observation_images]
C -- No --> E[Skip_image_stack]
D --> E[Populate_queues_with_batch]
E --> F{Action_queue_empty}
F -- Yes --> G[Stack_latest_observations]
G --> H[Generate_actions_with_diffusion_model]
H --> I[Unnormalize_actions]
I --> J[Extend_action_queue]
F -- No --> K[Skip_action_generation]
J --> K[Popleft_action_from_queue]
K --> L[Return_action]
compute_loss Method Flowchart
flowchart TD
A[compute_loss_batch] --> B[Validate_batch_keys]
B --> C[Encode_global_conditioning]
C --> D[Sample_noise]
D --> E[Sample_random_timestep]
E --> F[Add_noise_to_trajectory]
F --> G[Run_UNet_to_predict]
G --> H{prediction_type}
H -- epsilon --> I[Target_is_noise]
H -- sample --> J[Target_is_action]
I --> K[Compute_MSE_loss]
J --> K
K --> L{do_mask_loss_for_padding}
L -- Yes --> M[Mask_loss_with_action_is_pad]
L -- No --> N[No_masking]
M --> O[Return_mean_loss]
N --> O
DiffusionConditionalUnet1d EncoderDecoderSkipConnections
graph LR
subgraph Encoder
E1[Input_x]
E2[ResidualBlock1]
E3[ResidualBlock2]
E4[Downsample]
E5[ResidualBlock3]
E6[ResidualBlock4]
E7[Downsample]
end
subgraph Middle
M1[ResidualBlockMid1]
M2[ResidualBlockMid2]
end
subgraph Decoder
D1[Concat_skip]
D2[ResidualBlockUp1]
D3[ResidualBlockUp2]
D4[Upsample]
D5[Concat_skip]
D6[ResidualBlockUp3]
D7[ResidualBlockUp4]
D8[Upsample]
end
E1-->E2-->E3-->E4-->E5-->E6-->E7-->M1-->M2
M2-->D1
D1-->D2-->D3-->D4-->D5-->D6-->D7-->D8
DiffusionRgbEncoder_ImageFeatureExtraction
flowchart TD
A[Input_Image_Tensor] --> B{Crop}
B -- Yes --> C[Random_Center_Crop]
B -- No --> D[No_Crop]
C --> D[ResNet_Backbone]
D --> E[SpatialSoftmax]
E --> F[Linear_Layer_ReLU]
F --> G[Output_Feature_Vector]
DataStructureRelationships
erDiagram
DIFFUSIONCONFIG ||--o{ DIFFUSIONPOLICY : uses
DIFFUSIONPOLICY ||--o{ DIFFUSIONMODEL : owns
DIFFUSIONMODEL ||--o{ DIFFUSIONCONDITIONALUNET1D : owns
DIFFUSIONMODEL ||--o{ DIFFUSIONRGBENCODER : uses
DIFFUSIONCONDITIONALUNET1D ||--o{ DIFFUSIONCONDITIONALRESIDUALBLOCK1D : contains
DIFFUSIONCONDITIONALUNET1D ||--o{ DIFFUSIONCONV1DBLOCK : contains
DIFFUSIONCONDITIONALUNET1D ||--o{ DIFFUSIONSINUSOIDALPOSEMB : uses
DIFFUSIONRGBENCODER ||--o{ SPATIALSOFTMAX : uses
DiffusionConditionalResidualBlock1d_FiLMModulation
flowchart TD
A[Input_x_cond] --> B[Conv1d_GroupNorm_Mish]
B --> C[FiLM_Linear_cond_to_bias_scale]
C --> D[Apply_bias_scale_to_features]
D --> E[Conv1d_GroupNorm_Mish]
E --> F[Residual_connection_Conv1d_if_needed]
F --> G[Output]
EndToEndPolicyInferencePipeline
flowchart LR
A[Raw_Observations] --> B[DiffusionPolicy_normalize_inputs]
B --> C[DiffusionModel_prepare_global_conditioning]
C --> D[DiffusionConditionalUnet1d_conditional_sample]
D --> E[DiffusionPolicy_unnormalize_outputs]
E --> F[Action_Output]
ValidationAndErrorPropagationInSelectAction
flowchart TD
A[select_action] --> B{image_features_present}
B -- Yes --> C[Stack_images]
C --> D[Populate_queues]
D --> E{Action_queue_empty}
E -- Yes --> F[Generate_actions]
F --> G{unnormalize_outputs_raises}
G -- Yes --> H[Error_dataset_stats_missing]
G -- No --> I[Continue]
E -- No --> I[Continue]
I --> J[Pop_and_return_action]
QueueManagementInDiffusionPolicy
flowchart TD
A[env_reset] --> B[DiffusionPolicy_reset]
B --> C[Clear_observation_queue]
B --> D[Clear_action_queue]
E[select_action] --> F[Populate_queues_with_new_batch]
F --> G{Action_queue_empty}
G -- Yes --> H[Generate_actions_fill_action_queue]
G -- No --> I[No_action_generation]
H --> I[Pop_action_from_queue]
I --> J[Return_action]
NoiseSchedulerAndTimestepHandling
sequenceDiagram
participant DM as DiffusionModel
participant NS as NoiseScheduler
DM->>NS: set_timesteps_num_inference_steps
loop For_each_timestep_t
DM->>NS: step_model_output_t_sample
NS-->>DM: prev_sample
end
ExtensibilityAndModularity
graph TD
A[PreTrainedPolicy] -->|inherits| B[DiffusionPolicy]
B -->|has| C[DiffusionModel]
C -->|has| D[DiffusionConditionalUnet1d]
C -->|has| E[NoiseScheduler]
C -->|has| F[DiffusionRgbEncoder]
D -->|composed_of| G[ResidualBlocks_ConvBlocks_PosEmb]
F -->|uses| H[SpatialSoftmax]
style D fill:#f9f,stroke:#333,stroke-width:2px
style F fill:#bbf,stroke:#333,stroke-width:2px
select_action_FullSequence
sequenceDiagram
participant Env as Environment
participant DP as DiffusionPolicy
participant Q as Queues
participant DM as DiffusionModel
participant UN as Unnormalize
Env->>DP: Provide_observation_batch
DP->>DP: Normalize_inputs
DP->>Q: Update_observation_queue
alt Action_queue_empty
DP->>DM: generate_actions_batch
DM->>DP: Return_actions
DP->>UN: Unnormalize_actions
UN->>Q: Fill_action_queue
end
Q->>DP: Pop_action
DP->>Env: Return_action
compute_loss_ErrorAndDataFlow
flowchart TD
A[compute_loss_batch] --> B[Check_required_keys]
B -- Missing_keys --> C[Raise_ValueError]
B -- All_present --> D[Prepare_global_conditioning]
D --> E[Sample_noise]
E --> F[Sample_random_timestep]
F --> G[Add_noise_to_trajectory]
G --> H[Run_UNet]
H --> I{prediction_type}
I -- epsilon --> J[Target_is_noise]
I -- sample --> K[Target_is_action]
J --> L[Compute_MSE_loss]
K --> L
L --> M{do_mask_loss_for_padding}
M -- Yes --> N[Mask_loss]
M -- No --> O[No_masking]
N --> P[Return_mean_loss]
O --> P
conditional_sample_DenoisingLoop
flowchart TD
A[Sample_prior_noise] --> B[Set_scheduler_timesteps]
B --> C{For_each_timestep}
C --> D[UNet_predicts_model_output]
D --> E[Scheduler_steps_xt_to_xtminus1]
E --> C
C -- All_steps_done --> F[Return_denoised_sample]
DiffusionConditionalUnet1d_forward_BlockFlow
flowchart TD
A[Input_x_timestep_global_cond] --> B[Encode_timestep]
B --> C[Concat_with_global_cond]
C --> D[Encoder_blocks_down_modules]
D --> E[Mid_blocks]
E --> F[Decoder_blocks_up_modules_with_skip_connections]
F --> G[Final_Conv]
G --> H[Output]
DiffusionRgbEncoder_forward_ImageFeatureExtraction
flowchart TD
A[Input_Image_Tensor] --> B{Crop}
B -- Yes --> C[Random_Center_Crop]
B -- No --> D[No_Crop]
C --> D[ResNet_Backbone]
D --> E[SpatialSoftmax]
E --> F[Linear_Layer_ReLU]
F --> G[Output_Feature_Vector]
NormalizationUnnormalizationPipeline
flowchart LR
A[Raw_Batch] --> B[Normalize_Inputs]
B --> C[Model_Forward_Action_Selection]
C --> D[Unnormalize_Outputs]
D --> E[Environment_Training]
SchedulerAndUNetInteraction
sequenceDiagram
participant DM as DiffusionModel
participant NS as NoiseScheduler
participant UN as UNet
DM->>NS: set_timesteps_num_inference_steps
loop For_each_timestep_t
DM->>UN: Predict_model_output
UN->>DM: Output
DM->>NS: step_model_output_t_sample
NS-->>DM: prev_sample
end
QueueManagement
flowchart TD
A[env_reset] --> B[DiffusionPolicy_reset]
B --> C[Clear_observation_queue]
B --> D[Clear_action_queue]
E[select_action] --> F[Populate_queues_with_new_batch]
F --> G{Action_queue_empty}
G -- Yes --> H[Generate_actions_fill_action_queue]
G -- No --> I[No_action_generation]
H --> I[Pop_action_from_queue]
I --> J[Return_action]
BatchValidationInComputeLoss
flowchart TD
A[compute_loss_batch] --> B{Keys_present}
B -- No --> C[Raise_ValueError]
B -- Yes --> D[Continue]
ExtensibilityPoints
flowchart TD
A[DiffusionPolicy] --> B[Custom_Normalize_Unnormalize]
A --> C[Custom_DiffusionModel]
C --> D[Custom_UNet]
C --> E[Custom_NoiseScheduler]
C --> F[Custom_Encoder]
FullObjectGraph
graph TD
A[DiffusionPolicy]
A --> B[DiffusionModel]
B --> C[DiffusionConditionalUnet1d]
B --> D[NoiseScheduler]
B --> E[DiffusionRgbEncoder]
C --> F[ResidualBlocks]
C --> G[ConvBlocks]
C --> H[SinusoidalPosEmb]
E --> I[SpatialSoftmax]
DataFlow_EnvironmentToAction
flowchart LR
A[Environment_Observations] --> B[DiffusionPolicy_normalize_inputs]
B --> C[Feature_Encoding_State_Image_Env]
C --> D[DiffusionModel_UNet_Scheduler]
D --> E[Action_Trajectory]
E --> F[DiffusionPolicy_unnormalize_outputs]
F --> G[Action_Output]
ErrorPath_ImageShapeMismatch
flowchart TD
A[validate_features] --> B{All_images_same_shape}
B -- No --> C[Raise_ValueError]
B -- Yes --> D[Continue]
ErrorPath_MissingDatasetStats
flowchart TD
A[select_action] --> B{unnormalize_outputs_needs_stats}
B -- Yes --> C[Raise_Error_dataset_stats_missing]
B -- No --> D[Continue]