Source code for src.sae.builder

#!/usr/bin/env python

import json
import logging
from pathlib import Path
from typing import Optional, Dict, Any

import torch
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

from .sae import MultiModalSAE, create_multimodal_sae


[docs] class SAEBuilder: """ Builder class for loading SAE models with their configurations. Provides convenient methods to load trained SAE models from standard paths. """ def __init__(self, device: str = 'cuda'): self.device = device
[docs] def load_from_experiment( self, experiment_path: str, checkpoint: str = 'latest', config_filename: str = 'config.json' ) -> MultiModalSAE: """ Load SAE model from experiment directory. Args: experiment_path: Path to experiment directory (e.g., "output/sae_drop_footbag_into_di_838a8c8b") checkpoint: Which checkpoint to load - 'best', 'latest', or specific epoch number config_filename: Name of config file Returns: Loaded SAE model """ experiment_path = Path(experiment_path) config_path = experiment_path / config_filename if not config_path.exists(): raise FileNotFoundError(f"Config file not found: {config_path}") # Determine model path based on checkpoint specification if checkpoint == 'best': model_path = experiment_path / "best_model.safetensors" elif checkpoint == 'latest': # Find latest epoch checkpoint model_files = list(experiment_path.glob("model_epoch_*.safetensors")) if not model_files: raise FileNotFoundError(f"No model checkpoints found in {experiment_path}") latest_model = max(model_files, key=lambda x: int(x.stem.split('_')[-1])) model_path = latest_model else: # Specific epoch or filename try: epoch_num = int(checkpoint) model_path = experiment_path / f"model_epoch_{epoch_num}.safetensors" except ValueError: # Treat as filename model_path = experiment_path / checkpoint if not model_path.exists(): raise FileNotFoundError(f"Model file not found: {model_path}") return self.load_from_files(str(model_path), str(config_path))
[docs] def load_from_files( self, model_path: str, config_path: str ) -> MultiModalSAE: """ Load SAE model from specific model and config files. Args: model_path: Path to safetensors model file config_path: Path to config.json file Returns: Loaded SAE model """ model_path = Path(model_path) config_path = Path(config_path) if not model_path.exists(): raise FileNotFoundError(f"Model file not found: {model_path}") if not config_path.exists(): raise FileNotFoundError(f"Config file not found: {config_path}") # Load config with open(config_path, 'r') as f: config_dict = json.load(f) # Extract model parameters num_tokens = config_dict.get('num_tokens') token_dim = config_dict.get('token_dim') feature_dim = config_dict.get('feature_dim') activation_fn = config_dict.get('activation_fn', 'relu') use_bfloat16 = config_dict.get('use_bfloat16', False) if any(param is None for param in [num_tokens, token_dim]): raise ValueError("Config file missing required parameters: num_tokens, token_dim") # Calculate feature_dim if not present if feature_dim is None: expansion_factor = config_dict.get('expansion_factor', 1.25) feature_dim = int(num_tokens * token_dim * expansion_factor) # Create model model = create_multimodal_sae( num_tokens=num_tokens, token_dim=token_dim, feature_dim=feature_dim, device=self.device, use_bfloat16=use_bfloat16 ) # Load weights from safetensors model_state = load_file(model_path) model.load_state_dict(model_state) model.eval() logging.info(f"Loaded SAE model from {model_path}") logging.info(f"Model config: {num_tokens} tokens, {token_dim} dim, {feature_dim} features") return model
[docs] def load_with_auto_config( self, model_path: str, config_path: Optional[str] = None ) -> MultiModalSAE: """ Load SAE model with automatic config discovery. Args: model_path: Path to safetensors model file config_path: Optional path to config.json. If None, searches automatically Returns: Loaded SAE model """ model_path = Path(model_path) # Auto-discover config if not provided if config_path is None: potential_configs = [ model_path.parent / "config.json", model_path.parent.parent / "config.json", ] for potential_config in potential_configs: if potential_config.exists(): config_path = potential_config break if config_path is None: raise FileNotFoundError("Could not find config.json file. Please provide config_path explicitly.") return self.load_from_files(str(model_path), str(config_path))
[docs] def load_from_hub( self, repo_id: str, filename: str = "model.safetensors", config_filename: str = "config.json", revision: str = "main", cache_dir: Optional[str] = None, force_download: bool = False, token: Optional[str] = None ) -> MultiModalSAE: """ Load SAE model from Hugging Face Hub. Args: repo_id: Repository ID on Hugging Face Hub filename: Model filename to download config_filename: Config filename to download revision: Git revision (branch/tag/commit) cache_dir: Local cache directory force_download: Force re-download even if cached token: Hugging Face token for private repos Returns: Loaded SAE model """ # Download model file model_file = hf_hub_download( repo_id=repo_id, filename=filename, revision=revision, cache_dir=cache_dir, force_download=force_download, token=token, ) # Download config file config_file = hf_hub_download( repo_id=repo_id, filename=config_filename, revision=revision, cache_dir=cache_dir, force_download=force_download, token=token, ) return self.load_from_files(model_file, config_file)
[docs] @classmethod def from_default_path( cls, experiment_name: str, base_output_dir: str = "output", device: str = 'cuda' ) -> 'SAEBuilder': """ Create SAEBuilder and load model from default output structure. Args: experiment_name: Name of experiment (e.g., "sae_drop_footbag_into_di_838a8c8b") base_output_dir: Base output directory device: Device to load model on Returns: SAEBuilder instance with loaded model """ builder = cls(device=device) experiment_path = Path(base_output_dir) / experiment_name return builder.load_from_experiment(str(experiment_path))
[docs] def load_sae_model_simple( experiment_path: str, checkpoint: str = 'best', device: str = 'cuda' ) -> MultiModalSAE: """ Simple convenience function to load SAE model from experiment directory. Args: experiment_path: Path to experiment directory checkpoint: Which checkpoint to load device: Device to load on Returns: Loaded SAE model """ builder = SAEBuilder(device=device) return builder.load_from_experiment(experiment_path, checkpoint)