src.sae package

Submodules

src.sae.builder module

class src.sae.builder.SAEBuilder(device: str = 'cuda')[source]

Bases: object

Builder class for loading SAE models with their configurations. Provides convenient methods to load trained SAE models from standard paths.

classmethod from_default_path(experiment_name: str, base_output_dir: str = 'output', device: str = 'cuda') SAEBuilder[source]

Create SAEBuilder and load model from default output structure.

Args:

experiment_name: Name of experiment (e.g., “sae_drop_footbag_into_di_838a8c8b”) base_output_dir: Base output directory device: Device to load model on

Returns:

SAEBuilder instance with loaded model

load_from_experiment(experiment_path: str, checkpoint: str = 'latest', config_filename: str = 'config.json') MultiModalSAE[source]

Load SAE model from experiment directory.

Args:

experiment_path: Path to experiment directory (e.g., “output/sae_drop_footbag_into_di_838a8c8b”) checkpoint: Which checkpoint to load - ‘best’, ‘latest’, or specific epoch number config_filename: Name of config file

Returns:

Loaded SAE model

load_from_files(model_path: str, config_path: str) MultiModalSAE[source]

Load SAE model from specific model and config files.

Args:

model_path: Path to safetensors model file config_path: Path to config.json file

Returns:

Loaded SAE model

load_from_hub(repo_id: str, filename: str = 'model.safetensors', config_filename: str = 'config.json', revision: str = 'main', cache_dir: str | None = None, force_download: bool = False, token: str | None = None) MultiModalSAE[source]

Load SAE model from Hugging Face Hub.

Args:

repo_id: Repository ID on Hugging Face Hub filename: Model filename to download config_filename: Config filename to download revision: Git revision (branch/tag/commit) cache_dir: Local cache directory force_download: Force re-download even if cached token: Hugging Face token for private repos

Returns:

Loaded SAE model

load_with_auto_config(model_path: str, config_path: str | None = None) MultiModalSAE[source]

Load SAE model with automatic config discovery.

Args:

model_path: Path to safetensors model file config_path: Optional path to config.json. If None, searches automatically

Returns:

Loaded SAE model

src.sae.builder.load_sae_model_simple(experiment_path: str, checkpoint: str = 'best', device: str = 'cuda') MultiModalSAE[source]

Simple convenience function to load SAE model from experiment directory.

Args:

experiment_path: Path to experiment directory checkpoint: Which checkpoint to load device: Device to load on

Returns:

Loaded SAE model

src.sae.config module

class src.sae.config.SAETrainingConfig(num_tokens: int | None = None, token_dim: int | None = None, expansion_factor: float = 1, activation_fn: str = 'relu', use_token_sampling: bool = True, fixed_tokens: list = <factory>, sampling_strategy: str = 'block_average', sampling_stride: int = 8, max_sampled_tokens: int = 200, block_size: int = 8, batch_size: int = 128, learning_rate: float = 0.0001, num_epochs: int = 20, validation_split: float = 0.1, l1_penalty: float = 0.3, optimizer: str = 'adam', weight_decay: float = 1e-05, lr_schedule: str = 'constant', warmup_epochs: int = 2, gradient_clip_norm: float = 1.0, early_stopping_patience: int = 10, early_stopping_min_delta: float = 1e-05, log_every: int = 5, save_every: int = 1000, validate_every: int = 500, device: str = 'cpu')[source]

Bases: object

Configuration for SAE training

activation_fn: str = 'relu'
batch_size: int = 128
block_size: int = 8
device: str = 'cpu'
early_stopping_min_delta: float = 1e-05
early_stopping_patience: int = 10
expansion_factor: float = 1
property feature_dim: int | None

Calculate feature dimension based on expansion factor

fixed_tokens: list
gradient_clip_norm: float = 1.0
infer_model_params_from_cache(cache_path: str, token_sampler_config=None)[source]

Infer model parameters from cached activation data.

Args:

cache_path: Path to cached activation data token_sampler_config: TokenSamplerConfig object for token sampling.

Returns:

Self for method chaining

l1_penalty: float = 0.3
learning_rate: float = 0.0001
log_every: int = 5
lr_schedule: str = 'constant'
max_sampled_tokens: int = 200
num_epochs: int = 20
num_tokens: int | None = None
optimizer: str = 'adam'
sampling_strategy: str = 'block_average'
sampling_stride: int = 8
save_every: int = 1000
token_dim: int | None = None
use_token_sampling: bool = True
validate_every: int = 500
validation_split: float = 0.1
warmup_epochs: int = 2
weight_decay: float = 1e-05

src.sae.sae module

