import json
import logging
from pathlib import Path
from typing import Optional, Dict
from tempfile import TemporaryDirectory
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from safetensors.torch import save_file, load_file
from huggingface_hub import HfApi, hf_hub_download
from lerobot.policies.factory import make_policy
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from src.utils import make_dataset_without_config, get_repo_hash
from src.utils.naming import get_experiment_name, get_cache_name
from .config import SAETrainingConfig
from .token_sampler import TokenSamplerConfig
from .activation_collector import (
collect_and_cache_activations,
create_cached_dataloader,
is_cache_valid,
cleanup_invalid_cache,
get_cache_status,
)
from .sae import create_multimodal_sae
[docs]
class SAETrainer():
def __init__(
self,
repo_id: str,
policy_path: Path,
batch_size: int = 16,
num_workers: int = 4,
output_directory: Path = "output",
resume_checkpoint: Optional[Path] = None,
activation_cache_path: str = str(Path.home() / ".cache" / "physical_ai_interpretability" / "sae_activations"),
force_cache_refresh: bool = False,
use_wandb: bool = False,
wandb_project_name: str = "physical_ai_interpretability",
sae_config: Optional[SAETrainingConfig] = None,
# Hugging Face integration parameters
upload_to_hub: bool = False,
hub_repo_id: Optional[str] = None,
hub_private: bool = True,
hub_license: str = "mit",
hub_tags: Optional[list] = None,
):
# Load policy
policy_cfg = PreTrainedConfig.from_pretrained(policy_path)
policy_cfg.pretrained_path = policy_path
# Load dataset with proper delta indices
# ACT models typically need current observation only
dataset = make_dataset_without_config(
repo_id=repo_id,
action_delta_indices=list(range(policy_cfg.chunk_size)),
)
# Create dataloader
self.dataloader = DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
drop_last=False,
)
ds_meta = dataset.meta if isinstance(dataset, LeRobotDataset) else dataset._datasets[0].meta
self.policy = make_policy(policy_cfg, ds_meta=ds_meta)
# Use provided config or create default
self.config = sae_config if sae_config is not None else SAETrainingConfig()
# Create token sampler config from SAE config
token_sampler_config = TokenSamplerConfig(
fixed_tokens=self.config.fixed_tokens,
sampling_strategy=self.config.sampling_strategy,
sampling_stride=self.config.sampling_stride,
max_sampled_tokens=self.config.max_sampled_tokens,
block_size=self.config.block_size
) if self.config.use_token_sampling else None
# Auto-infer model parameters - will be updated later if using cached activations
self.config.infer_model_params(self.policy, token_sampler_config)
self._token_sampler_config = token_sampler_config # Store for potential cache-based inference
# Determine layer name from policy - pick last layer in encoder
self.layer_name = "model.encoder.layers.3.norm2" # Default layer
if hasattr(self.policy, 'model') and hasattr(self.policy.model, 'encoder'):
if hasattr(self.policy.model.encoder, 'layers') and len(self.policy.model.encoder.layers) > 0:
# Use the last layer's norm2 by default
layer_idx = len(self.policy.model.encoder.layers) - 1
self.layer_name = f"model.encoder.layers.{layer_idx}.norm2"
# Store initialization parameters
self.repo_id = repo_id
self.output_directory = Path(output_directory)
self.resume_checkpoint = resume_checkpoint
self.force_cache_refresh = force_cache_refresh
self.activation_cache_path = activation_cache_path
self.use_wandb = use_wandb
self.wandb_project_name = wandb_project_name
# Store Hugging Face parameters
self.upload_to_hub = upload_to_hub
self.hub_repo_id = hub_repo_id
self.hub_private = hub_private
self.hub_license = hub_license
self.hub_tags = hub_tags or ["sae", "sparse-autoencoder", "robotics", "out-of-distribution"]
# Store token sampler config for activation collection
self.token_sampler_config = token_sampler_config
# Initialize wandb
if use_wandb:
experiment_name = get_experiment_name(repo_id, prefix="sae")
self.wandb = wandb.init(
project=wandb_project_name,
name=experiment_name,
config={
'repo_id': repo_id,
'repo_hash': get_repo_hash(repo_id),
'layer_name': self.layer_name,
'num_tokens': self.config.num_tokens,
'token_dim': self.config.token_dim,
'feature_dim': self.config.feature_dim,
'expansion_factor': self.config.expansion_factor,
**self.config.__dict__
}
)
else:
self.wandb = None
[docs]
def collect_activations(self):
"""Collect activations and return cached dataloader with resumption support"""
cache_path = Path(self.activation_cache_path) / get_cache_name(self.repo_id)
# Get detailed cache status
cache_status = get_cache_status(str(cache_path))
logging.info(f"Cache status: {cache_status['status']}")
if cache_status['exists']:
if cache_status['status'] == 'completed':
logging.info(f"Found completed cache with {cache_status['total_samples']} samples")
elif cache_status['status'] == 'in_progress' and cache_status['can_resume']:
logging.info(f"Found resumable cache with {cache_status['total_samples']} samples")
logging.info(f"Last batch: {cache_status['last_batch_idx']}, Interruptions: {cache_status['interruption_count']}")
elif cache_status['status'] == 'in_progress':
logging.warning("Found incomplete cache but it's not resumable")
# Check if we should use existing cache
use_existing_cache = (
cache_status['status'] == 'completed' and not self.force_cache_refresh
)
if use_existing_cache:
logging.info(f"Using existing valid activation cache at {cache_path}")
try:
# Update config with parameters from cache metadata
logging.info("Updating model parameters from cached activation data...")
self.config.infer_model_params_from_cache(str(cache_path), self._token_sampler_config)
return create_cached_dataloader(
cache_dir=str(cache_path),
batch_size=self.config.batch_size,
shuffle=True,
preload_buffers=2
)
except Exception as e:
logging.warning(f"Failed to load existing cache: {e}")
logging.info("Will recreate cache from scratch")
# Fall through to cache recreation
# Handle invalid or missing cache
if cache_path.exists():
if not is_cache_valid(str(cache_path)):
logging.info(f"Cache at {cache_path} is invalid or incomplete, cleaning up")
cleanup_invalid_cache(str(cache_path))
elif self.force_cache_refresh:
logging.info(f"Force refresh requested, cleaning existing cache at {cache_path}")
cleanup_invalid_cache(str(cache_path))
logging.info(f"Collecting new activations to {cache_path}")
# Collect activations with token sampling
try:
cache_dir_str = collect_and_cache_activations(
act_model=self.policy,
dataloader=self.dataloader,
layer_name=self.layer_name,
cache_dir=str(cache_path.parent),
experiment_name=cache_path.name,
device=self.config.device,
cleanup_on_start=self.force_cache_refresh, # Force clean start if requested
use_token_sampling=self.config.use_token_sampling,
fixed_tokens=self.config.fixed_tokens if self.config.use_token_sampling else None,
sampling_strategy=self.config.sampling_strategy if self.config.use_token_sampling else "uniform",
sampling_stride=self.config.sampling_stride,
max_sampled_tokens=self.config.max_sampled_tokens,
block_size=self.config.block_size
)
except Exception as e:
# Clean up any partial cache that might have been created
if cache_path.exists():
logging.warning(f"Activation collection failed, cleaning up partial cache: {e}")
cleanup_invalid_cache(str(cache_path))
raise
# Verify the cache was created successfully
if not is_cache_valid(cache_dir_str):
logging.error(f"Cache creation completed but validation failed for {cache_dir_str}")
cleanup_invalid_cache(cache_dir_str)
raise RuntimeError("Failed to create valid activation cache")
# Update config with actual parameters from the newly created cache
logging.info("Updating model parameters from newly created cache data...")
self.config.infer_model_params_from_cache(cache_dir_str, self._token_sampler_config)
# Create dataloader from cached activations
return create_cached_dataloader(
cache_dir=cache_dir_str,
batch_size=self.config.batch_size,
shuffle=True,
preload_buffers=2
)
[docs]
def create_model(self) -> nn.Module:
"""Create SAE model based on config"""
model = create_multimodal_sae(
num_tokens=self.config.num_tokens,
token_dim=self.config.token_dim,
feature_dim=self.config.feature_dim,
device=self.config.device
)
return model.to(self.config.device)
[docs]
def create_optimizer_and_scheduler(self, model: nn.Module, train_loader: DataLoader):
"""Create optimizer and learning rate scheduler"""
# Optimizer
if self.config.optimizer == 'adam':
optimizer = optim.Adam(
model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
elif self.config.optimizer == 'adamw':
optimizer = optim.AdamW(
model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
elif self.config.optimizer == 'sgd':
optimizer = optim.SGD(
model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay,
momentum=0.9
)
else:
raise ValueError(f"Unknown optimizer: {self.config.optimizer}")
# Scheduler
total_steps = len(train_loader) * self.config.num_epochs
warmup_steps = len(train_loader) * self.config.warmup_epochs
if self.config.lr_schedule == 'cosine':
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)
elif self.config.lr_schedule == 'linear':
scheduler = optim.lr_scheduler.LinearLR(
optimizer,
start_factor=0.1,
total_iters=warmup_steps
)
else:
scheduler = optim.lr_scheduler.ConstantLR(optimizer)
return optimizer, scheduler
[docs]
def train_step(self, model: nn.Module, batch: torch.Tensor) -> Dict[str, float]:
"""Single training step - returns both scalars and tensors"""
batch = batch.to(self.config.device)
# Forward pass - keep the tensor version for backprop
loss_dict_tensors = model.compute_loss(
batch,
l1_penalty=self.config.l1_penalty,
)
# Convert to scalars for logging, but keep tensor versions
loss_dict_scalars = {k: v.item() if torch.is_tensor(v) else v
for k, v in loss_dict_tensors.items()}
return loss_dict_scalars, loss_dict_tensors
[docs]
def save_checkpoint(self, model: nn.Module, optimizer: optim.Optimizer,
scheduler, epoch: int, is_best: bool = False):
"""Save model checkpoint using safetensors for model weights"""
# Save model weights with safetensors (secure and efficient)
model_path = self.final_output_directory / f"model_epoch_{epoch}.safetensors"
save_file(model.state_dict(), model_path)
# Save training state with torch.save (optimizer/scheduler states need pickle)
training_state = {
'epoch': epoch,
'global_step': getattr(self, 'global_step', 0),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_val_loss': getattr(self, 'best_val_loss', float('inf')),
'config': self.config.__dict__,
'model_path': str(model_path) # Reference to the safetensors model file
}
state_path = self.final_output_directory / f"training_state_epoch_{epoch}.pt"
torch.save(training_state, state_path)
logging.info(f"Saved checkpoint - Model: {model_path.name}, State: {state_path.name}")
# Save best checkpoint
if is_best:
best_model_path = self.final_output_directory / "best_model.safetensors"
best_state_path = self.final_output_directory / "best_training_state.pt"
# Copy current best to dedicated best files
save_file(model.state_dict(), best_model_path)
training_state['model_path'] = str(best_model_path)
torch.save(training_state, best_state_path)
logging.info(f"Saved best checkpoint at epoch {epoch} - Model: {best_model_path.name}")
[docs]
def save_complete_model(self, model: nn.Module, epoch: int = None):
"""
Save the complete model in a 'complete' folder ready for Hugging Face upload.
This includes model.safetensors, config.json, and training_state.pt
"""
# Create complete folder
complete_dir = self.final_output_directory / "complete"
complete_dir.mkdir(exist_ok=True)
# Save model weights as model.safetensors (standard HF naming)
model_path = complete_dir / "model.safetensors"
save_file(model.state_dict(), model_path)
# Copy or create config.json
source_config = self.final_output_directory / "config.json"
dest_config = complete_dir / "config.json"
if source_config.exists():
# Copy existing config
import shutil
shutil.copy2(source_config, dest_config)
else:
# Create minimal config
config_dict = self.config.__dict__.copy()
config_dict.update({
'repo_id': self.repo_id,
'repo_hash': get_repo_hash(self.repo_id),
'layer_name': self.layer_name,
'experiment_name': getattr(self, 'experiment_name', 'unknown')
})
with open(dest_config, 'w') as f:
json.dump(config_dict, f, indent=2)
# Save training state (if available)
training_state_path = complete_dir / "training_state.pt"
if epoch is not None:
# Copy the training state from specific epoch
source_state = self.final_output_directory / f"training_state_epoch_{epoch}.pt"
else:
# Use best training state if available
source_state = self.final_output_directory / "best_training_state.pt"
if not source_state.exists():
# Find latest training state
state_files = list(self.final_output_directory.glob("training_state_epoch_*.pt"))
if state_files:
source_state = max(state_files, key=lambda x: int(x.stem.split('_')[-1]))
if source_state.exists():
import shutil
shutil.copy2(source_state, training_state_path)
logging.info(f"Saved complete model to: {complete_dir}")
return complete_dir
[docs]
def push_model_to_hub(self, complete_model_dir: Path):
"""
Push the complete model to Hugging Face Hub
"""
if not self.upload_to_hub:
logging.info("Hub upload disabled, skipping...")
return None
if not self.hub_repo_id:
raise ValueError("hub_repo_id must be specified to upload to Hub")
api = HfApi()
# Create repo
repo_info = api.create_repo(
repo_id=self.hub_repo_id,
private=self.hub_private,
exist_ok=True
)
logging.info(f"Created/accessed Hub repo: {repo_info.repo_id}")
# Generate model card
readme_content = self.generate_model_card()
readme_path = complete_model_dir / "README.md"
with open(readme_path, 'w') as f:
f.write(readme_content)
# Upload folder
commit_info = api.upload_folder(
repo_id=repo_info.repo_id,
repo_type="model",
folder_path=complete_model_dir,
commit_message="Upload SAE model weights, config, and training state",
allow_patterns=["*.safetensors", "*.json", "*.pt", "*.md"],
ignore_patterns=["*.tmp", "*.log", "__pycache__/*"],
)
logging.info(f"Model pushed to Hub: {commit_info.repo_url.url}")
return commit_info
[docs]
def generate_model_card(self) -> str:
"""Generate a model card for the SAE model"""
# Generate YAML frontmatter
yaml_tags = '\n'.join([f'- {tag}' for tag in self.hub_tags])
card_content = f"""---
license: {self.hub_license}
tags:
{yaml_tags}
datasets:
- {self.repo_id}
library_name: physical-ai-interpretability
---
# Sparse Autoencoder (SAE) Model
This model is a Sparse Autoencoder trained for interpretability analysis of robotics policies using the LeRobot framework.
## Model Details
- **Architecture**: Multi-modal Sparse Autoencoder
- **Training Dataset**: `{self.repo_id}`
- **Base Policy**: LeRobot ACT policy
- **Layer Target**: `{self.layer_name}`
- **Tokens**: {self.config.num_tokens}
- **Token Dimension**: {self.config.token_dim}
- **Feature Dimension**: {self.config.feature_dim}
- **Expansion Factor**: {self.config.expansion_factor}
## Training Configuration
- **Learning Rate**: {self.config.learning_rate}
- **Batch Size**: {self.config.batch_size}
- **L1 Penalty**: {self.config.l1_penalty}
- **Epochs**: {self.config.num_epochs}
- **Optimizer**: {self.config.optimizer}
## Usage
```python
from src.sae.trainer import load_sae_from_hub
# Load model from Hub
model = load_sae_from_hub("{self.hub_repo_id}")
# Or load using builder
from src.sae.builder import SAEBuilder
builder = SAEBuilder(device='cuda')
model = builder.load_from_hub("{self.hub_repo_id}")
```
## Out-of-Distribution Detection
This SAE model can be used for OOD detection with LeRobot policies:
```python
from src.ood import OODDetector
# Create OOD detector with Hub-loaded SAE
ood_detector = OODDetector(
policy=your_policy,
sae_hub_repo_id="{self.hub_repo_id}"
)
# Fit threshold and use for detection
ood_detector.fit_ood_threshold_to_validation_dataset(validation_dataset)
is_ood, error = ood_detector.is_out_of_distribution(observation)
```
## Files
- `model.safetensors`: The trained SAE model weights
- `config.json`: Training and model configuration
- `training_state.pt`: Complete training state (optimizer, scheduler, metrics)
- `ood_params.json`: OOD detection parameters (if fitted)
## Citation
If you use this model in your research, please cite:
```bibtex
@misc{{sae_model,
title={{Sparse Autoencoder for {self.repo_id.split('/')[-1].replace('_', ' ').title()}}},
author={{Your Name}},
year={{2024}},
url={{https://huggingface.co/{self.hub_repo_id}}}
}}
```
## Framework
This model was trained using the [physical-ai-interpretability](https://github.com/your-repo/physical-ai-interpretability) framework with [LeRobot](https://github.com/huggingface/lerobot).
"""
return card_content
[docs]
def load_checkpoint(self, model: nn.Module, optimizer: optim.Optimizer = None,
scheduler = None, checkpoint_path: str = None, load_best: bool = False):
"""Load model checkpoint from safetensors format"""
if checkpoint_path is None:
if load_best:
model_path = self.final_output_directory / "best_model.safetensors"
state_path = self.final_output_directory / "best_training_state.pt"
else:
# Find the latest checkpoint
model_files = list(self.final_output_directory.glob("model_epoch_*.safetensors"))
if not model_files:
raise FileNotFoundError("No checkpoint files found")
latest_model = max(model_files, key=lambda x: int(x.stem.split('_')[-1]))
epoch_num = latest_model.stem.split('_')[-1]
model_path = latest_model
state_path = self.final_output_directory / f"training_state_epoch_{epoch_num}.pt"
else:
# Custom checkpoint path provided
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.suffix == '.safetensors':
model_path = checkpoint_path
# Try to find corresponding training state
base_name = checkpoint_path.stem
state_path = checkpoint_path.parent / f"training_state_{base_name.replace('model_', '')}.pt"
else:
# Legacy .pt format - load differently
return self._load_legacy_checkpoint(model, optimizer, scheduler, checkpoint_path)
# Load model weights from safetensors
if model_path.exists():
model_state = load_file(model_path)
model.load_state_dict(model_state)
logging.info(f"Loaded model weights from {model_path}")
else:
raise FileNotFoundError(f"Model file not found: {model_path}")
# Load training state if available
if state_path.exists() and (optimizer is not None or scheduler is not None):
training_state = torch.load(state_path, map_location='cpu')
if optimizer is not None and 'optimizer_state_dict' in training_state:
optimizer.load_state_dict(training_state['optimizer_state_dict'])
if scheduler is not None and 'scheduler_state_dict' in training_state:
scheduler.load_state_dict(training_state['scheduler_state_dict'])
# Restore training state
if hasattr(self, 'global_step'):
self.global_step = training_state.get('global_step', 0)
if hasattr(self, 'best_val_loss'):
self.best_val_loss = training_state.get('best_val_loss', float('inf'))
logging.info(f"Loaded training state from {state_path}")
return training_state
return None
def _load_legacy_checkpoint(self, model: nn.Module, optimizer: optim.Optimizer = None,
scheduler = None, checkpoint_path: str = None):
"""Load legacy .pt checkpoint format"""
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# Restore training state
if hasattr(self, 'global_step'):
self.global_step = checkpoint.get('global_step', 0)
if hasattr(self, 'best_val_loss'):
self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
logging.info(f"Loaded legacy checkpoint from {checkpoint_path}")
return checkpoint
[docs]
def train(self):
"""
Main training method
Returns:
Trained SAE model
"""
# Step 1: Collect activations
cached_dataloader = self.collect_activations()
if cached_dataloader is None:
raise ValueError("No activations collected. Please check the cache path and try again.")
logging.info(f"Activations collected and cached")
# Step 2: Set up training
logging.info("Starting SAE training")
# Create output directory using trainer parameters
experiment_name = get_experiment_name(self.repo_id, prefix="sae")
output_dir = self.output_directory / experiment_name
output_dir.mkdir(parents=True, exist_ok=True)
self.final_output_directory = output_dir
# Save config
config_dict = self.config.__dict__.copy()
# Add trainer-specific info
config_dict.update({
'repo_id': self.repo_id,
'repo_hash': get_repo_hash(self.repo_id),
'layer_name': self.layer_name,
'activation_cache_path': self.activation_cache_path,
'experiment_name': experiment_name
})
with open(output_dir / "config.json", "w") as f:
json.dump(config_dict, f, indent=2)
# Create model
model = self.create_model()
logging.info(f"Created model with {sum(p.numel() for p in model.parameters())} parameters")
# Create optimizer and scheduler
optimizer, scheduler = self.create_optimizer_and_scheduler(model, cached_dataloader)
# Handle checkpoint resuming
start_epoch = 0
if self.resume_checkpoint is not None:
logging.info(f"Resuming training from checkpoint: {self.resume_checkpoint}")
training_state = self.load_checkpoint(model, optimizer, scheduler, str(self.resume_checkpoint))
if training_state:
start_epoch = training_state.get('epoch', 0) + 1
logging.info(f"Resuming from epoch {start_epoch}")
# Training state
best_val_loss = float('inf')
patience_counter = 0
global_step = 0
self.global_step = global_step
self.best_val_loss = best_val_loss
# Training loop
for epoch in range(start_epoch, self.config.num_epochs):
model.train()
epoch_metrics = {}
# Training
progress_bar = tqdm(cached_dataloader, desc=f"Epoch {epoch+1}/{self.config.num_epochs}")
for batch in progress_bar:
# Single forward pass that returns both scalars and tensors
loss_dict_scalars, loss_dict_tensors = self.train_step(model, batch)
# Backward pass using the tensor version
optimizer.zero_grad()
loss_dict_tensors['total_loss'].backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(
model.parameters(),
self.config.gradient_clip_norm
)
optimizer.step()
scheduler.step()
# Update metrics
for key, value in loss_dict_scalars.items():
if key not in epoch_metrics:
epoch_metrics[key] = []
epoch_metrics[key].append(value)
# Use scalar version for logging/progress bar updates
progress_bar.set_postfix({
'loss': f"{loss_dict_scalars['total_loss']:.4f}",
'lr': f"{scheduler.get_last_lr()[0]:.2e}"
})
global_step += 1
self.global_step = global_step
# Logging
if global_step % self.config.log_every == 0:
avg_metrics = {k: np.mean(v[-100:]) for k, v in epoch_metrics.items()}
logging.info(
f"Step {global_step}, Epoch {epoch+1}, "
f"Loss: {avg_metrics['total_loss']:.4f}, "
f"MSE: {avg_metrics['mse_loss']:.4f}, "
f"R²: {avg_metrics.get('r_squared', 0):.4f}"
)
if self.use_wandb and self.wandb is not None:
self.wandb.log({
f"train/{k}": v for k, v in avg_metrics.items()
}, step=global_step)
# Periodic saving
if global_step % self.config.save_every == 0:
self.save_checkpoint(model, optimizer, scheduler, epoch)
# End of epoch
avg_epoch_metrics = {k: np.mean(v) for k, v in epoch_metrics.items()}
logging.info(f"End of epoch {epoch+1}: {avg_epoch_metrics}")
# Final save
final_epoch = self.config.num_epochs - 1
self.save_checkpoint(model, optimizer, scheduler, final_epoch)
# Save complete model for potential Hub upload
complete_dir = self.save_complete_model(model, final_epoch)
# Upload to Hub if requested
if self.upload_to_hub:
try:
commit_info = self.push_model_to_hub(complete_dir)
if commit_info:
logging.info(f"Successfully uploaded model to Hub: {commit_info.repo_url.url}")
except Exception as e:
logging.error(f"Failed to upload model to Hub: {e}")
logging.info("Model training completed successfully, but Hub upload failed")
logging.info("Training completed!")
return model
[docs]
def load_sae_from_hub(
repo_id: str,
filename: str = "model.safetensors",
config_filename: str = "config.json",
revision: str = "main",
cache_dir: Optional[str] = None,
force_download: bool = False,
token: Optional[str] = None,
device: str = 'cuda'
):
"""
Load SAE model from Hugging Face Hub
Args:
repo_id: Repository ID on Hugging Face Hub
filename: Model filename to download
config_filename: Config filename to download
revision: Git revision (branch/tag/commit)
cache_dir: Local cache directory
force_download: Force re-download even if cached
token: Hugging Face token for private repos
device: Device to load model on
Returns:
Loaded SAE model
"""
from src.sae import create_multimodal_sae
# Download model file
model_file = hf_hub_download(
repo_id=repo_id,
filename=filename,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
token=token,
)
# Download config file
config_file = hf_hub_download(
repo_id=repo_id,
filename=config_filename,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
token=token,
)
# Load config
with open(config_file, 'r') as f:
config_dict = json.load(f)
# Extract model parameters
num_tokens = config_dict.get('num_tokens')
token_dim = config_dict.get('token_dim')
feature_dim = config_dict.get('feature_dim')
if any(param is None for param in [num_tokens, token_dim, feature_dim]):
raise ValueError("Config file missing required parameters: num_tokens, token_dim, feature_dim")
# Create model
model = create_multimodal_sae(
num_tokens=num_tokens,
token_dim=token_dim,
feature_dim=feature_dim,
device=device
)
# Load weights from safetensors
model_state = load_file(model_file)
model.load_state_dict(model_state)
model.eval()
logging.info(f"Loaded SAE model from Hub: {repo_id}")
logging.info(f"Model config: {num_tokens} tokens, {token_dim} dim, {feature_dim} features")
return model
[docs]
def load_sae_model(model_path: str, config_path: str = None, device: str = 'cuda'):
"""
Standalone function to load a trained SAE model from safetensors checkpoint
Args:
model_path: Path to the safetensors model file
config_path: Optional path to config.json file. If None, tries to find it automatically
device: Device to load the model on
Returns:
Loaded SAE model
"""
from src.sae import create_multimodal_sae
model_path = Path(model_path)
# Try to find config automatically if not provided
if config_path is None:
# Look for config.json in the same directory or parent directory
potential_configs = [
model_path.parent / "config.json",
model_path.parent.parent / "config.json"
]
config_path = None
for potential_config in potential_configs:
if potential_config.exists():
config_path = potential_config
break
if config_path is None:
raise FileNotFoundError("Could not find config.json file. Please provide config_path explicitly.")
# Load config
with open(config_path, 'r') as f:
config_dict = json.load(f)
# Extract model parameters
num_tokens = config_dict.get('num_tokens')
token_dim = config_dict.get('token_dim')
feature_dim = config_dict.get('feature_dim')
activation_fn = config_dict.get('activation_fn', 'relu')
if any(param is None for param in [num_tokens, token_dim, feature_dim]):
raise ValueError("Config file missing required parameters: num_tokens, token_dim, feature_dim")
# Create model
model = create_multimodal_sae(
num_tokens=num_tokens,
token_dim=token_dim,
feature_dim=feature_dim,
device=device
)
# Load weights from safetensors
model_state = load_file(model_path)
model.load_state_dict(model_state)
model.eval()
logging.info(f"Loaded SAE model from {model_path}")
logging.info(f"Model config: {num_tokens} tokens, {token_dim} dim, {feature_dim} features")
return model