src.attention_maps package

Submodules

src.attention_maps.act_attention_mapper module

class src.attention_maps.act_attention_mapper.ACTPolicyWithAttention(policy, image_shapes=None, specific_decoder_token_index: int | None = None)[source]

Bases: object

Wrapper for ACTPolicy that provides transformer attention visualizations.

select_action(observation: Dict[str, Tensor]) Tuple[Tensor, Tensor, List[ndarray]][source]

Extends policy.select_action to also compute attention maps.

Args:

observation: Dictionary of observations

Returns:

action: The predicted action tensor attention_maps: List of attention maps, one for each image

visualize_attention(images: List[Tensor] | None = None, attention_maps: List[ndarray] | None = None, observation: Dict[str, Tensor] | None = None, use_rgb: bool = False, overlay_alpha: float = 0.5, show_proprio_border: bool = True, proprio_border_width: int = 15) List[ndarray][source]

Create visualizations by overlaying attention maps on images.

Args:

images: List of image tensors (optional) attention_maps: List of attention maps (optional) observation: Observation dict (optional, used if images not provided) use_rgb: Whether to use RGB for visualization overlay_alpha: Alpha value for attention overlay

Returns:

List of visualization images as numpy arrays

Module contents