class src.sae.sae.MultiModalSAE(num_tokens: int, token_dim: int = 512, feature_dim: int = 4096, use_bias: bool = True, activation_fn: str = 'leaky_relu', dropout_rate: float = 0.0, device: str = 'cuda', use_bfloat16: bool = False)[source]

Bases: Module

Sparse Autoencoder that processes all tokens from ACT model simultaneously.

num_tokens (int): should match the number of tokens of your ACT model, or the number of sampled tokens when using sampling token_dim (int): should match the dim_model hyperparam in your ACT model feature_dim (int): the number of features the SAE should learn to represent. Usually this would be (num_tokens * token_dim * expansion_factor)

compute_loss(x: Tensor, l1_penalty: float = 0.0) Dict[str, Tensor][source]

Compute loss with optional regularization Note: Loss computation often benefits from float32 precision for stability

decode(features: Tensor) Tensor[source]

Decode features back to token space

encode(x: Tensor) Tensor[source]

Encode flattened token representation to feature space Now handles bfloat16 conversion automatically

forward(x: Tensor) Tuple[Tensor, Tensor][source]

Full forward pass: encode then decode

src.sae.token_sampler module

class src.sae.token_sampler.TokenSampler(config: TokenSamplerConfig, total_tokens: int)[source]

Bases: object

Handles consistent token sampling strategies

get_sampling_info() Dict[str, Any][source]

Get information about the sampling configuration

sample_tokens(activations: Tensor) Tensor[source]

Sample tokens from activation tensor

Args:

activations: Tensor of shape (batch_size, num_tokens, token_dim)

Returns:

Sampled activations of shape (batch_size, num_sampled_tokens, token_dim)

class src.sae.token_sampler.TokenSamplerConfig(fixed_tokens: List[int] = None, sampling_strategy: str = 'block_average', sampling_stride: int = 8, max_sampled_tokens: int = 100, random_seed: int = 42, block_size: int = 8)[source]

Bases: object

block_size: int = 8
fixed_tokens: List[int] = None
max_sampled_tokens: int = 100
random_seed: int = 42
sampling_strategy: str = 'block_average'
sampling_stride: int = 8

src.sae.trainer module

class src.sae.trainer.SAETrainer(repo_id: str, policy_path: Path, batch_size: int = 16, num_workers: int = 4, output_directory: Path = 'output', resume_checkpoint: Path | None = None, activation_cache_path: str = '/home/runner/.cache/physical_ai_interpretability/sae_activations', force_cache_refresh: bool = False, use_wandb: bool = False, wandb_project_name: str = 'physical_ai_interpretability', sae_config: SAETrainingConfig | None = None, upload_to_hub: bool = False, hub_repo_id: str | None = None, hub_private: bool = True, hub_license: str = 'mit', hub_tags: list | None = None)[source]

Bases: object

collect_activations()[source]

Collect activations and return cached dataloader with resumption support

create_model() Module[source]

Create SAE model based on config

create_optimizer_and_scheduler(model: Module, train_loader: DataLoader)[source]

Create optimizer and learning rate scheduler

generate_model_card() str[source]

Generate a model card for the SAE model

load_checkpoint(model: Module, optimizer: Optimizer | None = None, scheduler=None, checkpoint_path: str | None = None, load_best: bool = False)[source]

Load model checkpoint from safetensors format

push_model_to_hub(complete_model_dir: Path)[source]

Push the complete model to Hugging Face Hub

save_checkpoint(model: Module, optimizer: Optimizer, scheduler, epoch: int, is_best: bool = False)[source]

Save model checkpoint using safetensors for model weights

save_complete_model(model: Module, epoch: int | None = None)[source]

Save the complete model in a ‘complete’ folder ready for Hugging Face upload. This includes model.safetensors, config.json, and training_state.pt

train()[source]

Main training method

Returns:

Trained SAE model

train_step(model: Module, batch: Tensor) Dict[str, float][source]

Single training step - returns both scalars and tensors

src.sae.trainer.load_sae_from_hub(repo_id: str, filename: str = 'model.safetensors', config_filename: str = 'config.json', revision: str = 'main', cache_dir: str | None = None, force_download: bool = False, token: str | None = None, device: str = 'cuda')[source]

Load SAE model from Hugging Face Hub

Args:

repo_id: Repository ID on Hugging Face Hub filename: Model filename to download config_filename: Config filename to download revision: Git revision (branch/tag/commit) cache_dir: Local cache directory force_download: Force re-download even if cached token: Hugging Face token for private repos device: Device to load model on

Returns:

Loaded SAE model

