modeling_diffusion - dingdongdengdong/astra_ws GitHub Wiki

modeling_diffusion.py: Code Flow Analysis and Diagram

Mermaid Diagram

graph TD
    A[DiffusionPolicy_PreTrainedPolicy]
    A -->|init| B[DiffusionModel]
    A -->|normalize_inputs_targets| C[Normalize_Unnormalize]
    A -->|reset| D[Queue_Initialization]
    A -->|select_action| E[Action_Selection_Flow]
    A -->|forward| F[Training_Loss_Flow]
    B -->|init| G[DiffusionConditionalUnet1d]
    B -->|init| H[NoiseScheduler]
    B -->|generate_actions| I[ActionGeneration]
    B -->|compute_loss| J[LossComputation]
    G -->|forward| K[UNetForwardPass]
    H[DDPM_DDIM_Scheduler]
    subgraph Encoders
        L[DiffusionRgbEncoder]
        M[SpatialSoftmax]
    end
    G --> L
    L --> M
    subgraph Blocks
        N[DiffusionConditionalResidualBlock1d]
        O[DiffusionConv1dBlock]
        P[DiffusionSinusoidalPosEmb]
    end
    G --> N
    N --> O
    G --> P

Code Flow Explanation

DiffusionPolicy Class

  • Inherits from PreTrainedPolicy and serves as the main interface for the diffusion policy.
  • Handles normalization, queue management, and action selection.
  • Key methods:
    • __init__: Initializes normalization, diffusion model, and queues.
    • reset: Resets observation/action queues.
    • select_action: Maintains observation/action history, invokes the diffusion model to generate actions, and unnormalizes outputs.
    • forward: Computes training loss using normalized inputs and targets.

DiffusionModel Class

  • Core model for action generation and training.
  • Initializes observation encoders (state, image, environment), U-Net, and noise scheduler.
  • Key methods:
    • conditional_sample: Iteratively denoises a sample using the U-Net and noise scheduler.
    • _prepare_global_conditioning: Encodes and concatenates all conditioning features.
    • generate_actions: Prepares conditioning, samples actions, and extracts the relevant action window.
    • compute_loss: Adds noise to actions, runs the U-Net, and computes MSE loss (optionally masking padded actions).

DiffusionConditionalUnet1d Class

  • Implements a 1D U-Net with FiLM conditioning for the diffusion process.
  • Encoder and decoder blocks are built from DiffusionConditionalResidualBlock1d and DiffusionConv1dBlock.
  • Uses sinusoidal positional embeddings for timestep encoding.
  • Forward pass encodes, processes, and decodes the input trajectory, applying FiLM conditioning at each block.

DiffusionRgbEncoder and SpatialSoftmax

  • DiffusionRgbEncoder: Encodes images using a ResNet backbone, optional cropping, and spatial softmax pooling.
  • SpatialSoftmax: Extracts keypoints from feature maps as a form of spatial attention.

Supporting Blocks

  • DiffusionConditionalResidualBlock1d: Residual block with FiLM modulation for global conditioning.
  • DiffusionConv1dBlock: Conv1d + GroupNorm + Mish activation.
  • DiffusionSinusoidalPosEmb: Provides positional embeddings for diffusion timesteps.

Data Flow and Dependencies

  • Data flows from environment observations through normalization, encoding, and the diffusion model, with actions generated via iterative denoising.
  • The architecture is modular, with clear separation between configuration, encoding, diffusion modeling, and training/inference logic.

Summary

This file implements the full diffusion policy pipeline, from observation encoding and normalization to action generation and training, using a modular and extensible architecture suitable for complex visuomotor policy learning tasks.

Additional Diagrams and Explanations

Class Structure Diagram

