import json
import logging
from typing import List, Optional, Dict, Tuple
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from scipy import stats
from huggingface_hub import hf_hub_download, HfApi
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from src.sae.sae import MultiModalSAE
from src.sae.builder import SAEBuilder
from src.sae.token_sampler import TokenSampler, TokenSamplerConfig
[docs]
class OODDetector:
"""
Out-of-distribution detector based on detecting unusual attention patterns.
Builds on the SAE model trained to extract features from a given ACT policy.
It uses the SAE's reconstruction error as a proxy for scenarios that are deemed out of distribution.
"""
def __init__(
self,
policy: ACTPolicy,
sae_experiment_path: Optional[str] = None,
sae_hub_repo_id: Optional[str] = None,
ood_params_path: Optional[Path] = None,
force_ood_refresh: bool = False,
device: str = 'cuda',
):
self.policy = policy
self.device = device
# Validate input - need either experiment path or hub repo_id
if not sae_experiment_path and not sae_hub_repo_id:
raise ValueError("Must provide either sae_experiment_path or sae_hub_repo_id")
if sae_experiment_path and sae_hub_repo_id:
raise ValueError("Cannot provide both sae_experiment_path and sae_hub_repo_id")
# Load SAE model and config
self.sae_config = None
self.sae_source = None # 'local' or 'hub'
self.sae_hub_repo_id = sae_hub_repo_id
builder = SAEBuilder(device=device)
if sae_hub_repo_id:
# Load from Hugging Face Hub
logging.info(f"Loading SAE model from Hub: {sae_hub_repo_id}")
self.sae_model = builder.load_from_hub(
repo_id=sae_hub_repo_id,
)
self.sae_source = 'hub'
# Try to download config from Hub
try:
config_file = hf_hub_download(
repo_id=sae_hub_repo_id,
filename="config.json",
)
with open(config_file, 'r') as f:
self.sae_config = json.load(f)
except Exception as e:
logging.warning(f"Could not load config from Hub: {e}")
else:
# Load from local experiment path
logging.info(f"Loading SAE model from experiment: {sae_experiment_path}")
self.sae_model = builder.load_from_experiment(sae_experiment_path)
self.sae_source = 'local'
# Load config if exists
config_path = Path(sae_experiment_path) / "config.json"
if config_path.exists():
with open(config_path, 'r') as f:
self.sae_config = json.load(f)
self.layer_name = self._infer_layer_name_from_policy()
# Set SAE to eval mode
self.sae_model.eval()
# Initialize token sampler from SAE config
self.token_sampler = None
if self.sae_config is not None:
# Create token sampler config from SAE config
if self.sae_config.get('use_token_sampling', False):
sampler_config = TokenSamplerConfig(
fixed_tokens=self.sae_config.get('fixed_tokens', [0, 1]),
sampling_strategy=self.sae_config.get('sampling_strategy', 'block_average'),
sampling_stride=self.sae_config.get('sampling_stride', 8),
max_sampled_tokens=self.sae_config.get('max_sampled_tokens', 100),
block_size=self.sae_config.get('block_size', 8)
)
# Infer total_tokens from the policy model (same as SAE training)
from src.sae.config import SAETrainingConfig
temp_config = SAETrainingConfig()
total_tokens = temp_config._infer_original_num_tokens(self.policy)
if total_tokens is None:
total_tokens = 602 # Fallback default
logging.warning("Could not infer token count from model, using default 602")
else:
logging.info(f"Inferred {total_tokens} tokens from policy model for OOD detection")
self.token_sampler = TokenSampler(sampler_config, total_tokens=total_tokens)
logging.info(f"Initialized token sampler for OOD detection: {sampler_config.sampling_strategy}")
else:
logging.info("Token sampling disabled for OOD detection")
else:
logging.warning("No SAE config available, token sampling disabled")
# OOD distribution parameters
self.ood_params = None
self.ood_params_path = ood_params_path
self.force_ood_refresh = force_ood_refresh
# Handle existing OOD parameters based on force_ood_refresh flag
if force_ood_refresh:
logging.info("Force refresh requested - will not load existing OOD params")
else:
# Try to load OOD parameters from various sources
loaded = False
# 1. Try local path if provided
if ood_params_path is not None and Path(ood_params_path).exists():
logging.info(f"Loading existing OOD parameters from {ood_params_path}")
self._load_ood_params()
loaded = True
# 2. Try downloading from Hub if SAE is from Hub and no local params found
elif self.sae_source == 'hub' and not loaded:
try:
ood_params_file = hf_hub_download(
repo_id=sae_hub_repo_id,
filename="ood_params.json",
)
with open(ood_params_file, 'r') as f:
self.ood_params = json.load(f)
logging.info(f"Loaded OOD parameters from Hub: {sae_hub_repo_id}")
loaded = True
except Exception as e:
logging.info(f"Could not load OOD params from Hub: {e}")
if not loaded:
if ood_params_path is not None:
logging.info(f"OOD params path specified but file doesn't exist: {ood_params_path}")
else:
logging.info("No OOD params found - will need to fit threshold")
# Hook for activation extraction
self._hook = None
self._register_activation_hook()
def _register_activation_hook(self):
"""Register forward hook to capture activations from the specified layer"""
def hook_fn(module, input, output):
# Store the activations for later use
if isinstance(output, tuple):
self._captured_activations = output[0].clone() # Clone to avoid memory issues
else:
self._captured_activations = output.clone()
logging.debug(f"Captured activations shape: {self._captured_activations.shape}")
# Get layer by name
layer = self.policy
for attr in self.layer_name.split('.'):
layer = getattr(layer, attr)
# Verify the layer exists and is the right type
logging.info(f"Target layer: {layer} (type: {type(layer)})")
self._hook = layer.register_forward_hook(hook_fn)
logging.info(f"Registered activation hook on layer: {self.layer_name}")
# Verify hook was registered
if hasattr(layer, '_forward_hooks') and len(layer._forward_hooks) > 0:
logging.info(f"Hook successfully registered. Layer has {len(layer._forward_hooks)} forward hooks.")
else:
logging.warning("Hook registration may have failed - no forward hooks found on layer")
def _remove_hook(self):
"""Remove the forward hook"""
if self._hook is not None:
self._hook.remove()
self._hook = None
def __del__(self):
"""Cleanup hook when object is destroyed"""
self._remove_hook()
def _infer_layer_name_from_policy(self) -> str:
"""
Infer the layer name from the policy structure, using the same logic as SAE trainer.
Returns the last encoder layer's norm2 by default.
"""
# Default layer name
default_layer = "model.encoder.layers.3.norm2"
if hasattr(self.policy, 'model') and hasattr(self.policy.model, 'encoder'):
if hasattr(self.policy.model.encoder, 'layers') and len(self.policy.model.encoder.layers) > 0:
# Use the last layer's norm2 by default
layer_idx = len(self.policy.model.encoder.layers) - 1
inferred_layer = f"model.encoder.layers.{layer_idx}.norm2"
logging.info(f"Inferred layer name from policy structure: {inferred_layer}")
return inferred_layer
logging.info(f"Could not infer layer from policy structure, using default: {default_layer}")
return default_layer
[docs]
def fit_ood_threshold_to_validation_dataset(
self,
dataset: LeRobotDataset,
std_threshold: float = 2.5,
batch_size: int = 16,
max_samples: Optional[int] = None,
save_path: Optional[str] = None,
) -> Dict[str, float]:
"""
Calibrate the out-of-duistribution detector on an unseen validation dataset.
This method runs the OOD Detector for each frame in the dataset, and fits the results on a Gaussian distribution.
Anything above the specified standard deviation threshold (defaults to σ=2.5) is deemed out of distribution.
While the default value will work for many datasets we recommend tuning it with a value that works best for your own datasets.
Fit the OOD threshold to the validation dataset using Gaussian distribution fitting.
Args:
dataset: Validation dataset to fit on
std_threshold: Number of standard deviations from mean to use as threshold
batch_size: Batch size for processing
max_samples: Maximum number of samples to process (None for all)
save_path: Path to save OOD parameters (if None, uses self.ood_params_path)
Returns:
Dictionary with fitted parameters
"""
logging.info("Fitting OOD threshold using validation dataset...")
# Collect reconstruction errors from validation dataset
reconstruction_errors = []
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
)
num_processed = 0
max_batches = (max_samples // batch_size) if max_samples else None
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing validation data")):
if max_batches and batch_idx >= max_batches:
break
# Prepare batch for policy - convert to observation format
batch_observations = self._prepare_batch_for_policy(batch)
for obs in batch_observations:
# Get reconstruction error for this observation
recon_error = self.get_reconstruction_error(obs)
reconstruction_errors.append(recon_error)
num_processed += 1
if max_samples and num_processed >= max_samples:
break
if max_samples and num_processed >= max_samples:
break
if len(reconstruction_errors) == 0:
raise RuntimeError("No reconstruction errors were collected. Check dataset and model compatibility.")
# Convert to numpy array for analysis
reconstruction_errors = np.array(reconstruction_errors)
# Fit Gaussian distribution
mean = float(np.mean(reconstruction_errors))
std = float(np.std(reconstruction_errors))
# Calculate threshold
threshold = mean + std_threshold * std
# Additional statistics
percentiles = np.percentile(reconstruction_errors, [50, 90, 95, 99, 99.5, 99.9])
# Store parameters
self.ood_params = {
'mean': mean,
'std': std,
'threshold': threshold,
'std_threshold': std_threshold,
'num_samples': len(reconstruction_errors),
'percentiles': {
'50': float(percentiles[0]),
'90': float(percentiles[1]),
'95': float(percentiles[2]),
'99': float(percentiles[3]),
'99.5': float(percentiles[4]),
'99.9': float(percentiles[5]),
},
'min': float(np.min(reconstruction_errors)),
'max': float(np.max(reconstruction_errors)),
}
# Save parameters locally
save_path = save_path or self.ood_params_path
if save_path:
self._save_ood_params(save_path)
logging.info(f"OOD parameters {'refreshed and ' if self.force_ood_refresh else ''}saved to {save_path}")
# Upload to Hub if the SAE model came from Hub
if self.sae_source == 'hub':
try:
self._upload_ood_params_to_hub()
logging.info(f"OOD parameters uploaded to Hub: {self.sae_hub_repo_id}")
except Exception as e:
logging.warning(f"Failed to upload OOD parameters to Hub: {e}")
logging.info(f"OOD threshold fitted:")
logging.info(f" Mean: {mean:.6f}")
logging.info(f" Std: {std:.6f}")
logging.info(f" Threshold ({std_threshold}σ): {threshold:.6f}")
logging.info(f" Samples processed: {len(reconstruction_errors)}")
return self.ood_params
[docs]
def is_out_of_distribution(self, observation: dict) -> Tuple[bool, float]:
"""
Detect if the observation is OOD using the SAE model.
Args:
observation: Input observation dictionary
Returns:
Tuple of (is_ood, reconstruction_error)
"""
if self.ood_params is None:
raise RuntimeError("OOD parameters not fitted. Call fit_ood_threshold_to_validation_dataset first.")
reconstruction_error = self.get_reconstruction_error(observation)
threshold = self.ood_params['threshold']
is_ood = reconstruction_error > threshold
return is_ood, reconstruction_error
[docs]
def get_reconstruction_error(self, observation: dict) -> float:
"""
Get the reconstruction error of the observation using the SAE model.
Args:
observation: Input observation dictionary (same format as policy input)
Returns:
Reconstruction error (MSE loss)
"""
with torch.inference_mode():
# Reset captured activations
self._captured_activations = None
# Run policy forward pass to capture activations
# We don't need the actual output, just the activations
_ = self.policy.select_action(observation)
self.policy.reset()
# Get activations and prepare for SAE
self._captured_activations = self._captured_activations.detach()
activations = self._captured_activations.permute(1, 0, 2).contiguous() # flip batch size and tokens_length dims
# Apply token sampling if configured (same as SAE training)
if self.token_sampler is not None:
original_shape = activations.shape
activations = self.token_sampler.sample_tokens(activations)
logging.debug(f"Token sampling: {original_shape} -> {activations.shape}")
else:
logging.debug(f"No token sampling applied, activations shape: {activations.shape}")
# Handle batch dimension - we expect single sample
if activations.dim() == 3 and activations.shape[0] == 1:
activations = activations.squeeze(0) # Remove batch dim
elif activations.dim() == 3:
# Multiple samples in batch - take first one
activations = activations[0]
# Ensure activations are in the right format for SAE
# SAE expects (num_tokens, token_dim)
if activations.dim() != 2:
raise RuntimeError(f"Expected 2D activations, got shape {activations.shape}")
# Add batch dimension for SAE
activations_batch = activations.unsqueeze(0).to(self.device)
# Get reconstruction from SAE
reconstruction, features = self.sae_model(activations_batch)
# Calculate reconstruction error (MSE)
mse_loss = torch.nn.functional.mse_loss(
reconstruction.squeeze(0),
activations_batch.squeeze(0),
reduction='mean'
)
return float(mse_loss.item())
def _prepare_batch_for_policy(self, batch: dict) -> List[dict]:
"""
Convert dataset batch to list of policy observation dictionaries.
Args:
batch: Batch from dataset
Returns:
List of observation dictionaries
"""
batch_size = None
observations = []
# Determine batch size from first tensor
for key, value in batch.items():
if torch.is_tensor(value):
batch_size = value.shape[0]
break
if batch_size is None:
raise ValueError("Could not determine batch size from batch data")
# Convert batch to list of individual observations
for i in range(batch_size):
obs = {}
for key, value in batch.items():
if torch.is_tensor(value):
# Extract single sample and add batch dimension
obs[key] = value[i:i+1].to(self.device)
else:
# Handle non-tensor data
if hasattr(value, '__getitem__'):
obs[key] = value[i]
else:
obs[key] = value
observations.append(obs)
return observations
def _load_ood_params(self):
"""Load OOD parameters from file"""
if self.ood_params_path and Path(self.ood_params_path).exists():
with open(self.ood_params_path, 'r') as f:
self.ood_params = json.load(f)
logging.info(f"Loaded OOD parameters from {self.ood_params_path}")
logging.info(f" Threshold: {self.ood_params.get('threshold', 'N/A')}")
def _save_ood_params(self, save_path: str):
"""Save OOD parameters to file"""
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, 'w') as f:
json.dump(self.ood_params, f, indent=2)
logging.info(f"Saved OOD parameters to {save_path}")
def _upload_ood_params_to_hub(self):
"""Upload OOD parameters to Hugging Face Hub"""
if not self.sae_hub_repo_id:
raise ValueError("No Hub repo ID available for upload")
if not self.ood_params:
raise ValueError("No OOD parameters to upload")
# Create temporary file with OOD parameters
from tempfile import NamedTemporaryFile
import os
with NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(self.ood_params, f, indent=2)
temp_path = f.name
try:
# Upload to Hub
api = HfApi()
api.upload_file(
path_or_fileobj=temp_path,
path_in_repo="ood_params.json",
repo_id=self.sae_hub_repo_id,
commit_message="Update OOD parameters"
)
finally:
# Clean up temp file
os.unlink(temp_path)
[docs]
def get_ood_stats(self) -> Optional[Dict[str, float]]:
"""Get current OOD parameters and statistics"""
return self.ood_params.copy() if self.ood_params else None
[docs]
def needs_ood_fitting(self) -> bool:
"""Check if OOD threshold needs to be fitted"""
return self.ood_params is None or self.force_ood_refresh
[docs]
def create_default_ood_params_path(experiment_name: str, base_dir: str = "output") -> str:
"""
Create standard path for OOD parameters based on experiment name.
Args:
experiment_name: SAE experiment name
base_dir: Base output directory
Returns:
Path string for OOD parameters file
"""
return str(Path(base_dir) / experiment_name / "ood_params.json")