src.sae.trainer.load_sae_model(model_path: str, config_path: str | None = None, device: str = 'cuda')[source]

Standalone function to load a trained SAE model from safetensors checkpoint

Args:

model_path: Path to the safetensors model file config_path: Optional path to config.json file. If None, tries to find it automatically device: Device to load the model on

Returns:

Loaded SAE model

Module contents

class src.sae.MultiModalSAE(num_tokens: int, token_dim: int = 512, feature_dim: int = 4096, use_bias: bool = True, activation_fn: str = 'leaky_relu', dropout_rate: float = 0.0, device: str = 'cuda', use_bfloat16: bool = False)[source]

Bases: Module

Sparse Autoencoder that processes all tokens from ACT model simultaneously.

num_tokens (int): should match the number of tokens of your ACT model, or the number of sampled tokens when using sampling token_dim (int): should match the dim_model hyperparam in your ACT model feature_dim (int): the number of features the SAE should learn to represent. Usually this would be (num_tokens * token_dim * expansion_factor)

compute_loss(x: Tensor, l1_penalty: float = 0.0) Dict[str, Tensor][source]

Compute loss with optional regularization Note: Loss computation often benefits from float32 precision for stability

decode(features: Tensor) Tensor[source]

Decode features back to token space

encode(x: Tensor) Tensor[source]

Encode flattened token representation to feature space Now handles bfloat16 conversion automatically

forward(x: Tensor) Tuple[Tensor, Tensor][source]

Full forward pass: encode then decode

class src.sae.SAEBuilder(device: str = 'cuda')[source]

Bases: object

Builder class for loading SAE models with their configurations. Provides convenient methods to load trained SAE models from standard paths.

classmethod from_default_path(experiment_name: str, base_output_dir: str = 'output', device: str = 'cuda') SAEBuilder[source]

Create SAEBuilder and load model from default output structure.

Args:

experiment_name: Name of experiment (e.g., “sae_drop_footbag_into_di_838a8c8b”) base_output_dir: Base output directory device: Device to load model on

Returns:

SAEBuilder instance with loaded model

load_from_experiment(experiment_path: str, checkpoint: str = 'latest', config_filename: str = 'config.json') MultiModalSAE[source]

Load SAE model from experiment directory.

Args:

experiment_path: Path to experiment directory (e.g., “output/sae_drop_footbag_into_di_838a8c8b”) checkpoint: Which checkpoint to load - ‘best’, ‘latest’, or specific epoch number config_filename: Name of config file

Returns:

Loaded SAE model

load_from_files(model_path: str, config_path: str) MultiModalSAE[source]

Load SAE model from specific model and config files.

Args:

model_path: Path to safetensors model file config_path: Path to config.json file

Returns:

Loaded SAE model

load_from_hub(repo_id: str, filename: str = 'model.safetensors', config_filename: str = 'config.json', revision: str = 'main', cache_dir: str | None = None, force_download: bool = False, token: str | None = None) MultiModalSAE[source]

Load SAE model from Hugging Face Hub.

Args:

repo_id: Repository ID on Hugging Face Hub filename: Model filename to download config_filename: Config filename to download revision: Git revision (branch/tag/commit) cache_dir: Local cache directory force_download: Force re-download even if cached token: Hugging Face token for private repos

Returns:

Loaded SAE model

load_with_auto_config(model_path: str, config_path: str | None = None) MultiModalSAE[source]

Load SAE model with automatic config discovery.

Args:

model_path: Path to safetensors model file config_path: Optional path to config.json. If None, searches automatically

Returns:

Loaded SAE model

class src.sae.SAETrainer(repo_id: str, policy_path: Path, batch_size: int = 16, num_workers: int = 4, output_directory: Path = 'output', resume_checkpoint: Path | None = None, activation_cache_path: str = '/home/runner/.cache/physical_ai_interpretability/sae_activations', force_cache_refresh: bool = False, use_wandb: bool = False, wandb_project_name: str = 'physical_ai_interpretability', sae_config: SAETrainingConfig | None = None, upload_to_hub: bool = False, hub_repo_id: str | None = None, hub_private: bool = True, hub_license: str = 'mit', hub_tags: list | None = None)[source]

Bases: object

collect_activations()[source]

Collect activations and return cached dataloader with resumption support

create_model() Module[source]

Create SAE model based on config

create_optimizer_and_scheduler(model: Module, train_loader: DataLoader)[source]

Create optimizer and learning rate scheduler

generate_model_card() str[source]

Generate a model card for the SAE model

load_checkpoint(model: Module, optimizer: Optimizer | None = None, scheduler=None, checkpoint_path: str | None = None, load_best: bool = False)[source]

