Skip to content

Examples¶

This page provides practical code examples for using SpaRRTa in your research.

Quick Examples¶

Evaluate a Single Model¶

from sparrta import SpaRRTaEvaluator, load_vfm

# Load model
model = load_vfm("dinov2_vitb14")

# Create evaluator
evaluator = SpaRRTaEvaluator(
    data_path="data/sparrta",
    probe_type="efficient",
)

# Run evaluation
results = evaluator.evaluate(
    model=model,
    environments=["forest"],
    tasks=["ego"],
)

print(f"Forest Ego Accuracy: {results['forest']['ego']['accuracy']:.2f}%")

Compare Multiple Models¶

from sparrta import SpaRRTaEvaluator, load_vfm
import pandas as pd

models = [
    "dinov2_vitb14",
    "dinov2_vitl14", 
    "clip_vitb16",
    "mae_vitb16",
]

evaluator = SpaRRTaEvaluator(data_path="data/sparrta")
results = {}

for model_name in models:
    model = load_vfm(model_name)
    results[model_name] = evaluator.evaluate(
        model=model,
        environments="all",
        tasks=["ego", "allo"],
    )

# Create comparison table
df = pd.DataFrame({
    name: {
        "Ego": r.mean_accuracy("ego"),
        "Allo": r.mean_accuracy("allo"),
    }
    for name, r in results.items()
}).T

print(df)

Dataset Examples¶

Load and Visualize Samples¶

import matplotlib.pyplot as plt
from sparrta import SpaRRTaDataset

dataset = SpaRRTaDataset(
    data_path="data/sparrta",
    environment="forest",
    task="ego",
    split="test",
)

# Visualize samples
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for idx, ax in enumerate(axes.flat):
    image, label, metadata = dataset[idx]

    # Convert tensor to numpy
    img_np = image.permute(1, 2, 0).numpy()

    ax.imshow(img_np)
    ax.set_title(f"{metadata['source_object']['class']} → {metadata['target_object']['class']}\nLabel: {['Front', 'Back', 'Left', 'Right'][label]}")
    ax.axis("off")

plt.tight_layout()
plt.savefig("sample_images.png", dpi=150)
plt.show()

Custom Data Filtering¶

from sparrta import SpaRRTaDataset

# Filter by specific object combinations
dataset = SpaRRTaDataset(
    data_path="data/sparrta",
    environment="desert",
    task="allo",
    object_triples=[
        ("cactus", "camel", "human"),
        ("rock", "barrel", "human"),
    ],
)

print(f"Filtered dataset size: {len(dataset)}")

# Filter by label
left_samples = [
    (img, label, meta) 
    for img, label, meta in dataset 
    if label == 2  # Left
]

print(f"Samples with 'Left' label: {len(left_samples)}")

Training Examples¶

Custom Training Loop¶

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from sparrta import SpaRRTaDataset, load_vfm
from sparrta.probes import EfficientProbe
from tqdm import tqdm

# Configuration
config = {
    "model": "dinov2_vitb14",
    "environment": "forest",
    "task": "ego",
    "batch_size": 256,
    "lr": 1e-3,
    "epochs": 100,
    "num_queries": 4,
}

# Setup data
train_dataset = SpaRRTaDataset("data/sparrta", config["environment"], config["task"], "train")
val_dataset = SpaRRTaDataset("data/sparrta", config["environment"], config["task"], "val")

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)

# Setup model and probe
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_vfm(config["model"]).to(device).eval()
probe = EfficientProbe(
    input_dim=model.feature_dim,
    num_classes=4,
    num_queries=config["num_queries"],
).to(device)

# Setup training
optimizer = AdamW(probe.parameters(), lr=config["lr"], weight_decay=1e-3)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)
criterion = nn.CrossEntropyLoss()

# Training loop
best_val_acc = 0

