Source code for src.sae.config
from dataclasses import dataclass, field
from typing import Optional
import torch
import logging
[docs]
@dataclass
class SAETrainingConfig:
"""Configuration for SAE training"""
# Model config - these will be auto-inferred
num_tokens: Optional[int] = None # Auto-inferred from token sampling config
token_dim: Optional[int] = None # Auto-inferred from ACT model
expansion_factor: float = 1 # Feature expansion factor (feature_dim = num_tokens * token_dim * expansion_factor)
activation_fn: str = 'relu' # 'tanh', 'relu', 'leaky_relu'
# Token sampling config - affects num_tokens calculation
use_token_sampling: bool = True
fixed_tokens: list = field(default_factory=lambda: [0, 1]) # VAE latent + proprioception tokens
sampling_strategy: str = "block_average" # "uniform", "stride", "random_fixed", "block_average"
sampling_stride: int = 8
max_sampled_tokens: int = 200
block_size: int = 8
# Training config
batch_size: int = 128
learning_rate: float = 1e-4
num_epochs: int = 20
validation_split: float = 0.1
# Loss config
l1_penalty: float = 0.3
# Optimization config
optimizer: str = 'adam' # 'adam', 'adamw', 'sgd'
weight_decay: float = 1e-5
lr_schedule: str = 'constant' # 'cosine', 'linear', 'constant'
warmup_epochs: int = 2
gradient_clip_norm: float = 1.0
# Early stopping
early_stopping_patience: int = 10
early_stopping_min_delta: float = 1e-5
# Logging and saving
log_every: int = 5
save_every: int = 1000
validate_every: int = 500
# Device
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
@property
def feature_dim(self) -> Optional[int]:
"""Calculate feature dimension based on expansion factor"""
if self.num_tokens is not None and self.token_dim is not None:
return int(self.num_tokens * self.token_dim * self.expansion_factor)
return None
def _infer_original_num_tokens(self, policy) -> int:
"""
Infer the original number of tokens from the ACT policy model.
The total number of tokens in ACT models is calculated as:
- 2 fixed tokens (VAE latent + proprioception)
- Plus tokens from each camera image (width/32 × height/32 for each camera)
Args:
policy: The ACT policy model
Returns:
Original number of tokens in the model
"""
original_num_tokens = None
# Method 1: Try to infer from model configuration
if hasattr(policy, 'config'):
config = policy.config
# Check if config has image features information
if hasattr(config, 'image_features') and config.image_features:
# Calculate tokens from image dimensions
# ACT models typically have 2 fixed tokens (VAE + proprioception)
fixed_tokens = 2
image_tokens = len(config.image_features)
original_num_tokens = fixed_tokens + (image_tokens * 300)
# Method 3: Use established default for ACT models if nothing else worked
if original_num_tokens is None:
# For common ACT setups with 2 cameras at 480x640 resolution
original_num_tokens = 602 # 2 + 2*((480/32) * (640/32)) = 2 + 2*(15*20) = 602
logging.warning(f"Using default ACT token count: {original_num_tokens} (configure your model for automatic detection)")
return original_num_tokens
[docs]
def infer_model_params_from_cache(self, cache_path: str, token_sampler_config=None):
"""
Infer model parameters from cached activation data.
Args:
cache_path: Path to cached activation data
token_sampler_config: TokenSamplerConfig object for token sampling.
Returns:
Self for method chaining
"""
# Load original_num_tokens from cache metadata
from .activation_collector import load_original_num_tokens_from_cache
original_num_tokens = load_original_num_tokens_from_cache(cache_path)
if original_num_tokens is None:
raise ValueError(f"Could not load original_num_tokens from cache at {cache_path}")
# Infer token_dim from activation shape in cache metadata
import json
from pathlib import Path
metadata_file = Path(cache_path) / "metadata.json"
if metadata_file.exists():
with open(metadata_file, 'r') as f:
metadata = json.load(f)
activation_shape = metadata.get('activation_shape')
if activation_shape and len(activation_shape) == 2:
# activation_shape is [num_tokens, token_dim] after excluding batch dimension
self.token_dim = activation_shape[1]
logging.info(f"Inferred token_dim from cache metadata: {self.token_dim}")
if self.token_dim is None:
raise ValueError("Could not infer token_dim from cache metadata")
# Calculate num_tokens based on token sampling
if token_sampler_config is not None:
from src.sae import TokenSampler
sampler = TokenSampler(token_sampler_config, total_tokens=original_num_tokens)
sampling_info = sampler.get_sampling_info()
if sampling_info['use_token_sampling']:
self.num_tokens = sampling_info['num_sampled_tokens']
else:
self.num_tokens = original_num_tokens
else:
# No token sampling, use the fixed tokens only
self.num_tokens = len(self.fixed_tokens)
logging.info(f"Inferred model parameters from cache:")
logging.info(f" original_num_tokens: {original_num_tokens}")
logging.info(f" token_dim: {self.token_dim}")
logging.info(f" num_tokens (after sampling): {self.num_tokens}")
logging.info(f" feature_dim: {self.feature_dim}")
logging.info(f" expansion_factor: {self.expansion_factor}")
return self
def infer_model_params(self, policy, token_sampler_config=None):
"""
Infer num_tokens and token_dim from the policy model and token sampling configuration.
Args:
policy: The ACT policy model
token_sampler_config: TokenSamplerConfig object for token sampling.
:meta private:
"""
# Infer token_dim from the model
# Typically this is the hidden dimension of the encoder
if hasattr(policy, 'model') and hasattr(policy.model, 'encoder'):
if hasattr(policy.model.encoder, 'layers') and len(policy.model.encoder.layers) > 0:
# Get from the first layer's dimension
first_layer = policy.model.encoder.layers[0]
if hasattr(first_layer, 'norm1') and hasattr(first_layer.norm1, 'normalized_shape'):
self.token_dim = first_layer.norm1.normalized_shape[0]
elif hasattr(first_layer, 'self_attn') and hasattr(first_layer.self_attn, 'embed_dim'):
self.token_dim = first_layer.self_attn.embed_dim
elif hasattr(first_layer, 'norm2') and hasattr(first_layer.norm2, 'normalized_shape'):
self.token_dim = first_layer.norm2.normalized_shape[0]
# If we couldn't infer from the model structure, try a different approach
if self.token_dim is None and hasattr(policy, 'config'):
if hasattr(policy.config, 'dim_model'):
self.token_dim = policy.config.dim_model
elif hasattr(policy.config, 'hidden_size'):
self.token_dim = policy.config.hidden_size
# Infer original_num_tokens from the model itself
original_num_tokens = self._infer_original_num_tokens(policy)
# Final fallback to default
if original_num_tokens is None:
original_num_tokens = 602
logging.warning("Using hardcoded default of 602 tokens. Consider configuring your model properly.")
# Infer num_tokens based on token sampling
if token_sampler_config is not None:
from src.sae import TokenSampler
sampler = TokenSampler(token_sampler_config, total_tokens=original_num_tokens)
sampling_info = sampler.get_sampling_info()
if sampling_info['use_token_sampling']:
self.num_tokens = sampling_info['num_sampled_tokens']
else:
self.num_tokens = original_num_tokens
else:
# No token sampling, use default
self.num_tokens = len(self.fixed_tokens) # Just the fixed tokens
# Validate that we have the required values
if self.token_dim is None:
raise ValueError("Could not infer token_dim from the model. Please set it manually.")
if self.num_tokens is None:
raise ValueError("Could not infer num_tokens. Please check token sampling configuration.")
logging.info(f"Auto-inferred model parameters:")
logging.info(f" original_num_tokens: {original_num_tokens}")
logging.info(f" token_dim: {self.token_dim}")
logging.info(f" num_tokens (after sampling): {self.num_tokens}")
logging.info(f" feature_dim: {self.feature_dim}")
logging.info(f" expansion_factor: {self.expansion_factor}")
return self