src package¶
Subpackages¶
- src.attention_maps package
- src.ood package
- src.sae package
- Submodules
- src.sae.builder module
- src.sae.config module
SAETrainingConfig
SAETrainingConfig.activation_fn
SAETrainingConfig.batch_size
SAETrainingConfig.block_size
SAETrainingConfig.device
SAETrainingConfig.early_stopping_min_delta
SAETrainingConfig.early_stopping_patience
SAETrainingConfig.expansion_factor
SAETrainingConfig.feature_dim
SAETrainingConfig.fixed_tokens
SAETrainingConfig.gradient_clip_norm
SAETrainingConfig.infer_model_params_from_cache()
SAETrainingConfig.l1_penalty
SAETrainingConfig.learning_rate
SAETrainingConfig.log_every
SAETrainingConfig.lr_schedule
SAETrainingConfig.max_sampled_tokens
SAETrainingConfig.num_epochs
SAETrainingConfig.num_tokens
SAETrainingConfig.optimizer
SAETrainingConfig.sampling_strategy
SAETrainingConfig.sampling_stride
SAETrainingConfig.save_every
SAETrainingConfig.token_dim
SAETrainingConfig.use_token_sampling
SAETrainingConfig.validate_every
SAETrainingConfig.validation_split
SAETrainingConfig.warmup_epochs
SAETrainingConfig.weight_decay
- src.sae.sae module
- src.sae.token_sampler module
- src.sae.trainer module
SAETrainer
SAETrainer.collect_activations()
SAETrainer.create_model()
SAETrainer.create_optimizer_and_scheduler()
SAETrainer.generate_model_card()
SAETrainer.load_checkpoint()
SAETrainer.push_model_to_hub()
SAETrainer.save_checkpoint()
SAETrainer.save_complete_model()
SAETrainer.train()
SAETrainer.train_step()
load_sae_from_hub()
load_sae_model()
- Module contents
MultiModalSAE
SAEBuilder
SAETrainer
SAETrainer.collect_activations()
SAETrainer.create_model()
SAETrainer.create_optimizer_and_scheduler()
SAETrainer.generate_model_card()
SAETrainer.load_checkpoint()
SAETrainer.push_model_to_hub()
SAETrainer.save_checkpoint()
SAETrainer.save_complete_model()
SAETrainer.train()
SAETrainer.train_step()
SAETrainingConfig
SAETrainingConfig.activation_fn
SAETrainingConfig.batch_size
SAETrainingConfig.block_size
SAETrainingConfig.device
SAETrainingConfig.early_stopping_min_delta
SAETrainingConfig.early_stopping_patience
SAETrainingConfig.expansion_factor
SAETrainingConfig.feature_dim
SAETrainingConfig.fixed_tokens
SAETrainingConfig.gradient_clip_norm
SAETrainingConfig.infer_model_params_from_cache()
SAETrainingConfig.l1_penalty
SAETrainingConfig.learning_rate
SAETrainingConfig.log_every
SAETrainingConfig.lr_schedule
SAETrainingConfig.max_sampled_tokens
SAETrainingConfig.num_epochs
SAETrainingConfig.num_tokens
SAETrainingConfig.optimizer
SAETrainingConfig.sampling_strategy
SAETrainingConfig.sampling_stride
SAETrainingConfig.save_every
SAETrainingConfig.token_dim
SAETrainingConfig.use_token_sampling
SAETrainingConfig.validate_every
SAETrainingConfig.validation_split
SAETrainingConfig.warmup_epochs
SAETrainingConfig.weight_decay
TokenSampler
TokenSamplerConfig
load_original_num_tokens_from_cache()
load_sae_model()
load_sae_model_simple()
- src.utils package
Module contents¶
Physical AI Attention Mapper
A package for interpretability analysis of physical AI models, including: - Sparse Autoencoders (SAE) for feature extraction - Attention mapping for transformer-based policies - Token sampling and activation collection utilities
- class src.ACTPolicyWithAttention(policy, image_shapes=None, specific_decoder_token_index: int | None = None)[source]¶
Bases:
object
Wrapper for ACTPolicy that provides transformer attention visualizations.
- select_action(observation: Dict[str, Tensor]) Tuple[Tensor, Tensor, List[ndarray]] [source]¶
Extends policy.select_action to also compute attention maps.
- Args:
observation: Dictionary of observations
- Returns:
action: The predicted action tensor attention_maps: List of attention maps, one for each image
- visualize_attention(images: List[Tensor] | None = None, attention_maps: List[ndarray] | None = None, observation: Dict[str, Tensor] | None = None, use_rgb: bool = False, overlay_alpha: float = 0.5, show_proprio_border: bool = True, proprio_border_width: int = 15) List[ndarray] [source]¶
Create visualizations by overlaying attention maps on images.
- Args:
images: List of image tensors (optional) attention_maps: List of attention maps (optional) observation: Observation dict (optional, used if images not provided) use_rgb: Whether to use RGB for visualization overlay_alpha: Alpha value for attention overlay
- Returns:
List of visualization images as numpy arrays
- class src.MultiModalSAE(num_tokens: int, token_dim: int = 512, feature_dim: int = 4096, use_bias: bool = True, activation_fn: str = 'leaky_relu', dropout_rate: float = 0.0, device: str = 'cuda', use_bfloat16: bool = False)[source]¶
Bases:
Module
Sparse Autoencoder that processes all tokens from ACT model simultaneously.
num_tokens (int): should match the number of tokens of your ACT model, or the number of sampled tokens when using sampling token_dim (int): should match the dim_model hyperparam in your ACT model feature_dim (int): the number of features the SAE should learn to represent. Usually this would be (num_tokens * token_dim * expansion_factor)
- compute_loss(x: Tensor, l1_penalty: float = 0.0) Dict[str, Tensor] [source]¶
Compute loss with optional regularization Note: Loss computation often benefits from float32 precision for stability
- class src.OODDetector(policy: ACTPolicy, sae_experiment_path: str | None = None, sae_hub_repo_id: str | None = None, ood_params_path: Path | None = None, force_ood_refresh: bool = False, device: str = 'cuda')[source]¶
Bases:
object
Out-of-distribution detector based on detecting unusual attention patterns.
Builds on the SAE model trained to extract features from a given ACT policy. It uses the SAE’s reconstruction error as a proxy for scenarios that are deemed out of distribution.
- fit_ood_threshold_to_validation_dataset(dataset: LeRobotDataset, std_threshold: float = 2.5, batch_size: int = 16, max_samples: int | None = None, save_path: str | None = None) Dict[str, float] [source]¶
Calibrate the out-of-duistribution detector on an unseen validation dataset.
This method runs the OOD Detector for each frame in the dataset, and fits the results on a Gaussian distribution. Anything above the specified standard deviation threshold (defaults to σ=2.5) is deemed out of distribution. While the default value will work for many datasets we recommend tuning it with a value that works best for your own datasets.
Fit the OOD threshold to the validation dataset using Gaussian distribution fitting.
- Args:
dataset: Validation dataset to fit on std_threshold: Number of standard deviations from mean to use as threshold batch_size: Batch size for processing max_samples: Maximum number of samples to process (None for all) save_path: Path to save OOD parameters (if None, uses self.ood_params_path)
- Returns:
Dictionary with fitted parameters
- get_reconstruction_error(observation: dict) float [source]¶
Get the reconstruction error of the observation using the SAE model.
- Args:
observation: Input observation dictionary (same format as policy input)
- Returns:
Reconstruction error (MSE loss)
- class src.SAEBuilder(device: str = 'cuda')[source]¶
Bases:
object
Builder class for loading SAE models with their configurations. Provides convenient methods to load trained SAE models from standard paths.
- classmethod from_default_path(experiment_name: str, base_output_dir: str = 'output', device: str = 'cuda') SAEBuilder [source]¶
Create SAEBuilder and load model from default output structure.
- Args:
experiment_name: Name of experiment (e.g., “sae_drop_footbag_into_di_838a8c8b”) base_output_dir: Base output directory device: Device to load model on
- Returns:
SAEBuilder instance with loaded model
- load_from_experiment(experiment_path: str, checkpoint: str = 'latest', config_filename: str = 'config.json') MultiModalSAE [source]¶
Load SAE model from experiment directory.
- Args:
experiment_path: Path to experiment directory (e.g., “output/sae_drop_footbag_into_di_838a8c8b”) checkpoint: Which checkpoint to load - ‘best’, ‘latest’, or specific epoch number config_filename: Name of config file
- Returns:
Loaded SAE model
- load_from_files(model_path: str, config_path: str) MultiModalSAE [source]¶
Load SAE model from specific model and config files.
- Args:
model_path: Path to safetensors model file config_path: Path to config.json file
- Returns:
Loaded SAE model
- load_from_hub(repo_id: str, filename: str = 'model.safetensors', config_filename: str = 'config.json', revision: str = 'main', cache_dir: str | None = None, force_download: bool = False, token: str | None = None) MultiModalSAE [source]¶
Load SAE model from Hugging Face Hub.
- Args:
repo_id: Repository ID on Hugging Face Hub filename: Model filename to download config_filename: Config filename to download revision: Git revision (branch/tag/commit) cache_dir: Local cache directory force_download: Force re-download even if cached token: Hugging Face token for private repos
- Returns:
Loaded SAE model
- load_with_auto_config(model_path: str, config_path: str | None = None) MultiModalSAE [source]¶
Load SAE model with automatic config discovery.
- Args:
model_path: Path to safetensors model file config_path: Optional path to config.json. If None, searches automatically
- Returns:
Loaded SAE model
- class src.SAETrainer(repo_id: str, policy_path: Path, batch_size: int = 16, num_workers: int = 4, output_directory: Path = 'output', resume_checkpoint: Path | None = None, activation_cache_path: str = '/home/runner/.cache/physical_ai_interpretability/sae_activations', force_cache_refresh: bool = False, use_wandb: bool = False, wandb_project_name: str = 'physical_ai_interpretability', sae_config: SAETrainingConfig | None = None, upload_to_hub: bool = False, hub_repo_id: str | None = None, hub_private: bool = True, hub_license: str = 'mit', hub_tags: list | None = None)[source]¶
Bases:
object
- collect_activations()[source]¶
Collect activations and return cached dataloader with resumption support
- create_optimizer_and_scheduler(model: Module, train_loader: DataLoader)[source]¶
Create optimizer and learning rate scheduler
- load_checkpoint(model: Module, optimizer: Optimizer | None = None, scheduler=None, checkpoint_path: str | None = None, load_best: bool = False)[source]¶
Load model checkpoint from safetensors format
- save_checkpoint(model: Module, optimizer: Optimizer, scheduler, epoch: int, is_best: bool = False)[source]¶
Save model checkpoint using safetensors for model weights
- class src.SAETrainingConfig(num_tokens: int | None = None, token_dim: int | None = None, expansion_factor: float = 1, activation_fn: str = 'relu', use_token_sampling: bool = True, fixed_tokens: list = <factory>, sampling_strategy: str = 'block_average', sampling_stride: int = 8, max_sampled_tokens: int = 200, block_size: int = 8, batch_size: int = 128, learning_rate: float = 0.0001, num_epochs: int = 20, validation_split: float = 0.1, l1_penalty: float = 0.3, optimizer: str = 'adam', weight_decay: float = 1e-05, lr_schedule: str = 'constant', warmup_epochs: int = 2, gradient_clip_norm: float = 1.0, early_stopping_patience: int = 10, early_stopping_min_delta: float = 1e-05, log_every: int = 5, save_every: int = 1000, validate_every: int = 500, device: str = 'cpu')[source]¶
Bases:
object
Configuration for SAE training
- activation_fn: str = 'relu'¶
- batch_size: int = 128¶
- block_size: int = 8¶
- device: str = 'cpu'¶
- early_stopping_min_delta: float = 1e-05¶
- early_stopping_patience: int = 10¶
- expansion_factor: float = 1¶
- property feature_dim: int | None¶
Calculate feature dimension based on expansion factor
- fixed_tokens: list¶
- gradient_clip_norm: float = 1.0¶
- infer_model_params_from_cache(cache_path: str, token_sampler_config=None)[source]¶
Infer model parameters from cached activation data.
- Args:
cache_path: Path to cached activation data token_sampler_config: TokenSamplerConfig object for token sampling.
- Returns:
Self for method chaining
- l1_penalty: float = 0.3¶
- learning_rate: float = 0.0001¶
- log_every: int = 5¶
- lr_schedule: str = 'constant'¶
- max_sampled_tokens: int = 200¶
- num_epochs: int = 20¶
- num_tokens: int | None = None¶
- optimizer: str = 'adam'¶
- sampling_strategy: str = 'block_average'¶
- sampling_stride: int = 8¶
- save_every: int = 1000¶
- token_dim: int | None = None¶
- use_token_sampling: bool = True¶
- validate_every: int = 500¶
- validation_split: float = 0.1¶
- warmup_epochs: int = 2¶
- weight_decay: float = 1e-05¶
- class src.TokenSampler(config: TokenSamplerConfig, total_tokens: int)[source]¶
Bases:
object
Handles consistent token sampling strategies
- class src.TokenSamplerConfig(fixed_tokens: List[int] = None, sampling_strategy: str = 'block_average', sampling_stride: int = 8, max_sampled_tokens: int = 100, random_seed: int = 42, block_size: int = 8)[source]¶
Bases:
object
- block_size: int = 8¶
- fixed_tokens: List[int] = None¶
- max_sampled_tokens: int = 100¶
- random_seed: int = 42¶
- sampling_strategy: str = 'block_average'¶
- sampling_stride: int = 8¶
- src.create_default_ood_params_path(experiment_name: str, base_dir: str = 'output') str [source]¶
Create standard path for OOD parameters based on experiment name.
- Args:
experiment_name: SAE experiment name base_dir: Base output directory
- Returns:
Path string for OOD parameters file
- src.load_sae_model(model_path: str, config_path: str | None = None, device: str = 'cuda')[source]¶
Standalone function to load a trained SAE model from safetensors checkpoint
- Args:
model_path: Path to the safetensors model file config_path: Optional path to config.json file. If None, tries to find it automatically device: Device to load the model on
- Returns:
Loaded SAE model
- src.load_sae_model_simple(experiment_path: str, checkpoint: str = 'best', device: str = 'cuda') MultiModalSAE [source]¶
Simple convenience function to load SAE model from experiment directory.
- Args:
experiment_path: Path to experiment directory checkpoint: Which checkpoint to load device: Device to load on
- Returns:
Loaded SAE model
- src.make_dataset_without_config(repo_id: str, action_delta_indices: List, observation_delta_indices: List | None = None, root: str | None = None, video_backend: str = 'pyav', episodes: list[int] | None = None, revision: str | None = None, use_imagenet_stats: bool = True) LeRobotDataset | MultiLeRobotDataset [source]¶
Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
- Args:
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
- Raises:
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
- Returns:
LeRobotDataset | MultiLeRobotDataset