for epoch in range(config["epochs"]):
    # Train
    probe.train()
    train_loss = 0

    for images, labels, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            features = model.extract_features(images)

        logits, _ = probe(features)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    scheduler.step()

    # Validate
    probe.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for images, labels, _ in val_loader:
            images, labels = images.to(device), labels.to(device)
            features = model.extract_features(images)
            logits, _ = probe(features)
            correct += (logits.argmax(dim=1) == labels).sum().item()
            total += len(labels)

    val_acc = 100 * correct / total

    print(f"Epoch {epoch+1}: Train Loss = {train_loss/len(train_loader):.4f}, Val Acc = {val_acc:.2f}%")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(probe.state_dict(), "best_probe.pth")

print(f"Best Val Accuracy: {best_val_acc:.2f}%")

Training with Weights & Biases Logging¶

import wandb
import torch
from sparrta import SpaRRTaEvaluator, load_vfm

# Initialize W&B
wandb.init(
    project="sparrta-evaluation",
    config={
        "model": "dinov2_vitl14",
        "probe": "efficient",
        "environments": ["forest", "desert", "city"],
    }
)

model = load_vfm(wandb.config.model)

evaluator = SpaRRTaEvaluator(
    data_path="data/sparrta",
    probe_type=wandb.config.probe,
    wandb_logging=True,  # Enable W&B logging
)

results = evaluator.evaluate(
    model=model,
    environments=wandb.config.environments,
    tasks=["ego", "allo"],
)

# Log final results
wandb.log({
    "mean_ego_accuracy": results.mean_accuracy("ego"),
    "mean_allo_accuracy": results.mean_accuracy("allo"),
})

wandb.finish()

Visualization Examples¶

Attention Map Visualization¶

import torch
import matplotlib.pyplot as plt
import numpy as np
from sparrta import load_vfm, SpaRRTaDataset
from sparrta.probes import EfficientProbe

# Load model and probe
model = load_vfm("dinov2_vitb14").cuda().eval()
probe = EfficientProbe(768, 4, 4).cuda()
probe.load_state_dict(torch.load("best_probe.pth"))
probe.eval()

# Load sample image
dataset = SpaRRTaDataset("data/sparrta", "forest", "ego", "test")
image, label, metadata = dataset[0]

# Get attention maps
with torch.no_grad():
    features = model.extract_features(image.unsqueeze(0).cuda())
    logits, attention = probe(features)  # attention: [1, 4, 196]

# Visualize
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

# Original image
axes[0].imshow(image.permute(1, 2, 0).numpy())
axes[0].set_title("Original Image")
axes[0].axis("off")

# Attention maps for each query
attention = attention[0].cpu().numpy()  # [4, 196]

for i in range(4):
    attn_map = attention[i].reshape(14, 14)  # Reshape to spatial grid
    attn_map = np.clip(attn_map, 0, 1)

    # Resize to image size
    import cv2
    attn_resized = cv2.resize(attn_map, (224, 224))

    axes[i+1].imshow(image.permute(1, 2, 0).numpy())
    axes[i+1].imshow(attn_resized, alpha=0.6, cmap="jet")
    axes[i+1].set_title(f"Query {i+1}")
    axes[i+1].axis("off")

plt.tight_layout()
plt.savefig("attention_visualization.png", dpi=150)
plt.show()

Results Comparison Plot¶

import matplotlib.pyplot as plt
import numpy as np

# Data from evaluation
models = ["DINO", "DINO-v2", "DINOv3", "VGGT", "MAE", "CLIP"]
ego_acc = [89.28, 91.91, 93.93, 96.18, 93.10, 56.33]
allo_acc = [64.38, 71.20, 72.05, 76.65, 70.66, 54.36]

x = np.arange(len(models))
width = 0.35

fig, ax = plt.subplots(figsize=(12, 6))

