src.sae package¶
Submodules¶
src.sae.builder module¶
- class src.sae.builder.SAEBuilder(device: str = 'cuda')[source]¶
Bases:
object
Builder class for loading SAE models with their configurations. Provides convenient methods to load trained SAE models from standard paths.
- classmethod from_default_path(experiment_name: str, base_output_dir: str = 'output', device: str = 'cuda') SAEBuilder [source]¶
Create SAEBuilder and load model from default output structure.
- Args:
experiment_name: Name of experiment (e.g., “sae_drop_footbag_into_di_838a8c8b”) base_output_dir: Base output directory device: Device to load model on
- Returns:
SAEBuilder instance with loaded model
- load_from_experiment(experiment_path: str, checkpoint: str = 'latest', config_filename: str = 'config.json') MultiModalSAE [source]¶
Load SAE model from experiment directory.
- Args:
experiment_path: Path to experiment directory (e.g., “output/sae_drop_footbag_into_di_838a8c8b”) checkpoint: Which checkpoint to load - ‘best’, ‘latest’, or specific epoch number config_filename: Name of config file
- Returns:
Loaded SAE model
- load_from_files(model_path: str, config_path: str) MultiModalSAE [source]¶
Load SAE model from specific model and config files.
- Args:
model_path: Path to safetensors model file config_path: Path to config.json file
- Returns:
Loaded SAE model
- load_from_hub(repo_id: str, filename: str = 'model.safetensors', config_filename: str = 'config.json', revision: str = 'main', cache_dir: str | None = None, force_download: bool = False, token: str | None = None) MultiModalSAE [source]¶
Load SAE model from Hugging Face Hub.
- Args:
repo_id: Repository ID on Hugging Face Hub filename: Model filename to download config_filename: Config filename to download revision: Git revision (branch/tag/commit) cache_dir: Local cache directory force_download: Force re-download even if cached token: Hugging Face token for private repos
- Returns:
Loaded SAE model
- load_with_auto_config(model_path: str, config_path: str | None = None) MultiModalSAE [source]¶
Load SAE model with automatic config discovery.
- Args:
model_path: Path to safetensors model file config_path: Optional path to config.json. If None, searches automatically
- Returns:
Loaded SAE model
- src.sae.builder.load_sae_model_simple(experiment_path: str, checkpoint: str = 'best', device: str = 'cuda') MultiModalSAE [source]¶
Simple convenience function to load SAE model from experiment directory.
- Args:
experiment_path: Path to experiment directory checkpoint: Which checkpoint to load device: Device to load on
- Returns:
Loaded SAE model
src.sae.config module¶
- class src.sae.config.SAETrainingConfig(num_tokens: int | None = None, token_dim: int | None = None, expansion_factor: float = 1, activation_fn: str = 'relu', use_token_sampling: bool = True, fixed_tokens: list = <factory>, sampling_strategy: str = 'block_average', sampling_stride: int = 8, max_sampled_tokens: int = 200, block_size: int = 8, batch_size: int = 128, learning_rate: float = 0.0001, num_epochs: int = 20, validation_split: float = 0.1, l1_penalty: float = 0.3, optimizer: str = 'adam', weight_decay: float = 1e-05, lr_schedule: str = 'constant', warmup_epochs: int = 2, gradient_clip_norm: float = 1.0, early_stopping_patience: int = 10, early_stopping_min_delta: float = 1e-05, log_every: int = 5, save_every: int = 1000, validate_every: int = 500, device: str = 'cpu')[source]¶
Bases:
object
Configuration for SAE training
- activation_fn: str = 'relu'¶
- batch_size: int = 128¶
- block_size: int = 8¶
- device: str = 'cpu'¶
- early_stopping_min_delta: float = 1e-05¶
- early_stopping_patience: int = 10¶
- expansion_factor: float = 1¶
- property feature_dim: int | None¶
Calculate feature dimension based on expansion factor
- fixed_tokens: list¶
- gradient_clip_norm: float = 1.0¶
- infer_model_params_from_cache(cache_path: str, token_sampler_config=None)[source]¶
Infer model parameters from cached activation data.
- Args:
cache_path: Path to cached activation data token_sampler_config: TokenSamplerConfig object for token sampling.
- Returns:
Self for method chaining
- l1_penalty: float = 0.3¶
- learning_rate: float = 0.0001¶
- log_every: int = 5¶
- lr_schedule: str = 'constant'¶
- max_sampled_tokens: int = 200¶
- num_epochs: int = 20¶
- num_tokens: int | None = None¶
- optimizer: str = 'adam'¶
- sampling_strategy: str = 'block_average'¶
- sampling_stride: int = 8¶
- save_every: int = 1000¶
- token_dim: int | None = None¶
- use_token_sampling: bool = True¶
- validate_every: int = 500¶
- validation_split: float = 0.1¶
- warmup_epochs: int = 2¶
- weight_decay: float = 1e-05¶
src.sae.sae module¶
- class src.sae.sae.MultiModalSAE(num_tokens: int, token_dim: int = 512, feature_dim: int = 4096, use_bias: bool = True, activation_fn: str = 'leaky_relu', dropout_rate: float = 0.0, device: str = 'cuda', use_bfloat16: bool = False)[source]¶
Bases:
Module
Sparse Autoencoder that processes all tokens from ACT model simultaneously.
num_tokens (int): should match the number of tokens of your ACT model, or the number of sampled tokens when using sampling token_dim (int): should match the dim_model hyperparam in your ACT model feature_dim (int): the number of features the SAE should learn to represent. Usually this would be (num_tokens * token_dim * expansion_factor)
- compute_loss(x: Tensor, l1_penalty: float = 0.0) Dict[str, Tensor] [source]¶
Compute loss with optional regularization Note: Loss computation often benefits from float32 precision for stability
src.sae.token_sampler module¶
- class src.sae.token_sampler.TokenSampler(config: TokenSamplerConfig, total_tokens: int)[source]¶
Bases:
object
Handles consistent token sampling strategies
- class src.sae.token_sampler.TokenSamplerConfig(fixed_tokens: List[int] = None, sampling_strategy: str = 'block_average', sampling_stride: int = 8, max_sampled_tokens: int = 100, random_seed: int = 42, block_size: int = 8)[source]¶
Bases:
object
- block_size: int = 8¶
- fixed_tokens: List[int] = None¶
- max_sampled_tokens: int = 100¶
- random_seed: int = 42¶
- sampling_strategy: str = 'block_average'¶
- sampling_stride: int = 8¶
src.sae.trainer module¶
- class src.sae.trainer.SAETrainer(repo_id: str, policy_path: Path, batch_size: int = 16, num_workers: int = 4, output_directory: Path = 'output', resume_checkpoint: Path | None = None, activation_cache_path: str = '/home/runner/.cache/physical_ai_interpretability/sae_activations', force_cache_refresh: bool = False, use_wandb: bool = False, wandb_project_name: str = 'physical_ai_interpretability', sae_config: SAETrainingConfig | None = None, upload_to_hub: bool = False, hub_repo_id: str | None = None, hub_private: bool = True, hub_license: str = 'mit', hub_tags: list | None = None)[source]¶
Bases:
object
- collect_activations()[source]¶
Collect activations and return cached dataloader with resumption support
- create_optimizer_and_scheduler(model: Module, train_loader: DataLoader)[source]¶
Create optimizer and learning rate scheduler
- load_checkpoint(model: Module, optimizer: Optimizer | None = None, scheduler=None, checkpoint_path: str | None = None, load_best: bool = False)[source]¶
Load model checkpoint from safetensors format
- save_checkpoint(model: Module, optimizer: Optimizer, scheduler, epoch: int, is_best: bool = False)[source]¶
Save model checkpoint using safetensors for model weights
- src.sae.trainer.load_sae_from_hub(repo_id: str, filename: str = 'model.safetensors', config_filename: str = 'config.json', revision: str = 'main', cache_dir: str | None = None, force_download: bool = False, token: str | None = None, device: str = 'cuda')[source]¶
Load SAE model from Hugging Face Hub
- Args:
repo_id: Repository ID on Hugging Face Hub filename: Model filename to download config_filename: Config filename to download revision: Git revision (branch/tag/commit) cache_dir: Local cache directory force_download: Force re-download even if cached token: Hugging Face token for private repos device: Device to load model on
- Returns:
Loaded SAE model
- src.sae.trainer.load_sae_model(model_path: str, config_path: str | None = None, device: str = 'cuda')[source]¶
Standalone function to load a trained SAE model from safetensors checkpoint
- Args:
model_path: Path to the safetensors model file config_path: Optional path to config.json file. If None, tries to find it automatically device: Device to load the model on
- Returns:
Loaded SAE model
Module contents¶
- class src.sae.MultiModalSAE(num_tokens: int, token_dim: int = 512, feature_dim: int = 4096, use_bias: bool = True, activation_fn: str = 'leaky_relu', dropout_rate: float = 0.0, device: str = 'cuda', use_bfloat16: bool = False)[source]¶
Bases:
Module
Sparse Autoencoder that processes all tokens from ACT model simultaneously.
num_tokens (int): should match the number of tokens of your ACT model, or the number of sampled tokens when using sampling token_dim (int): should match the dim_model hyperparam in your ACT model feature_dim (int): the number of features the SAE should learn to represent. Usually this would be (num_tokens * token_dim * expansion_factor)
- compute_loss(x: Tensor, l1_penalty: float = 0.0) Dict[str, Tensor] [source]¶
Compute loss with optional regularization Note: Loss computation often benefits from float32 precision for stability
- class src.sae.SAEBuilder(device: str = 'cuda')[source]¶
Bases:
object
Builder class for loading SAE models with their configurations. Provides convenient methods to load trained SAE models from standard paths.
- classmethod from_default_path(experiment_name: str, base_output_dir: str = 'output', device: str = 'cuda') SAEBuilder [source]¶
Create SAEBuilder and load model from default output structure.
- Args:
experiment_name: Name of experiment (e.g., “sae_drop_footbag_into_di_838a8c8b”) base_output_dir: Base output directory device: Device to load model on
- Returns:
SAEBuilder instance with loaded model
- load_from_experiment(experiment_path: str, checkpoint: str = 'latest', config_filename: str = 'config.json') MultiModalSAE [source]¶
Load SAE model from experiment directory.
- Args:
experiment_path: Path to experiment directory (e.g., “output/sae_drop_footbag_into_di_838a8c8b”) checkpoint: Which checkpoint to load - ‘best’, ‘latest’, or specific epoch number config_filename: Name of config file
- Returns:
Loaded SAE model
- load_from_files(model_path: str, config_path: str) MultiModalSAE [source]¶
Load SAE model from specific model and config files.
- Args:
model_path: Path to safetensors model file config_path: Path to config.json file
- Returns:
Loaded SAE model
- load_from_hub(repo_id: str, filename: str = 'model.safetensors', config_filename: str = 'config.json', revision: str = 'main', cache_dir: str | None = None, force_download: bool = False, token: str | None = None) MultiModalSAE [source]¶
Load SAE model from Hugging Face Hub.
- Args:
repo_id: Repository ID on Hugging Face Hub filename: Model filename to download config_filename: Config filename to download revision: Git revision (branch/tag/commit) cache_dir: Local cache directory force_download: Force re-download even if cached token: Hugging Face token for private repos
- Returns:
Loaded SAE model
- load_with_auto_config(model_path: str, config_path: str | None = None) MultiModalSAE [source]¶
Load SAE model with automatic config discovery.
- Args:
model_path: Path to safetensors model file config_path: Optional path to config.json. If None, searches automatically
- Returns:
Loaded SAE model
- class src.sae.SAETrainer(repo_id: str, policy_path: Path, batch_size: int = 16, num_workers: int = 4, output_directory: Path = 'output', resume_checkpoint: Path | None = None, activation_cache_path: str = '/home/runner/.cache/physical_ai_interpretability/sae_activations', force_cache_refresh: bool = False, use_wandb: bool = False, wandb_project_name: str = 'physical_ai_interpretability', sae_config: SAETrainingConfig | None = None, upload_to_hub: bool = False, hub_repo_id: str | None = None, hub_private: bool = True, hub_license: str = 'mit', hub_tags: list | None = None)[source]¶
Bases:
object
- collect_activations()[source]¶
Collect activations and return cached dataloader with resumption support
- create_optimizer_and_scheduler(model: Module, train_loader: DataLoader)[source]¶
Create optimizer and learning rate scheduler
- load_checkpoint(model: Module, optimizer: Optimizer | None = None, scheduler=None, checkpoint_path: str | None = None, load_best: bool = False)[source]¶
Load model checkpoint from safetensors format
- save_checkpoint(model: Module, optimizer: Optimizer, scheduler, epoch: int, is_best: bool = False)[source]¶
Save model checkpoint using safetensors for model weights
- class src.sae.SAETrainingConfig(num_tokens: int | None = None, token_dim: int | None = None, expansion_factor: float = 1, activation_fn: str = 'relu', use_token_sampling: bool = True, fixed_tokens: list = <factory>, sampling_strategy: str = 'block_average', sampling_stride: int = 8, max_sampled_tokens: int = 200, block_size: int = 8, batch_size: int = 128, learning_rate: float = 0.0001, num_epochs: int = 20, validation_split: float = 0.1, l1_penalty: float = 0.3, optimizer: str = 'adam', weight_decay: float = 1e-05, lr_schedule: str = 'constant', warmup_epochs: int = 2, gradient_clip_norm: float = 1.0, early_stopping_patience: int = 10, early_stopping_min_delta: float = 1e-05, log_every: int = 5, save_every: int = 1000, validate_every: int = 500, device: str = 'cpu')[source]¶
Bases:
object
Configuration for SAE training
- activation_fn: str = 'relu'¶
- batch_size: int = 128¶
- block_size: int = 8¶
- device: str = 'cpu'¶
- early_stopping_min_delta: float = 1e-05¶
- early_stopping_patience: int = 10¶
- expansion_factor: float = 1¶
- property feature_dim: int | None¶
Calculate feature dimension based on expansion factor
- fixed_tokens: list¶
- gradient_clip_norm: float = 1.0¶
- infer_model_params_from_cache(cache_path: str, token_sampler_config=None)[source]¶
Infer model parameters from cached activation data.
- Args:
cache_path: Path to cached activation data token_sampler_config: TokenSamplerConfig object for token sampling.
- Returns:
Self for method chaining
- l1_penalty: float = 0.3¶
- learning_rate: float = 0.0001¶
- log_every: int = 5¶
- lr_schedule: str = 'constant'¶
- max_sampled_tokens: int = 200¶
- num_epochs: int = 20¶
- num_tokens: int | None = None¶
- optimizer: str = 'adam'¶
- sampling_strategy: str = 'block_average'¶
- sampling_stride: int = 8¶
- save_every: int = 1000¶
- token_dim: int | None = None¶
- use_token_sampling: bool = True¶
- validate_every: int = 500¶
- validation_split: float = 0.1¶
- warmup_epochs: int = 2¶
- weight_decay: float = 1e-05¶
- class src.sae.TokenSampler(config: TokenSamplerConfig, total_tokens: int)[source]¶
Bases:
object
Handles consistent token sampling strategies
- class src.sae.TokenSamplerConfig(fixed_tokens: List[int] = None, sampling_strategy: str = 'block_average', sampling_stride: int = 8, max_sampled_tokens: int = 100, random_seed: int = 42, block_size: int = 8)[source]¶
Bases:
object
- block_size: int = 8¶
- fixed_tokens: List[int] = None¶
- max_sampled_tokens: int = 100¶
- random_seed: int = 42¶
- sampling_strategy: str = 'block_average'¶
- sampling_stride: int = 8¶
- src.sae.load_original_num_tokens_from_cache(cache_path: str) int | None [source]¶
Load the original number of tokens from cached activation metadata.
- Args:
cache_path: Path to the activation cache directory
- Returns:
Original number of tokens if found, None otherwise
- src.sae.load_sae_model(model_path: str, config_path: str | None = None, device: str = 'cuda')[source]¶
Standalone function to load a trained SAE model from safetensors checkpoint
- Args:
model_path: Path to the safetensors model file config_path: Optional path to config.json file. If None, tries to find it automatically device: Device to load the model on
- Returns:
Loaded SAE model
- src.sae.load_sae_model_simple(experiment_path: str, checkpoint: str = 'best', device: str = 'cuda') MultiModalSAE [source]¶
Simple convenience function to load SAE model from experiment directory.
- Args:
experiment_path: Path to experiment directory checkpoint: Which checkpoint to load device: Device to load on
- Returns:
Loaded SAE model