src.ood.ood_detector module¶
- class src.ood.ood_detector.OODDetector(policy: ACTPolicy, sae_experiment_path: str | None = None, sae_hub_repo_id: str | None = None, ood_params_path: Path | None = None, force_ood_refresh: bool = False, device: str = 'cuda')[source]¶
Bases:
object
Out-of-distribution detector based on detecting unusual attention patterns.
Builds on the SAE model trained to extract features from a given ACT policy. It uses the SAE’s reconstruction error as a proxy for scenarios that are deemed out of distribution.
- fit_ood_threshold_to_validation_dataset(dataset: LeRobotDataset, std_threshold: float = 2.5, batch_size: int = 16, max_samples: int | None = None, save_path: str | None = None) Dict[str, float] [source]¶
Calibrate the out-of-duistribution detector on an unseen validation dataset.
This method runs the OOD Detector for each frame in the dataset, and fits the results on a Gaussian distribution. Anything above the specified standard deviation threshold (defaults to σ=2.5) is deemed out of distribution. While the default value will work for many datasets we recommend tuning it with a value that works best for your own datasets.
Fit the OOD threshold to the validation dataset using Gaussian distribution fitting.
- Args:
dataset: Validation dataset to fit on std_threshold: Number of standard deviations from mean to use as threshold batch_size: Batch size for processing max_samples: Maximum number of samples to process (None for all) save_path: Path to save OOD parameters (if None, uses self.ood_params_path)
- Returns:
Dictionary with fitted parameters
- get_reconstruction_error(observation: dict) float [source]¶
Get the reconstruction error of the observation using the SAE model.
- Args:
observation: Input observation dictionary (same format as policy input)
- Returns:
Reconstruction error (MSE loss)