src.attention_maps package¶
Submodules¶
src.attention_maps.act_attention_mapper module¶
- class src.attention_maps.act_attention_mapper.ACTPolicyWithAttention(policy, image_shapes=None, specific_decoder_token_index: int | None = None)[source]¶
Bases:
object
Wrapper for ACTPolicy that provides transformer attention visualizations.
- select_action(observation: Dict[str, Tensor]) Tuple[Tensor, Tensor, List[ndarray]] [source]¶
Extends policy.select_action to also compute attention maps.
- Args:
observation: Dictionary of observations
- Returns:
action: The predicted action tensor attention_maps: List of attention maps, one for each image
- visualize_attention(images: List[Tensor] | None = None, attention_maps: List[ndarray] | None = None, observation: Dict[str, Tensor] | None = None, use_rgb: bool = False, overlay_alpha: float = 0.5, show_proprio_border: bool = True, proprio_border_width: int = 15) List[ndarray] [source]¶
Create visualizations by overlaying attention maps on images.
- Args:
images: List of image tensors (optional) attention_maps: List of attention maps (optional) observation: Observation dict (optional, used if images not provided) use_rgb: Whether to use RGB for visualization overlay_alpha: Alpha value for attention overlay
- Returns:
List of visualization images as numpy arrays