classDiagram
    class PreTrainedPolicy
    class DiffusionPolicy {
        -DiffusionModel diffusion
        -Normalize normalize_inputs
        -Normalize normalize_targets
        -Unnormalize unnormalize_outputs
        -dict _queues
        +select_action_batch()
        +forward_batch()
        +reset()
    }
    class DiffusionModel {
        -DiffusionConditionalUnet1d unet
        -NoiseScheduler noise_scheduler
        +generate_actions_batch()
        +compute_loss_batch()
        +conditional_sample_batch_size_global_cond()
    }
    class DiffusionConditionalUnet1d {
        -DiffusionSinusoidalPosEmb diffusion_step_encoder
        -ModuleList down_modules
        -ModuleList mid_modules
        -ModuleList up_modules
        -Sequential final_conv
        +forward_x_timestep_global_cond()
    }
    class DiffusionRgbEncoder
    class SpatialSoftmax
    class DiffusionConditionalResidualBlock1d
    class DiffusionConv1dBlock
    class DiffusionSinusoidalPosEmb

    PreTrainedPolicy <|-- DiffusionPolicy
    DiffusionPolicy o-- DiffusionModel
    DiffusionModel o-- DiffusionConditionalUnet1d
    DiffusionConditionalUnet1d o-- DiffusionConditionalResidualBlock1d
    DiffusionConditionalUnet1d o-- DiffusionConv1dBlock
    DiffusionConditionalUnet1d o-- DiffusionSinusoidalPosEmb
    DiffusionModel o-- DiffusionRgbEncoder
    DiffusionRgbEncoder o-- SpatialSoftmax

select_action Method Flowchart

flowchart TD
    A[select_action_batch] --> B[Normalize_inputs]
    B --> C{Has_image_features}
    C -- Yes --> D[Stack_images_into_batch_observation_images]
    C -- No --> E[Skip_image_stack]
    D --> E[Populate_queues_with_batch]
    E --> F{Action_queue_empty}
    F -- Yes --> G[Stack_latest_observations]
    G --> H[Generate_actions_with_diffusion_model]
    H --> I[Unnormalize_actions]
    I --> J[Extend_action_queue]
    F -- No --> K[Skip_action_generation]
    J --> K[Popleft_action_from_queue]
    K --> L[Return_action]

compute_loss Method Flowchart

flowchart TD
    A[compute_loss_batch] --> B[Validate_batch_keys]
    B --> C[Encode_global_conditioning]
    C --> D[Sample_noise]
    D --> E[Sample_random_timestep]
    E --> F[Add_noise_to_trajectory]
    F --> G[Run_UNet_to_predict]
    G --> H{prediction_type}
    H -- epsilon --> I[Target_is_noise]
    H -- sample --> J[Target_is_action]
    I --> K[Compute_MSE_loss]
    J --> K
    K --> L{do_mask_loss_for_padding}
    L -- Yes --> M[Mask_loss_with_action_is_pad]
    L -- No --> N[No_masking]
    M --> O[Return_mean_loss]
    N --> O

DiffusionConditionalUnet1d EncoderDecoderSkipConnections

graph LR
    subgraph Encoder
        E1[Input_x]
        E2[ResidualBlock1]
        E3[ResidualBlock2]
        E4[Downsample]
        E5[ResidualBlock3]
        E6[ResidualBlock4]
        E7[Downsample]
    end
    subgraph Middle
        M1[ResidualBlockMid1]
        M2[ResidualBlockMid2]
    end
    subgraph Decoder
        D1[Concat_skip]
        D2[ResidualBlockUp1]
        D3[ResidualBlockUp2]
        D4[Upsample]
        D5[Concat_skip]
        D6[ResidualBlockUp3]
        D7[ResidualBlockUp4]
        D8[Upsample]
    end
    E1-->E2-->E3-->E4-->E5-->E6-->E7-->M1-->M2
    M2-->D1
    D1-->D2-->D3-->D4-->D5-->D6-->D7-->D8

DiffusionRgbEncoder_ImageFeatureExtraction

flowchart TD
    A[Input_Image_Tensor] --> B{Crop}
    B -- Yes --> C[Random_Center_Crop]
    B -- No --> D[No_Crop]
    C --> D[ResNet_Backbone]
    D --> E[SpatialSoftmax]
    E --> F[Linear_Layer_ReLU]
    F --> G[Output_Feature_Vector]

DataStructureRelationships

