src.ood.ood_detector module

class src.ood.ood_detector.OODDetector(policy: ACTPolicy, sae_experiment_path: str | None = None, sae_hub_repo_id: str | None = None, ood_params_path: Path | None = None, force_ood_refresh: bool = False, device: str = 'cuda')[source]

Bases: object

Out-of-distribution detector based on detecting unusual attention patterns.

Builds on the SAE model trained to extract features from a given ACT policy. It uses the SAE’s reconstruction error as a proxy for scenarios that are deemed out of distribution.

fit_ood_threshold_to_validation_dataset(dataset: LeRobotDataset, std_threshold: float = 2.5, batch_size: int = 16, max_samples: int | None = None, save_path: str | None = None) Dict[str, float][source]

Calibrate the out-of-duistribution detector on an unseen validation dataset.

This method runs the OOD Detector for each frame in the dataset, and fits the results on a Gaussian distribution. Anything above the specified standard deviation threshold (defaults to σ=2.5) is deemed out of distribution. While the default value will work for many datasets we recommend tuning it with a value that works best for your own datasets.

Fit the OOD threshold to the validation dataset using Gaussian distribution fitting.

Args:

dataset: Validation dataset to fit on std_threshold: Number of standard deviations from mean to use as threshold batch_size: Batch size for processing max_samples: Maximum number of samples to process (None for all) save_path: Path to save OOD parameters (if None, uses self.ood_params_path)

Returns:

Dictionary with fitted parameters

get_ood_stats() Dict[str, float] | None[source]

Get current OOD parameters and statistics

get_reconstruction_error(observation: dict) float[source]

Get the reconstruction error of the observation using the SAE model.

Args:

observation: Input observation dictionary (same format as policy input)

Returns:

Reconstruction error (MSE loss)

is_out_of_distribution(observation: dict) Tuple[bool, float][source]

Detect if the observation is OOD using the SAE model.

Args:

observation: Input observation dictionary

Returns:

Tuple of (is_ood, reconstruction_error)

needs_ood_fitting() bool[source]

Check if OOD threshold needs to be fitted

src.ood.ood_detector.create_default_ood_params_path(experiment_name: str, base_dir: str = 'output') str[source]

Create standard path for OOD parameters based on experiment name.

Args:

experiment_name: SAE experiment name base_dir: Base output directory

Returns:

Path string for OOD parameters file