Source code for src.sae.activation_collector

from dataclasses import dataclass
import json
import logging
from pathlib import Path
import time
from typing import Dict, List, Tuple, Optional, Any

from safetensors.torch import save_file, load_file
import torch
from tqdm import tqdm

from .token_sampler import TokenSamplerConfig, TokenSampler


@dataclass
class ActivationCacheConfig:
    """
    Configuration for activation caching.
    
    :meta private:
    """
    # Cache settings
    cache_dir: str = "./output/activation_cache"
    buffer_size: int = 128  # Number of samples per cache file
    
    # Data organization
    layer_name: str = "encoder.layers.4"
    experiment_name: str = "act_activations"
    use_token_sampling: bool = True  # Enable token sampling
    
    # Memory management
    max_memory_gb: float = 4.0  # Maximum memory to use for caching
    cleanup_on_start: bool = False  # Clean old cache files (set True to force fresh start)
    
    # Validation
    validate_cache: bool = True  # Validate cached files on load
    cache_metadata: bool = True  # Save metadata with cache

class ActivationCache:
    """
    Manages caching of activations to disk for memory-efficient SAE training.

    :meta private:
    """
    
    def __init__(
            self,
            config: ActivationCacheConfig,
            sampler_config: TokenSamplerConfig,
            total_tokens: int = None,
            policy_model = None,
        ):
        self.config = config
        self.cache_dir = Path(config.cache_dir) / config.experiment_name
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        # For now, create sampler with a placeholder value - we'll update it when we see real data
        if total_tokens is None:
            total_tokens = 602  # Temporary placeholder, will be updated from actual data
            logging.info("Using placeholder token count 602, will be updated from actual activation data")
        
        self.sampler = TokenSampler(sampler_config, total_tokens)
        self._sampler_config = sampler_config  # Store config for potential recreation
        
        # Initialize cache state
        self.current_buffer = []
        self.buffer_count = 0
        self.total_samples = 0
        self.cache_files = []
        
        # Metadata
        self.metadata = {
            'layer_name': config.layer_name,
            'experiment_name': config.experiment_name,
            'created_at': time.time(),
            'buffer_size': config.buffer_size,
            'cache_files': [],
            'total_samples': 0,
            'activation_shape': None,
            # Progress tracking for resumability
            'collection_status': 'in_progress',  # 'in_progress', 'completed', 'failed'
            'last_batch_idx': -1,  # Last successfully processed batch
            'last_updated': time.time(),
            'resume_info': {
                'dataloader_position': 0,
                'can_resume': True,
                'interruption_count': 0
            }
        }
        
        # Try to load existing cache state for resumption first
        cache_loaded = self._try_load_existing_cache()
        
        # Only cleanup if we couldn't load a valid resumable cache
        if config.cleanup_on_start and not cache_loaded:
            logging.info("No resumable cache found, cleaning up any invalid cache files")
            self._cleanup_cache()
        
        logging.info(f"Initialized activation cache at {self.cache_dir}")
    
    def update_batch_idx(self, batch_idx):
        self.metadata['last_batch_idx'] = batch_idx

    def _try_load_existing_cache(self):
        """Try to load existing cache metadata for resumption"""
        metadata_path = self.cache_dir / "cache_metadata.json"
        
        if metadata_path.exists():
            try:
                # Load existing metadata
                with open(metadata_path, 'r') as f:
                    existing_metadata = json.load(f)
                                
                # Check if cache is resumable
                if (
                    existing_metadata.get('collection_status') == 'in_progress' and 
                    self._validate_existing_cache_files(existing_metadata)
                ):
                    
                    # Restore state for resumption
                    self.metadata.update(existing_metadata)
                    self.total_samples = existing_metadata.get('total_samples', 0)
                    
                    # If metadata doesn't have cache_files info, reconstruct it from directory
                    cache_files_info = existing_metadata.get('cache_files', [])
                    if not cache_files_info:
                        # Scan directory for existing cache files
                        cache_files_info = self._reconstruct_cache_files_info()
                        self.metadata['cache_files'] = cache_files_info
                    
                    self.buffer_count = len(cache_files_info)
                    self.cache_files = [
                        self.cache_dir / info['filename'] 
                        for info in cache_files_info
                    ]
                    
                    # Increment interruption count
                    self.metadata['resume_info']['interruption_count'] += 1
                    
                    logging.info(f"Found resumable cache with {self.total_samples} samples")
                    logging.info(f"Last batch: {self.metadata['last_batch_idx']}, "
                               f"Interruptions: {self.metadata['resume_info']['interruption_count']}")
                    return True
                else:
                    logging.info("Existing cache found but not resumable - starting fresh")
                    
            except (json.JSONDecodeError, KeyError, FileNotFoundError) as e:
                logging.warning(f"Could not load existing cache state: {e}")
        
        return False
    
    def _validate_existing_cache_files(self, metadata: dict) -> bool:
        """Validate that all referenced cache files actually exist and are valid"""
        try:
            cache_files_info = metadata.get('cache_files', [])
            
            # If no cache_files info in metadata, check if any safetensors files exist
            if not cache_files_info:
                safetensors_files = list(self.cache_dir.glob("activations_buffer_*.safetensors"))
                if not safetensors_files:
                    logging.warning("No cache files found in directory")
                    return False
                logging.info(f"Found {len(safetensors_files)} cache files without metadata")
                return True  # We can reconstruct the metadata
            
            # Validate referenced files
            for file_info in cache_files_info:
                file_path = self.cache_dir / file_info['filename']
                if not file_path.exists():
                    logging.warning(f"Missing cache file: {file_path}")
                    return False
                    
                # Basic size check - file should have reasonable size
                if file_path.stat().st_size < 100:  # Very small files are likely corrupted
                    logging.warning(f"Cache file too small (likely corrupted): {file_path}")
                    return False
            
            return True
        except Exception as e:
            logging.warning(f"Cache validation failed: {e}")
            return False
    
    def _reconstruct_cache_files_info(self) -> List[Dict[str, Any]]:
        """Reconstruct cache files info by scanning the directory"""
        cache_files_info = []
        
        # Find all safetensors files in the cache directory
        safetensors_files = sorted(self.cache_dir.glob("activations_buffer_*.safetensors"))
        
        for filepath in safetensors_files:
            # Extract buffer index from filename
            filename = filepath.name
            try:
                buffer_idx = int(filename.split('_')[-1].split('.')[0])
            except (ValueError, IndexError):
                continue
            
            # Try to determine number of samples by loading metadata
            metadata_path = filepath.with_suffix('.json')
            num_samples = 0
            sample_range = (0, 0)
            
            if metadata_path.exists():
                try:
                    with open(metadata_path, 'r') as f:
                        buffer_metadata = json.load(f)
                    num_samples = len(buffer_metadata)
                    if buffer_metadata:
                        # Calculate sample range
                        sample_indices = [item.get('sample_idx', 0) for item in buffer_metadata]
                        sample_range = (min(sample_indices), max(sample_indices))
                except (json.JSONDecodeError, KeyError):
                    pass
            
            cache_files_info.append({
                'filename': filename,
                'buffer_idx': buffer_idx,
                'num_samples': num_samples,
                'sample_range': sample_range
            })
        
        logging.info(f"Reconstructed cache info for {len(cache_files_info)} buffers")
        return cache_files_info
    
    def _cleanup_cache(self):
        """Remove old cache files"""
        if self.cache_dir.exists():
            for file in self.cache_dir.iterdir():
                if file.suffix in ['.safetensors', '.pt', '.json']:
                    file.unlink()
            logging.info("Cleaned up old cache files")
    
    def add_activations(self, activations: torch.Tensor, batch_metadata: Optional[Dict[str, Any]] = None):
        """
        Add activations to cache buffer
        
        Args:
            activations: Tensor of shape (batch_size, num_tokens, token_dim)
            batch_metadata: Optional metadata for this batch
        """
        # Ensure activations are on CPU and detached
        if activations.requires_grad:
            activations = activations.detach()
        if activations.device != torch.device('cpu'):
            activations = activations.cpu()

        activations = activations.permute(1, 0, 2).contiguous() # flip batch size and tokens_length dims
        
        # Store activation shape for metadata
        if self.metadata['activation_shape'] is None:
            self.metadata['activation_shape'] = list(activations.shape[1:])  # Exclude batch dimension
        
        # Record the original number of tokens (before sampling) - this is the key insight!
        original_num_tokens = activations.shape[1]  # num_tokens from (batch_size, num_tokens, token_dim)
        if 'original_num_tokens' not in self.metadata:
            self.metadata['original_num_tokens'] = original_num_tokens
            logging.info(f"Recorded original_num_tokens from actual data: {original_num_tokens}")
            
            # Update the sampler with the correct token count if it was using placeholder
            if self.sampler.total_tokens != original_num_tokens:
                logging.info(f"Updating TokenSampler with correct token count: {self.sampler.total_tokens} -> {original_num_tokens}")
                self.sampler = TokenSampler(self._sampler_config, original_num_tokens)
                
        elif self.metadata['original_num_tokens'] != original_num_tokens:
            logging.warning(f"Token count mismatch! Expected {self.metadata['original_num_tokens']}, got {original_num_tokens}")
        
        # Add to buffer
        activations = self.sampler.sample_tokens(activations)
        batch_size = activations.shape[0]
        for i in range(batch_size):
            sample = {
                'activation': activations[i],  # Shape: (num_tokens, token_dim)
                'sample_idx': self.total_samples,
                'buffer_idx': len(self.current_buffer)
            }
            
            # Add metadata if provided
            if batch_metadata:
                sample['metadata'] = self._extract_sample_metadata(batch_metadata, i)
            
            self.current_buffer.append(sample)
            self.total_samples += 1
            
        # Flush buffer if full
        if len(self.current_buffer) >= self.config.buffer_size:
            self._flush_buffer()
    
    def _extract_sample_metadata(self, batch_metadata: Dict[str, Any], sample_idx: int) -> Dict[str, Any]:
        """Extract metadata for a single sample from batch metadata"""
        sample_metadata = {}
        
        for key, value in batch_metadata.items():
            if isinstance(value, torch.Tensor):
                if len(value.shape) > 0 and value.shape[0] > sample_idx:
                    # Store only essential metadata to save space
                    if key in ['episode_idx', 'frame_idx', 'action']:
                        sample_metadata[key] = value[sample_idx].tolist() if hasattr(value[sample_idx], 'tolist') else value[sample_idx]
            elif isinstance(value, (list, tuple)) and len(value) > sample_idx:
                sample_metadata[key] = value[sample_idx]
        
        return sample_metadata
    
    def _flush_buffer(self):
        """Save current buffer to disk"""
        if not self.current_buffer:
            return
        
        # Create filename
        filename = f"activations_buffer_{self.buffer_count:06d}"
        
        filepath = self.cache_dir / f"{filename}.safetensors"
        self._save_buffer_safetensors(filepath)
        
        # Update metadata
        self.cache_files.append(str(filepath))
        self.metadata['cache_files'].append({
            'filename': filepath.name,
            'buffer_idx': self.buffer_count,
            'num_samples': len(self.current_buffer),
            'sample_range': (
                self.current_buffer[0]['sample_idx'],
                self.current_buffer[-1]['sample_idx']
            )
        })
        
        logging.info(f"Saved buffer {self.buffer_count} with {len(self.current_buffer)} samples to {filepath.name}")
        
        # Update and save incremental metadata after each buffer
        self.metadata['total_samples'] = self.total_samples
        self.metadata['num_buffers'] = self.buffer_count + 1  # +1 because we're about to increment
        self.metadata['last_updated'] = time.time()
        
        # Save incremental metadata so the cache can be loaded even during collection
        metadata_path = self.cache_dir / "cache_metadata.json" 
        with open(metadata_path, 'w') as f:
            json.dump(self.metadata, f, indent=2)
        
        # Clear buffer
        self.current_buffer = []
        self.buffer_count += 1
    
    def _save_buffer_safetensors(self, filepath: Path):
        """Save buffer using safetensors format"""
        # Prepare tensors for safetensors (flat structure required)
        tensors = {}
        metadata_list = []
        
        for i, sample in enumerate(self.current_buffer):
            # Store activation tensor
            tensors[f"activation_{i}"] = sample['activation']
            
            # Store metadata separately (safetensors doesn't support nested structures)
            sample_metadata = {
                'sample_idx': sample['sample_idx'],
                'buffer_idx': sample['buffer_idx']
            }
            if 'metadata' in sample:
                sample_metadata.update(sample['metadata'])
            
            metadata_list.append(sample_metadata)
        
        # Save tensors
        save_file(tensors, filepath)
        
        # Save metadata separately
        metadata_path = filepath.with_suffix('.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata_list, f)
    
    def _save_buffer_torch(self, filepath: Path):
        """Save buffer using PyTorch format"""
        # Stack activations into single tensor
        activations = torch.stack([sample['activation'] for sample in self.current_buffer])
        
        # Prepare metadata
        metadata_list = []
        for sample in self.current_buffer:
            sample_metadata = {
                'sample_idx': sample['sample_idx'],
                'buffer_idx': sample['buffer_idx']
            }
            if 'metadata' in sample:
                sample_metadata.update(sample['metadata'])
            metadata_list.append(sample_metadata)
        
        # Save everything
        torch.save({
            'activations': activations,
            'metadata': metadata_list,
            'buffer_info': {
                'buffer_idx': self.buffer_count,
                'num_samples': len(self.current_buffer)
            }
        }, filepath)
    
    def finalize(self):
        """Flush any remaining activations and save metadata"""
        # Flush remaining buffer
        if self.current_buffer:
            self._flush_buffer()
        
        # Update final metadata
        self.metadata['total_samples'] = self.total_samples
        self.metadata['num_buffers'] = self.buffer_count
        self.metadata['finalized_at'] = time.time()
        
        # Only mark as completed if collection was actually complete
        if self.metadata.get('collection_complete', False):
            self.metadata['collection_status'] = 'completed'
            self.metadata['resume_info']['can_resume'] = False
        else:
            # Collection was interrupted - keep as in_progress and resumable
            self.metadata['collection_status'] = 'in_progress'
            self.metadata['resume_info']['can_resume'] = True
            logging.info("Cache finalized but collection incomplete - remains resumable")
        
        # Save metadata
        metadata_path = self.cache_dir / "cache_metadata.json"
        with open(metadata_path, 'w') as f:
            json.dump(self.metadata, f, indent=2)
        
        logging.info(f"Finalized cache with {self.total_samples} samples in {self.buffer_count} buffers")
        
        return {
            'total_samples': self.total_samples,
            'num_buffers': self.buffer_count,
            'cache_dir': str(self.cache_dir),
            'cache_files': self.cache_files
        }


class CachedActivationDataset(torch.utils.data.Dataset):
    """
    Dataset that loads activations from cached files.

    :meta private:
    """
    
    def __init__(self, cache_dir: str, shuffle: bool = True, preload_buffers: int = 2):
        self.cache_dir = Path(cache_dir)
        self.shuffle = shuffle
        self.preload_buffers = preload_buffers
        
        # Load metadata
        metadata_path = self.cache_dir / "cache_metadata.json"
        if not metadata_path.exists():
            raise FileNotFoundError(f"Cache metadata not found at {metadata_path}")
        
        with open(metadata_path, 'r') as f:
            self.metadata = json.load(f)
        
        self.total_samples = self.metadata['total_samples']
        self.activation_shape = tuple(self.metadata['activation_shape'])
        self.cache_files_info = self.metadata['cache_files']
        
        # Create sample index mapping
        self._create_sample_mapping()
        
        # Buffer management for preloading
        self._buffer_cache = {}
        self._buffer_access_order = []
        
        logging.info(f"Loaded cached dataset with {self.total_samples} samples")
    
    def _create_sample_mapping(self):
        """Create mapping from global sample index to (buffer_file, local_index)"""
        self.sample_to_buffer = {}
        
        for buffer_info in self.cache_files_info:
            start_idx, end_idx = buffer_info['sample_range']
            buffer_filename = buffer_info['filename']
            
            for global_idx in range(start_idx, end_idx + 1):
                local_idx = global_idx - start_idx
                self.sample_to_buffer[global_idx] = (buffer_filename, local_idx)
    
    def _load_buffer(self, filename: str) -> Tuple[torch.Tensor, List[Dict[str, Any]]]:
        """Load a specific buffer file"""
        filepath = self.cache_dir / filename
        
        if filename.endswith('.safetensors'):
            return self._load_buffer_safetensors(filepath)
        else:
            return self._load_buffer_torch(filepath)
    
    def _load_buffer_safetensors(self, filepath: Path) -> Tuple[torch.Tensor, List[Dict[str, Any]]]:
        """Load buffer from safetensors format"""
        # Load tensors
        tensors = load_file(filepath)
        
        # Reconstruct activations
        activations = []
        i = 0
        while f"activation_{i}" in tensors:
            activations.append(tensors[f"activation_{i}"])
            i += 1
        
        activations_tensor = torch.stack(activations)
        
        # Load metadata
        metadata_path = filepath.with_suffix('.json')
        if metadata_path.exists():
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
        else:
            metadata = [{}] * len(activations)
        
        return activations_tensor, metadata
    
    def _load_buffer_torch(self, filepath: Path) -> Tuple[torch.Tensor, List[Dict[str, Any]]]:
        """Load buffer from PyTorch format"""
        data = torch.load(filepath, map_location='cpu')
        return data['activations'], data['metadata']
    
    def _get_buffer_with_cache(self, filename: str) -> Tuple[torch.Tensor, List[Dict[str, Any]]]:
        """Get buffer with LRU caching"""
        if filename in self._buffer_cache:
            # Move to end (most recently used)
            self._buffer_access_order.remove(filename)
            self._buffer_access_order.append(filename)
            return self._buffer_cache[filename]
        
        # Load buffer
        activations, metadata = self._load_buffer(filename)
        
        # Add to cache
        self._buffer_cache[filename] = (activations, metadata)
        self._buffer_access_order.append(filename)
        
        # Evict old buffers if cache is full
        while len(self._buffer_cache) > self.preload_buffers:
            oldest_file = self._buffer_access_order.pop(0)
            del self._buffer_cache[oldest_file]
        
        return activations, metadata
    
    def __len__(self) -> int:
        return self.total_samples
    
    def __getitem__(self, idx: int) -> torch.Tensor:
        """Get activation by global index"""
        if idx >= self.total_samples:
            raise IndexError(f"Index {idx} out of range for dataset of size {self.total_samples}")
        
        # Get buffer info
        buffer_filename, local_idx = self.sample_to_buffer[idx]
        
        # Load buffer (with caching)
        activations, metadata = self._get_buffer_with_cache(buffer_filename)
        
        # Return specific activation
        return activations[local_idx]


class ActivationCollector:
    """
    Memory-efficient activation collector that caches to disk.

    :meta private:
    """
    
    def __init__(
            self,
            act_model,
            config: ActivationCacheConfig,
            sampler_config: TokenSamplerConfig,
        ):
        self.act_model = act_model
        self.config = config
        self.cache = ActivationCache(config, sampler_config, policy_model=act_model)
        self.hook = None
        
        # Setup hook
        self._register_hook()
    
    def _register_hook(self):
        """Register forward hook to capture activations"""
        def hook_fn(module, input, output):
            # Handle different output formats
            if isinstance(output, tuple):
                activation = output[0]
            else:
                activation = output
            
            # Store activations to cache
            if len(activation.shape) == 3:  # (batch_size, seq_len, hidden_dim)
                self.cache.add_activations(activation)
            else:
                logging.warning(f"Unexpected activation shape: {activation.shape}")
        
        # Get layer by name
        layer = self.act_model
        for attr in self.config.layer_name.split('.'):
            layer = getattr(layer, attr)
        
        self.hook = layer.register_forward_hook(hook_fn)
        logging.info(f"Registered hook on layer: {self.config.layer_name}")
    
    def collect_activations(
        self, 
        dataloader: torch.utils.data.DataLoader, 
        max_samples: Optional[int] = None,
        device: str = 'cuda'
    ) -> str:
        """
        Collect activations and cache to disk with resumption support
        
        Args:
            dataloader: DataLoader with input data
            max_samples: Maximum number of samples to collect
            device: Device to run model on
            
        Returns:
            Path to cache directory
        """
        self.act_model.eval()
        self.act_model = self.act_model.to(device)
        
        # Determine starting point for resumption
        start_batch_idx = self.cache.metadata.get('last_batch_idx', -1) + 1
        total_batches = len(dataloader)
        
        # Set completion tracking metadata on first run or update if needed
        if self.cache.metadata.get('total_batches_expected') is None:
            self.cache.metadata['total_batches_expected'] = total_batches
            self.cache.metadata['max_samples_target'] = max_samples
            logging.info(f"Set completion target: {total_batches} batches, max_samples: {max_samples}")
        
        if start_batch_idx > 0:
            logging.info(f"Resuming activation collection from batch {start_batch_idx}/{total_batches}")
            logging.info(f"Already collected {self.cache.total_samples} samples")
        
        samples_collected = self.cache.total_samples
        processed_batches = 0
        
        try:
            with torch.inference_mode():
                for batch_idx, batch in enumerate(tqdm(dataloader, desc="Collecting activations", 
                                                     initial=start_batch_idx, total=total_batches)):
                    # Skip batches we've already processed (for resumption)
                    if batch_idx < start_batch_idx:
                        continue
                    
                    # Move batch to device
                    for key in batch:
                        if isinstance(batch[key], torch.Tensor):
                            batch[key] = batch[key].to(device)
                    
                    # Extract dataset indices if available for better tracking
                    dataset_idx = None
                    if 'dataset_index' in batch:
                        dataset_idx = batch['dataset_index'][0].item() if torch.is_tensor(batch['dataset_index']) else batch['dataset_index'][0]
                    
                    # Forward pass (triggers hook and caching)
                    self.cache.update_batch_idx(batch_idx)
                    _ = self.act_model.select_action(batch)
                    self.act_model.reset()
                    samples_collected = self.cache.total_samples
                    processed_batches += 1
                    
                    # Check memory usage periodically
                    if batch_idx % 100 == 0:
                        self._check_memory_usage()
                    
                    if max_samples and samples_collected >= max_samples:
                        logging.info(f"Reached maximum samples limit: {max_samples}")
                        self.cache.metadata['collection_complete'] = True
                        break
                
                # Check if we completed all batches
                if batch_idx >= total_batches - 1:
                    logging.info(f"Completed all {total_batches} batches")
                    self.cache.metadata['collection_complete'] = True
                        
        except Exception as e:
            # Mark cache as failed but keep progress for potential resumption
            self.cache.metadata['collection_status'] = 'failed'
            self.cache.metadata['resume_info']['can_resume'] = True
            
            logging.error(f"Activation collection failed at batch {batch_idx if 'batch_idx' in locals() else start_batch_idx}: {e}")
            raise
        
        # Finalize cache
        cache_info = self.cache.finalize()
        logging.info(f"Collected {cache_info['total_samples']} activations in {cache_info['num_buffers']} buffers")
        
        return str(self.cache.cache_dir)
    
    def _check_memory_usage(self):
        """Check and log memory usage"""
        try:
            import psutil
            memory_gb = psutil.virtual_memory().used / (1024**3)
            if memory_gb > self.config.max_memory_gb:
                logging.warning(f"Memory usage ({memory_gb:.1f}GB) exceeds limit ({self.config.max_memory_gb}GB)")
        except ImportError:
            pass  # psutil not available
    
    def remove_hook(self):
        """Remove the registered hook"""
        if self.hook:
            self.hook.remove()
            self.hook = None


def is_cache_valid(cache_dir: str) -> bool:
    """
    Check if a cache directory contains valid cached activations.
    Now supports checking resumable (incomplete) caches.
    
    Args:
        cache_dir: Directory to check
        
    Returns:
        True if cache is valid (complete or resumable), False otherwise

    :meta private:
    """
    cache_path = Path(cache_dir)
    
    # Check if directory exists
    if not cache_path.exists():
        return False
    
    # Check if metadata file exists
    metadata_path = cache_path / "cache_metadata.json"
    
    # For resumable caches, progress.json might exist without full metadata
    if not metadata_path.exists():
        return False
    
    try:
        # Load metadata if available
        metadata = {}
        if metadata_path.exists():
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
        
        # Check collection status
        collection_status = metadata.get('collection_status')
        
        if collection_status == 'completed':
            # For completed caches, do full validation
            required_fields = ['total_samples', 'num_buffers', 'cache_files', 'activation_shape']
            if not all(field in metadata for field in required_fields):
                return False
            
            # Check if cache files actually exist
            cache_files_info = metadata.get('cache_files', [])
            for file_info in cache_files_info:
                file_path = cache_path / file_info['filename']
                if not file_path.exists():
                    return False
            
            # Check that we have a reasonable number of samples
            if metadata.get('total_samples', 0) <= 0:
                return False
            
        return True
        
    except (json.JSONDecodeError, KeyError, FileNotFoundError):
        return False

def get_cache_status(cache_dir: str) -> dict:
    """
    Get detailed status information about a cache directory.
    
    Args:
        cache_dir: Directory to check
        
    Returns:
        Dictionary with cache status information

    :meta private:
    """
    cache_path = Path(cache_dir)
    
    if not cache_path.exists():
        return {'status': 'missing', 'exists': False}
    
    status_info = {
        'exists': True,
        'path': str(cache_path),
        'status': 'unknown',
        'total_samples': 0,
        'can_resume': False,
        'last_batch_idx': -1,
        'interruption_count': 0,
        'cache_files': 0
    }
    
    try:
        # Load metadata if available
        metadata_path = cache_path / "cache_metadata.json"
        
        metadata = {}
        if metadata_path.exists():
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
                
        
        # Extract status information
        status_info['status'] = metadata.get('collection_status')
        status_info['total_samples'] = metadata.get('total_samples', 0)
        status_info['can_resume'] = metadata.get('resume_info', {}).get('can_resume', False)
        status_info['last_batch_idx'] = metadata.get('last_batch_idx', -1)
        status_info['interruption_count'] = metadata.get('resume_info', {}).get('interruption_count', 0)
        status_info['cache_files'] = len(metadata.get('cache_files', []))
        
        if metadata.get('created_at'):
            status_info['created_at'] = metadata['created_at']
        if metadata.get('last_updated'):
            status_info['last_updated'] = metadata.get('last_updated')
        if metadata.get('finalized_at'):
            status_info['finalized_at'] = metadata['finalized_at']
            
    except (json.JSONDecodeError, KeyError, FileNotFoundError) as e:
        status_info['status'] = 'corrupted'
        status_info['error'] = str(e)
    
    return status_info


def cleanup_invalid_cache(cache_dir: str) -> None:
    """
    Remove an invalid cache directory and all its contents.
    
    Args:
        cache_dir: Directory to clean up

    :meta private:
    """
    cache_path = Path(cache_dir)
    
    if cache_path.exists():
        import shutil
        logging.info(f"Cleaning up invalid cache directory: {cache_path}")
        shutil.rmtree(cache_path)


def create_cached_dataloader(
    cache_dir: str,
    batch_size: int = 256,
    shuffle: bool = True,
    num_workers: int = 0,
    preload_buffers: int = 2
) -> torch.utils.data.DataLoader:
    """
    Create DataLoader from cached activations
    
    Args:
        cache_dir: Directory containing cached activations
        batch_size: Batch size for training
        shuffle: Whether to shuffle data
        num_workers: Number of worker processes
        preload_buffers: Number of buffers to keep in memory
        
    Returns:
        DataLoader for cached activations

    :meta private:
    """
    dataset = CachedActivationDataset(
        cache_dir=cache_dir,
        shuffle=shuffle,
        preload_buffers=preload_buffers
    )
    
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )


[docs] def load_original_num_tokens_from_cache(cache_path: str) -> Optional[int]: """ Load the original number of tokens from cached activation metadata. Args: cache_path: Path to the activation cache directory Returns: Original number of tokens if found, None otherwise """ cache_path = Path(cache_path) # Look for metadata.json file metadata_file = cache_path / "cache_metadata.json" if metadata_file.exists(): try: with open(metadata_file, 'r') as f: metadata = json.load(f) original_num_tokens = metadata.get('original_num_tokens') if original_num_tokens is not None: logging.info(f"Loaded original_num_tokens from cache metadata: {original_num_tokens}") return int(original_num_tokens) else: logging.warning("original_num_tokens not found in cache metadata") except Exception as e: logging.warning(f"Could not load metadata from {metadata_file}: {e}") else: logging.warning(f"Metadata file not found at {metadata_file}") return None
def collect_and_cache_activations( act_model, dataloader: torch.utils.data.DataLoader, layer_name: str, cache_dir: str, experiment_name: str = "act_activations", buffer_size: int = 128, max_samples: Optional[int] = None, device: str = 'cuda', # Cache management cleanup_on_start: bool = False, # Set True to force clean start instead of resuming # Token sampling parameters use_token_sampling: bool = True, fixed_tokens: List[int] = None, sampling_strategy: str = "block_average", sampling_stride: int = 8, max_sampled_tokens: int = 100, block_size: int = 8 ) -> str: """ Convenience function to collect and cache activations with optional token sampling Args: act_model: ACT model to collect activations from dataloader: DataLoader with input data layer_name: Name of layer to hook cache_dir: Directory to store cache experiment_name: Name for this experiment buffer_size: Number of samples per cache file max_samples: Maximum samples to collect device: Device to run model on cleanup_on_start: If True, force clean start instead of resuming from cache use_token_sampling: Whether to use token sampling fixed_tokens: Token indices to always include (default: [0, 601]) sampling_strategy: "uniform", "stride", "random_fixed", or "block_average" sampling_stride: Take every Nth token when using stride strategy max_sampled_tokens: Maximum number of tokens to sample block_size: Size of blocks for block_average strategy Returns: Path to cache directory :meta private: """ config = ActivationCacheConfig( cache_dir=cache_dir, layer_name=layer_name, experiment_name=experiment_name, buffer_size=buffer_size, use_token_sampling=use_token_sampling, cleanup_on_start=cleanup_on_start, ) sampler_config = None if use_token_sampling: sampler_config = TokenSamplerConfig( fixed_tokens=fixed_tokens, sampling_strategy=sampling_strategy, sampling_stride=sampling_stride, max_sampled_tokens=max_sampled_tokens, block_size=block_size ) collector = ActivationCollector(act_model, config, sampler_config=sampler_config) try: cache_path = collector.collect_activations(dataloader, max_samples, device) return cache_path finally: collector.remove_hook()