erDiagram
    DIFFUSIONCONFIG ||--o{ DIFFUSIONPOLICY : uses
    DIFFUSIONPOLICY ||--o{ DIFFUSIONMODEL : owns
    DIFFUSIONMODEL ||--o{ DIFFUSIONCONDITIONALUNET1D : owns
    DIFFUSIONMODEL ||--o{ DIFFUSIONRGBENCODER : uses
    DIFFUSIONCONDITIONALUNET1D ||--o{ DIFFUSIONCONDITIONALRESIDUALBLOCK1D : contains
    DIFFUSIONCONDITIONALUNET1D ||--o{ DIFFUSIONCONV1DBLOCK : contains
    DIFFUSIONCONDITIONALUNET1D ||--o{ DIFFUSIONSINUSOIDALPOSEMB : uses
    DIFFUSIONRGBENCODER ||--o{ SPATIALSOFTMAX : uses

DiffusionConditionalResidualBlock1d_FiLMModulation

flowchart TD
    A[Input_x_cond] --> B[Conv1d_GroupNorm_Mish]
    B --> C[FiLM_Linear_cond_to_bias_scale]
    C --> D[Apply_bias_scale_to_features]
    D --> E[Conv1d_GroupNorm_Mish]
    E --> F[Residual_connection_Conv1d_if_needed]
    F --> G[Output]

EndToEndPolicyInferencePipeline

flowchart LR
    A[Raw_Observations] --> B[DiffusionPolicy_normalize_inputs]
    B --> C[DiffusionModel_prepare_global_conditioning]
    C --> D[DiffusionConditionalUnet1d_conditional_sample]
    D --> E[DiffusionPolicy_unnormalize_outputs]
    E --> F[Action_Output]

ValidationAndErrorPropagationInSelectAction

flowchart TD
    A[select_action] --> B{image_features_present}
    B -- Yes --> C[Stack_images]
    C --> D[Populate_queues]
    D --> E{Action_queue_empty}
    E -- Yes --> F[Generate_actions]
    F --> G{unnormalize_outputs_raises}
    G -- Yes --> H[Error_dataset_stats_missing]
    G -- No --> I[Continue]
    E -- No --> I[Continue]
    I --> J[Pop_and_return_action]

QueueManagementInDiffusionPolicy

flowchart TD
    A[env_reset] --> B[DiffusionPolicy_reset]
    B --> C[Clear_observation_queue]
    B --> D[Clear_action_queue]
    E[select_action] --> F[Populate_queues_with_new_batch]
    F --> G{Action_queue_empty}
    G -- Yes --> H[Generate_actions_fill_action_queue]
    G -- No --> I[No_action_generation]
    H --> I[Pop_action_from_queue]
    I --> J[Return_action]

NoiseSchedulerAndTimestepHandling

sequenceDiagram
    participant DM as DiffusionModel
    participant NS as NoiseScheduler
    DM->>NS: set_timesteps_num_inference_steps
    loop For_each_timestep_t
        DM->>NS: step_model_output_t_sample
        NS-->>DM: prev_sample
    end

ExtensibilityAndModularity

graph TD
    A[PreTrainedPolicy] -->|inherits| B[DiffusionPolicy]
    B -->|has| C[DiffusionModel]
    C -->|has| D[DiffusionConditionalUnet1d]
    C -->|has| E[NoiseScheduler]
    C -->|has| F[DiffusionRgbEncoder]
    D -->|composed_of| G[ResidualBlocks_ConvBlocks_PosEmb]
    F -->|uses| H[SpatialSoftmax]
    style D fill:#f9f,stroke:#333,stroke-width:2px
    style F fill:#bbf,stroke:#333,stroke-width:2px

select_action_FullSequence

sequenceDiagram
    participant Env as Environment
    participant DP as DiffusionPolicy
    participant Q as Queues
    participant DM as DiffusionModel
    participant UN as Unnormalize

    Env->>DP: Provide_observation_batch
    DP->>DP: Normalize_inputs
    DP->>Q: Update_observation_queue
    alt Action_queue_empty
        DP->>DM: generate_actions_batch
        DM->>DP: Return_actions
        DP->>UN: Unnormalize_actions
        UN->>Q: Fill_action_queue
    end
    Q->>DP: Pop_action
    DP->>Env: Return_action

compute_loss_ErrorAndDataFlow

flowchart TD
    A[compute_loss_batch] --> B[Check_required_keys]
    B -- Missing_keys --> C[Raise_ValueError]
    B -- All_present --> D[Prepare_global_conditioning]
    D --> E[Sample_noise]
    E --> F[Sample_random_timestep]
    F --> G[Add_noise_to_trajectory]
    G --> H[Run_UNet]
    H --> I{prediction_type}
    I -- epsilon --> J[Target_is_noise]
    I -- sample --> K[Target_is_action]
    J --> L[Compute_MSE_loss]
    K --> L
    L --> M{do_mask_loss_for_padding}
    M -- Yes --> N[Mask_loss]
    M -- No --> O[No_masking]
    N --> P[Return_mean_loss]
    O --> P

conditional_sample_DenoisingLoop

flowchart TD
    A[Sample_prior_noise] --> B[Set_scheduler_timesteps]
    B --> C{For_each_timestep}
    C --> D[UNet_predicts_model_output]
    D --> E[Scheduler_steps_xt_to_xtminus1]
    E --> C
    C -- All_steps_done --> F[Return_denoised_sample]

DiffusionConditionalUnet1d_forward_BlockFlow

flowchart TD
    A[Input_x_timestep_global_cond] --> B[Encode_timestep]
    B --> C[Concat_with_global_cond]
    C --> D[Encoder_blocks_down_modules]
    D --> E[Mid_blocks]
    E --> F[Decoder_blocks_up_modules_with_skip_connections]
    F --> G[Final_Conv]
    G --> H[Output]

DiffusionRgbEncoder_forward_ImageFeatureExtraction

flowchart TD
    A[Input_Image_Tensor] --> B{Crop}
    B -- Yes --> C[Random_Center_Crop]
    B -- No --> D[No_Crop]
    C --> D[ResNet_Backbone]
    D --> E[SpatialSoftmax]
    E --> F[Linear_Layer_ReLU]
    F --> G[Output_Feature_Vector]

NormalizationUnnormalizationPipeline

flowchart LR
    A[Raw_Batch] --> B[Normalize_Inputs]
    B --> C[Model_Forward_Action_Selection]
    C --> D[Unnormalize_Outputs]
    D --> E[Environment_Training]

SchedulerAndUNetInteraction

sequenceDiagram
    participant DM as DiffusionModel
    participant NS as NoiseScheduler
    participant UN as UNet
    DM->>NS: set_timesteps_num_inference_steps
    loop For_each_timestep_t
        DM->>UN: Predict_model_output
        UN->>DM: Output
        DM->>NS: step_model_output_t_sample
        NS-->>DM: prev_sample
    end

QueueManagement

flowchart TD
    A[env_reset] --> B[DiffusionPolicy_reset]
    B --> C[Clear_observation_queue]
    B --> D[Clear_action_queue]
    E[select_action] --> F[Populate_queues_with_new_batch]
    F --> G{Action_queue_empty}
    G -- Yes --> H[Generate_actions_fill_action_queue]
    G -- No --> I[No_action_generation]
    H --> I[Pop_action_from_queue]
    I --> J[Return_action]

BatchValidationInComputeLoss

flowchart TD
    A[compute_loss_batch] --> B{Keys_present}
    B -- No --> C[Raise_ValueError]
    B -- Yes --> D[Continue]

ExtensibilityPoints

flowchart TD
    A[DiffusionPolicy] --> B[Custom_Normalize_Unnormalize]
    A --> C[Custom_DiffusionModel]
    C --> D[Custom_UNet]
    C --> E[Custom_NoiseScheduler]
    C --> F[Custom_Encoder]

FullObjectGraph

graph TD
    A[DiffusionPolicy]
    A --> B[DiffusionModel]
    B --> C[DiffusionConditionalUnet1d]
    B --> D[NoiseScheduler]
    B --> E[DiffusionRgbEncoder]
    C --> F[ResidualBlocks]
    C --> G[ConvBlocks]
    C --> H[SinusoidalPosEmb]
    E --> I[SpatialSoftmax]

DataFlow_EnvironmentToAction

flowchart LR
    A[Environment_Observations] --> B[DiffusionPolicy_normalize_inputs]
    B --> C[Feature_Encoding_State_Image_Env]
    C --> D[DiffusionModel_UNet_Scheduler]
    D --> E[Action_Trajectory]
    E --> F[DiffusionPolicy_unnormalize_outputs]
    F --> G[Action_Output]

ErrorPath_ImageShapeMismatch

flowchart TD
    A[validate_features] --> B{All_images_same_shape}
    B -- No --> C[Raise_ValueError]
    B -- Yes --> D[Continue]

ErrorPath_MissingDatasetStats

flowchart TD
    A[select_action] --> B{unnormalize_outputs_needs_stats}
    B -- Yes --> C[Raise_Error_dataset_stats_missing]
    B -- No --> D[Continue]