bars1 = ax.bar(x - width/2, ego_acc, width, label='Egocentric', color='#7c4dff')
bars2 = ax.bar(x + width/2, allo_acc, width, label='Allocentric', color='#536dfe')

ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('SpaRRTa Performance by Model (Efficient Probing)', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(models, fontsize=11)
ax.legend()
ax.set_ylim(0, 100)

# Add value labels
for bar in bars1:
    height = bar.get_height()
    ax.annotate(f'{height:.1f}',
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3), textcoords="offset points",
                ha='center', va='bottom', fontsize=9)

for bar in bars2:
    height = bar.get_height()
    ax.annotate(f'{height:.1f}',
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3), textcoords="offset points",
                ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig("model_comparison.png", dpi=150)
plt.show()

Video Demo Integration¶

Embedding YouTube Video¶

To showcase video demos on your documentation pages:

<div class="video-container">
  <iframe 
    src="https://www.youtube.com/embed/VIDEO_ID" 
    title="SpaRRTa Demo"
    allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
    allowfullscreen>
  </iframe>
</div>

Advanced Examples¶

Custom Model Evaluation¶

import torch
import torch.nn as nn
from sparrta.models import register_model, VFMBase

@register_model("my_vit")
class MyViT(VFMBase):
    """Custom Vision Transformer wrapper."""

    def __init__(self, checkpoint_path=None):
        super().__init__()

        # Load your custom model
        from timm import create_model
        self.model = create_model(
            "vit_base_patch16_224",
            pretrained=True,
            num_classes=0,  # Remove classification head
        )

        if checkpoint_path:
            state_dict = torch.load(checkpoint_path)
            self.model.load_state_dict(state_dict)

    def extract_features(self, images):
        """Extract patch features."""
        # Get patch embeddings from ViT
        x = self.model.patch_embed(images)
        x = self.model._pos_embed(x)
        x = self.model.blocks(x)
        x = self.model.norm(x)

        # Return patch tokens (excluding CLS)
        return x[:, 1:, :]  # [B, 196, 768]

    @property
    def feature_dim(self):
        return 768

    @property
    def num_patches(self):
        return 196

# Use custom model
from sparrta import SpaRRTaEvaluator, load_vfm

model = load_vfm("my_vit", checkpoint_path="my_weights.pth")
evaluator = SpaRRTaEvaluator(data_path="data/sparrta")
results = evaluator.evaluate(model)

Multi-GPU Evaluation¶

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from sparrta import SpaRRTaEvaluator, load_vfm
import os

def setup_distributed():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    return local_rank

def main():
    local_rank = setup_distributed()

    model = load_vfm("dinov2_vitl14").cuda()

    evaluator = SpaRRTaEvaluator(
        data_path="data/sparrta",
        distributed=True,
        local_rank=local_rank,
    )

    results = evaluator.evaluate(
        model=model,
        environments="all",
    )

    if local_rank == 0:
        print(results.summary())
        results.save("distributed_results.json")

if __name__ == "__main__":
    main()

Run with:

torchrun --nproc_per_node=4 distributed_eval.py

Jupyter Notebook¶

For interactive exploration, see our Jupyter notebook:

# In Jupyter
%load_ext autoreload
%autoreload 2

from sparrta import SpaRRTaDataset, SpaRRTaEvaluator, load_vfm
import matplotlib.pyplot as plt

# Interactive visualization
dataset = SpaRRTaDataset("data/sparrta", "forest", "ego", "test")

# Use ipywidgets for interactive exploration
from ipywidgets import interact, IntSlider

@interact(idx=IntSlider(min=0, max=len(dataset)-1, step=1))
def show_sample(idx):
    image, label, meta = dataset[idx]
    plt.figure(figsize=(8, 8))
    plt.imshow(image.permute(1, 2, 0).numpy())
    plt.title(f"Source: {meta['source_object']['class']}\n"
              f"Target: {meta['target_object']['class']}\n"
              f"Label: {['Front', 'Back', 'Left', 'Right'][label]}")
    plt.axis("off")
    plt.show()