Load model checkpoint from safetensors format

push_model_to_hub(complete_model_dir: Path)[source]

Push the complete model to Hugging Face Hub

save_checkpoint(model: Module, optimizer: Optimizer, scheduler, epoch: int, is_best: bool = False)[source]

Save model checkpoint using safetensors for model weights

save_complete_model(model: Module, epoch: int | None = None)[source]

Save the complete model in a ‘complete’ folder ready for Hugging Face upload. This includes model.safetensors, config.json, and training_state.pt

train()[source]

Main training method

Returns:

Trained SAE model

train_step(model: Module, batch: Tensor) Dict[str, float][source]

Single training step - returns both scalars and tensors

class src.sae.SAETrainingConfig(num_tokens: int | None = None, token_dim: int | None = None, expansion_factor: float = 1, activation_fn: str = 'relu', use_token_sampling: bool = True, fixed_tokens: list = <factory>, sampling_strategy: str = 'block_average', sampling_stride: int = 8, max_sampled_tokens: int = 200, block_size: int = 8, batch_size: int = 128, learning_rate: float = 0.0001, num_epochs: int = 20, validation_split: float = 0.1, l1_penalty: float = 0.3, optimizer: str = 'adam', weight_decay: float = 1e-05, lr_schedule: str = 'constant', warmup_epochs: int = 2, gradient_clip_norm: float = 1.0, early_stopping_patience: int = 10, early_stopping_min_delta: float = 1e-05, log_every: int = 5, save_every: int = 1000, validate_every: int = 500, device: str = 'cpu')[source]

Bases: object

Configuration for SAE training

activation_fn: str = 'relu'
batch_size: int = 128
block_size: int = 8
device: str = 'cpu'
early_stopping_min_delta: float = 1e-05
early_stopping_patience: int = 10
expansion_factor: float = 1
property feature_dim: int | None

Calculate feature dimension based on expansion factor

fixed_tokens: list
gradient_clip_norm: float = 1.0
infer_model_params_from_cache(cache_path: str, token_sampler_config=None)[source]

Infer model parameters from cached activation data.

Args:

cache_path: Path to cached activation data token_sampler_config: TokenSamplerConfig object for token sampling.

Returns:

Self for method chaining

l1_penalty: float = 0.3
learning_rate: float = 0.0001
log_every: int = 5
lr_schedule: str = 'constant'
max_sampled_tokens: int = 200
num_epochs: int = 20
num_tokens: int | None = None
optimizer: str = 'adam'
sampling_strategy: str = 'block_average'
sampling_stride: int = 8
save_every: int = 1000
token_dim: int | None = None
use_token_sampling: bool = True
validate_every: int = 500
validation_split: float = 0.1
warmup_epochs: int = 2
weight_decay: float = 1e-05
class src.sae.TokenSampler(config: TokenSamplerConfig, total_tokens: int)[source]

Bases: object

Handles consistent token sampling strategies

get_sampling_info() Dict[str, Any][source]

Get information about the sampling configuration

sample_tokens(activations: Tensor) Tensor[source]

Sample tokens from activation tensor

Args:

activations: Tensor of shape (batch_size, num_tokens, token_dim)

Returns:

Sampled activations of shape (batch_size, num_sampled_tokens, token_dim)

class src.sae.TokenSamplerConfig(fixed_tokens: List[int] = None, sampling_strategy: str = 'block_average', sampling_stride: int = 8, max_sampled_tokens: int = 100, random_seed: int = 42, block_size: int = 8)[source]

Bases: object

block_size: int = 8
fixed_tokens: List[int] = None
max_sampled_tokens: int = 100
random_seed: int = 42
sampling_strategy: str = 'block_average'
sampling_stride: int = 8
src.sae.load_original_num_tokens_from_cache(cache_path: str) int | None[source]

Load the original number of tokens from cached activation metadata.

Args:

cache_path: Path to the activation cache directory

Returns:

Original number of tokens if found, None otherwise

src.sae.load_sae_model(model_path: str, config_path: str | None = None, device: str = 'cuda')[source]

Standalone function to load a trained SAE model from safetensors checkpoint

Args:

model_path: Path to the safetensors model file config_path: Optional path to config.json file. If None, tries to find it automatically device: Device to load the model on

Returns:

Loaded SAE model

src.sae.load_sae_model_simple(experiment_path: str, checkpoint: str = 'best', device: str = 'cuda') MultiModalSAE[source]

Simple convenience function to load SAE model from experiment directory.

Args:

experiment_path: Path to experiment directory checkpoint: Which checkpoint to load device: Device to load on

Returns:

Loaded SAE model