Source code for src.attention_maps.act_attention_mapper

import numpy as np
import torch
import cv2
from typing import List, Dict, Tuple, Optional

[docs] class ACTPolicyWithAttention: """ Wrapper for ACTPolicy that provides transformer attention visualizations. """ def __init__(self, policy, image_shapes=None, specific_decoder_token_index: Optional[int] = None): """ Initialize the wrapper with an ACTPolicy. Args: policy: An instance of ACTPolicy image_shapes: Optional list of image shapes [(H1, W1), (H2, W2), ...] if known in advance specific_decoder_token_index: experimental, allows visualising attention maps for a particular token rather than averaging all outputs. """ self.policy = policy self.config = policy.config self.specific_decoder_token_index = specific_decoder_token_index if self.specific_decoder_token_index is not None: if not hasattr(self.config, 'chunk_size'): raise AttributeError("Policy's config object does not have 'chunk_size' attribute.") if not (0 <= self.specific_decoder_token_index < self.config.chunk_size): raise ValueError( f"specific_decoder_token_index ({self.specific_decoder_token_index}) " f"must be between 0 and chunk_size-1 ({self.config.chunk_size - 1})." ) # Determine number of images from config if self.config.image_features: self.num_images = len(self.config.image_features) else: self.num_images = 0 # Store image shapes if provided, otherwise will be detected at runtime self.image_shapes = image_shapes # For storing the last processed images and attention self.last_observation = None self.last_attention_maps = None if not hasattr(self.policy, 'model') or \ not hasattr(self.policy.model, 'decoder') or \ not hasattr(self.policy.model.decoder, 'layers') or \ not self.policy.model.decoder.layers: raise AttributeError("Policy model structure does not match expected ACT architecture for target_layer.") self.target_layer = self.policy.model.decoder.layers[-1].multihead_attn
[docs] def select_action(self, observation: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, List[np.ndarray]]: """ Extends policy.select_action to also compute attention maps. Args: observation: Dictionary of observations Returns: action: The predicted action tensor attention_maps: List of attention maps, one for each image """ # Store the observation for later use self.last_observation = observation.copy() # Process the images through the backbone first to understand spatial dimensions images = self._extract_images(observation) image_spatial_shapes = self._get_image_spatial_shapes(images) # Set up hook to capture attention weights attention_weights_capture = [] def attention_hook(module, input_args, output_tuple): # Capture the attention weights # In some MultiheadAttention implementations, the attention weights # might be returned with shape: [batch_size, tgt_len, src_len] # or [batch_size, num_heads, tgt_len, src_len] if isinstance(output_tuple, tuple) and len(output_tuple) > 1: # If output is a tuple with attention weights as second element attn_weights = output_tuple[1] else: # If output format is different, try to get weights from the module directly # Some implementations store attention weights in the module after forward pass attn_weights = getattr(module, 'attn_weights', None) if attn_weights is not None: # Store the weights regardless of shape - we'll handle reshape later attention_weights_capture.append(attn_weights.detach().cpu()) # Register the hook handle = self.target_layer.register_forward_hook(attention_hook) # Call the original policy's select_action with torch.inference_mode(): action = self.policy.select_action(observation) self.policy.reset() # Remove the hook handle.remove() # Process the attention weights if attention_weights_capture: attn = attention_weights_capture[0].to(action.device) attention_maps, proprio_attention = self._map_attention_to_images(attn, image_spatial_shapes) self.last_attention_maps = attention_maps self.last_proprio_attention = proprio_attention # Store for visualization else: print("Warning: No attention weights were captured.") attention_maps = [None] * self.num_images self.last_attention_maps = attention_maps self.last_proprio_attention = 0.0 # Store for visualization return action, attention_maps
def _extract_images(self, observation: Dict[str, torch.Tensor]) -> List[torch.Tensor]: """Extract image tensors from observation dictionary""" images = [] for key in self.config.image_features: if key in observation: images.append(observation[key]) return images def _get_image_spatial_shapes(self, images: List[torch.Tensor]) -> List[Tuple[int, int]]: """ Get the spatial shapes of the feature maps after ResNet processing. For ResNet, this is typically H/32 × W/32 """ spatial_shapes = [] for img_tensor in images: if img_tensor is None: spatial_shapes.append((0, 0)) continue # Run image through backbone to get feature map shape with torch.no_grad(): if img_tensor.dim() == 3: img_tensor_batched = img_tensor.unsqueeze(0) else: img_tensor_batched = img_tensor img_tensor_batched = img_tensor_batched.to(next(self.policy.model.backbone.parameters()).device) feature_map_dict = self.policy.model.backbone(img_tensor_batched) # Use batched tensor feature_map = feature_map_dict["feature_map"] h, w = feature_map.shape[2], feature_map.shape[3] spatial_shapes.append((h, w)) return spatial_shapes def _map_attention_to_images(self, attention: torch.Tensor, image_spatial_shapes: List[Tuple[int, int]]) -> Tuple[List[np.ndarray], float]: """ Map transformer attention weights back to the original images and extract proprioception attention. Normalizes attention maps globally across all images AND proprioception for this timestep. Args: attention: Tensor of shape [batch, heads, tgt_len, src_len] (tgt_len is config.chunk_size) image_spatial_shapes: List of (height, width) tuples for feature maps Returns: Tuple of: - List of globally normalized attention maps as numpy arrays - Proprioception attention value (float, normalized to same scale as visual attention) """ if attention.dim() == 4: attention = attention.mean(dim=1) # -> [batch, tgt_len, src_len] elif attention.dim() != 3: raise ValueError(f"Unexpected attention dimension: {attention.shape}. Expected 3 or 4.") # Token structure: [latent, (robot_state), (env_state), (image_tokens)] n_prefix_tokens = 1 # latent token proprio_token_idx = None if self.config.robot_state_feature: proprio_token_idx = n_prefix_tokens # proprioception is the next token n_prefix_tokens += 1 if self.config.env_state_feature: n_prefix_tokens += 1 # --- Step 1: Extract proprioception attention --- proprio_attention = 0.0 if proprio_token_idx is not None: # Extract attention to proprioception token if self.specific_decoder_token_index is not None: if 0 <= self.specific_decoder_token_index < attention.shape[1]: proprio_attention_tensor = attention[:, self.specific_decoder_token_index, proprio_token_idx] else: proprio_attention_tensor = attention[:, :, proprio_token_idx].mean(dim=1) else: proprio_attention_tensor = attention[:, :, proprio_token_idx].mean(dim=1) # Take first batch element proprio_attention = proprio_attention_tensor[0].cpu().numpy().item() # --- Step 2: Collect all raw (unnormalized) 2D numpy attention maps --- raw_numpy_attention_maps = [] # Store the per-image token counts for reshaping, needed later tokens_per_image = [h * w for h, w in image_spatial_shapes] current_src_token_idx = n_prefix_tokens for i, (h_feat, w_feat) in enumerate(image_spatial_shapes): if h_feat == 0 or w_feat == 0: raw_numpy_attention_maps.append(None) if tokens_per_image[i] > 0: # if shape was (0,0) but tokens_per_image[i] was not 0 current_src_token_idx += tokens_per_image[i] continue num_img_tokens = tokens_per_image[i] start_idx = current_src_token_idx end_idx = start_idx + num_img_tokens current_src_token_idx = end_idx attention_to_img_features = attention[:, :, start_idx:end_idx] if self.specific_decoder_token_index is not None: if not (0 <= self.specific_decoder_token_index < attention_to_img_features.shape[1]): print(f"Warning (map_attention): specific_decoder_token_index {self.specific_decoder_token_index} " f"is out of bounds for actual tgt_len {attention_to_img_features.shape[1]}. " f"Falling back to averaging.") img_attn_tensor_for_map = attention_to_img_features.mean(dim=1) else: img_attn_tensor_for_map = attention_to_img_features[:, self.specific_decoder_token_index, :] else: img_attn_tensor_for_map = attention_to_img_features.mean(dim=1) if img_attn_tensor_for_map.shape[0] > 1 and i == 0: # Print once print(f"Warning (map_attention): Batch size is {img_attn_tensor_for_map.shape[0]}. Processing first element for attention map.") if img_attn_tensor_for_map.shape[1] != num_img_tokens: print(f"Warning (map_attention): Mismatch in token count for image {i}. " f"Expected {num_img_tokens}, got {img_attn_tensor_for_map.shape[1]}. " f"Skipping map for this image.") raw_numpy_attention_maps.append(None) continue try: # Get the tensor for the first batch item, still on device img_attn_map_1d_tensor = img_attn_tensor_for_map[0] # [num_img_tokens] # Reshape to 2D tensor img_attn_map_2d_tensor = img_attn_map_1d_tensor.reshape(h_feat, w_feat) raw_numpy_attention_maps.append(img_attn_map_2d_tensor.cpu().numpy()) except RuntimeError as e: print(f"Error (map_attention): Reshaping attention for image {i}: {e}. " f"Shape was {img_attn_tensor_for_map[0].shape}, target HxW: {h_feat}x{w_feat}. " f"Num tokens: {num_img_tokens}. Skipping.") raw_numpy_attention_maps.append(None) continue # --- Step 3: Find global min and max from all valid raw maps AND proprioception --- global_min = float('inf') global_max = float('-inf') found_any_valid_map = False # Include proprioception attention in global scaling if proprio_attention is not None: if proprio_attention < global_min: global_min = proprio_attention if proprio_attention > global_max: global_max = proprio_attention found_any_valid_map = True for raw_map_np in raw_numpy_attention_maps: if raw_map_np is not None: current_min = raw_map_np.min() current_max = raw_map_np.max() if current_min < global_min: global_min = current_min if current_max > global_max: global_max = current_max found_any_valid_map = True if not found_any_valid_map: # All maps were None, return the list of Nones return raw_numpy_attention_maps, 0.0 # If global_min and global_max are still inf/-inf, it means all maps were empty or had issues # This case should be covered by found_any_valid_map, but as a safe guard: if global_min == float('inf') or global_max == float('-inf'): print("Warning (map_attention): Could not determine global min/max for attention. All maps might be invalid.") # Fallback: return unnormalized maps or list of Nones return [np.zeros_like(m, dtype=np.float32) if m is not None else None for m in raw_numpy_attention_maps], 0.0 # --- Step 4: Normalize proprioception attention --- if global_max > global_min: normalized_proprio_attention = (proprio_attention - global_min) / (global_max - global_min) else: normalized_proprio_attention = 0.0 # --- Step 5: Normalize all valid visual attention maps using global min/max --- final_normalized_attention_maps = [] for raw_map_np in raw_numpy_attention_maps: if raw_map_np is None: final_normalized_attention_maps.append(None) continue if global_max > global_min: # Perform normalization normalized_map = (raw_map_np - global_min) / (global_max - global_min) else: # All values across all valid maps are the same (e.g., all are 0.001, or all are 0) # Create a uniform map (e.g., all zeros or all 0.5s) # If global_max == global_min, it implies all values are equal to global_min (or global_max). # If global_min is 0, then (raw_map_np - 0) / (0-0) is problematic. # A common practice is to make such a map uniform, often zeros. normalized_map = np.zeros_like(raw_map_np, dtype=np.float32) # If you prefer a mid-gray for perfectly flat attention: # normalized_map = np.full_like(raw_map_np, 0.5, dtype=np.float32) final_normalized_attention_maps.append(normalized_map) return final_normalized_attention_maps, normalized_proprio_attention
[docs] def visualize_attention(self, images: Optional[List[torch.Tensor]] = None, attention_maps: Optional[List[np.ndarray]] = None, observation: Optional[Dict[str, torch.Tensor]] = None, use_rgb: bool = False, overlay_alpha: float = 0.5, show_proprio_border: bool = True, proprio_border_width: int = 15) -> List[np.ndarray]: """ Create visualizations by overlaying attention maps on images. Args: images: List of image tensors (optional) attention_maps: List of attention maps (optional) observation: Observation dict (optional, used if images not provided) use_rgb: Whether to use RGB for visualization overlay_alpha: Alpha value for attention overlay Returns: List of visualization images as numpy arrays """ # If no images provided, use from observation or last observation if images is None: if observation is not None: images = self._extract_images(observation) elif self.last_observation is not None: images = self._extract_images(self.last_observation) else: raise ValueError("No images provided and no stored observation available") # If no attention maps provided, use last computed ones if attention_maps is None: if self.last_attention_maps is not None: attention_maps = self.last_attention_maps else: raise ValueError("No attention maps provided and no stored attention maps available") # Get proprioception attention value proprio_attention = getattr(self, 'last_proprio_attention', 0.0) visualizations = [] for i, (img, attn_map) in enumerate(zip(images, attention_maps)): if img is None or attn_map is None: visualizations.append(None) continue # Convert tensor to numpy if isinstance(img, torch.Tensor): # Move channels to last dimension (H,W,C) for visualization if img.dim() == 4: # (B,C,H,W) img = img.squeeze(0) img_np = img.permute(1, 2, 0).cpu().numpy() # Normalize if needed if img_np.max() > 1.0: img_np = img_np / 255.0 else: img_np = img # Get image dimensions h, w = img_np.shape[:2] # Resize attention map to match image size attn_map_resized = cv2.resize(attn_map, (w, h)) # Create heatmap heatmap = cv2.applyColorMap(np.uint8(255 * attn_map_resized), cv2.COLORMAP_JET) if use_rgb: heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) # Create overlay with attention vis = cv2.addWeighted( np.uint8(255 * img_np), 1 - overlay_alpha, heatmap, overlay_alpha, 0 ) # Add proprioception attention border if show_proprio_border and proprio_attention > 0: # Convert normalized proprioception attention to color intensity border_intensity = int(255 * proprio_attention) # Create border color (use a different colormap for proprioception) # Using magenta/purple to distinguish from visual attention if use_rgb: border_color = (border_intensity, 0, border_intensity) # Magenta in RGB else: border_color = (border_intensity, 0, border_intensity) # Magenta in BGR # Draw border rectangles (outer and inner rectangles to create border effect) # Outer rectangle (full border) cv2.rectangle(vis, (0, 0), (w-1, h-1), border_color, proprio_border_width) # Optional: Add text label showing proprioception attention value text = f"Proprio: {proprio_attention:.3f}" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.6 thickness = 2 # Get text size for background rectangle (text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness) # Draw background rectangle for text cv2.rectangle(vis, (5, 5), (5 + text_width + 10, 5 + text_height + 10), (0, 0, 0), -1) # Draw text cv2.putText(vis, text, (10, 5 + text_height), font, font_scale, (255, 255, 255), thickness) visualizations.append(vis) return visualizations
# Forward other methods to the original policy def __getattr__(self, name): if name not in self.__dict__: return getattr(self.policy, name) return self.__dict__[name]