Source code for src.sae.sae

#!/usr/bin/env python

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple


[docs] class MultiModalSAE(nn.Module): """ Sparse Autoencoder that processes all tokens from ACT model simultaneously. num_tokens (int): should match the number of tokens of your ACT model, or the number of sampled tokens when using sampling token_dim (int): should match the dim_model hyperparam in your ACT model feature_dim (int): the number of features the SAE should learn to represent. Usually this would be (num_tokens * token_dim * expansion_factor) """ def __init__( self, num_tokens: int, token_dim: int = 512, feature_dim: int = 4096, use_bias: bool = True, activation_fn: str = 'leaky_relu', dropout_rate: float = 0.0, device: str = 'cuda', use_bfloat16: bool = False, ): super().__init__() self.num_tokens = num_tokens self.token_dim = token_dim self.feature_dim = feature_dim self.input_dim = num_tokens * token_dim self.device = device self.use_bfloat16 = use_bfloat16 # Set the default dtype for the model self.model_dtype = torch.bfloat16 if use_bfloat16 else torch.float32 # Encoder: compress all tokens to feature space self.encoder = nn.Linear(self.input_dim, feature_dim, bias=use_bias, dtype=self.model_dtype) # Decoder: reconstruct all tokens from features self.decoder = nn.Linear(feature_dim, self.input_dim, bias=use_bias, dtype=self.model_dtype) # Activation function - these work naturally with bfloat16 if activation_fn == 'tanh': self.activation = nn.Tanh() elif activation_fn == 'relu': self.activation = nn.ReLU() elif activation_fn == 'leaky_relu': self.activation = nn.LeakyReLU(0.1) else: raise ValueError(f"Unknown activation: {activation_fn}") # Optional dropout self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity() self._init_weights() # Move model to device and convert to bfloat16 if requested self.to(device) if use_bfloat16: self.to(dtype=torch.bfloat16) def _init_weights(self): """Initialize weights using Xavier uniform - works with any dtype""" with torch.no_grad(): # Initialize in float32 first for numerical stability nn.init.xavier_uniform_(self.encoder.weight.float()) nn.init.xavier_uniform_(self.decoder.weight.float()) if self.encoder.bias is not None: nn.init.zeros_(self.encoder.bias.float()) if self.decoder.bias is not None: nn.init.zeros_(self.decoder.bias.float()) # Convert back to target dtype if needed if self.use_bfloat16: self.encoder.weight.data = self.encoder.weight.data.to(torch.bfloat16) self.decoder.weight.data = self.decoder.weight.data.to(torch.bfloat16) if self.encoder.bias is not None: self.encoder.bias.data = self.encoder.bias.data.to(torch.bfloat16) if self.decoder.bias is not None: self.decoder.bias.data = self.decoder.bias.data.to(torch.bfloat16) def _ensure_correct_dtype(self, x: torch.Tensor) -> torch.Tensor: """Helper function to ensure input tensors match model dtype""" if self.use_bfloat16 and x.dtype != torch.bfloat16: return x.to(dtype=torch.bfloat16) elif not self.use_bfloat16 and x.dtype != torch.float32: return x.to(dtype=torch.float32) return x
[docs] def encode(self, x: torch.Tensor) -> torch.Tensor: """ Encode flattened token representation to feature space Now handles bfloat16 conversion automatically """ x = self._ensure_correct_dtype(x) features = self.encoder(x) features = self.activation(features) features = self.dropout(features) return features
[docs] def decode(self, features: torch.Tensor) -> torch.Tensor: """ Decode features back to token space """ features = self._ensure_correct_dtype(features) reconstruction = self.decoder(features) return reconstruction
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Full forward pass: encode then decode """ # Handle both shaped and flattened input if len(x.shape) == 3: # (batch_size, num_tokens, token_dim) batch_size = x.shape[0] x_flat = x.view(batch_size, -1) else: # Already flattened batch_size = x.shape[0] x_flat = x # Ensure correct dtype x_flat = self._ensure_correct_dtype(x_flat) # Encode to features features = self.encode(x_flat) # Decode back to token space reconstruction_flat = self.decode(features) # Reshape reconstruction to match input if len(x.shape) == 3: reconstruction = reconstruction_flat.view(batch_size, self.num_tokens, self.token_dim) else: reconstruction = reconstruction_flat return reconstruction, features
[docs] def compute_loss( self, x: torch.Tensor, l1_penalty: float = 0.0, ) -> Dict[str, torch.Tensor]: """ Compute loss with optional regularization Note: Loss computation often benefits from float32 precision for stability """ reconstruction, features = self.forward(x) # Handle input shape for loss computation if len(x.shape) == 3: x_flat = x.view(x.shape[0], -1) reconstruction_flat = reconstruction.view(reconstruction.shape[0], -1) else: x_flat = x reconstruction_flat = reconstruction x_flat_loss = x_flat reconstruction_flat_loss = reconstruction_flat features_loss = features # Reconstruction loss (MSE) mse_loss = F.mse_loss(reconstruction_flat_loss, x_flat_loss, reduction='mean') # L1 penalty on features (sparsity) l1_loss = torch.mean(torch.abs(features_loss)) if l1_penalty > 0 else torch.tensor(0.0, device=x.device) # Total loss total_loss = (mse_loss + l1_penalty * l1_loss) # Compute metrics for monitoring (in float32 for accuracy) with torch.no_grad(): feature_mean = features_loss.mean() feature_std = features_loss.std() feature_sparsity = (torch.abs(features_loss) < 0.1).float().mean() # Reconstruction quality (R²) ss_res = torch.sum((x_flat_loss - reconstruction_flat_loss) ** 2) ss_tot = torch.sum((x_flat_loss - x_flat_loss.mean()) ** 2) r_squared = 1 - (ss_res / (ss_tot + 1e-8)) return { 'total_loss': total_loss, 'mse_loss': mse_loss, 'l1_loss': l1_loss * l1_penalty, 'feature_mean': feature_mean, 'feature_std': feature_std, 'feature_sparsity': feature_sparsity, 'r_squared': r_squared }
def create_multimodal_sae( num_tokens: int = None, token_dim: int = 512, feature_dim: int = 1024, device: str = 'cuda', use_bfloat16: bool = False ) -> nn.Module: """ Factory function to SAE models with bfloat16 support. Args: num_tokens: Number of tokens in the model. If None, will raise an error. token_dim: Dimension of each token feature_dim: Dimension of the feature space device: Device to place the model on use_bfloat16: Whether to use bfloat16 precision :meta private: """ if num_tokens is None: raise ValueError( "num_tokens must be specified. Use SAETrainingConfig.infer_model_params() " "to automatically infer this value from your ACT policy model." ) return MultiModalSAE( num_tokens=num_tokens, token_dim=token_dim, feature_dim=feature_dim, device=device, use_bfloat16=use_bfloat16 ) # Utility function for converting input data to bfloat16 def prepare_batch_for_bfloat16(batch: torch.Tensor, device: str = 'cuda') -> torch.Tensor: """ Convert a batch of activations to bfloat16 and move to device Args: batch: Input tensor (usually float32) device: Target device Returns: Tensor converted to bfloat16 on target device :meta private: """ return batch.to(device=device, dtype=torch.bfloat16)