⏱️ 65 min

Explainable AI & Model Interpretation

Understand and explain AI model decisions

Why Explainability Matters

Deep learning models are often "black boxes"—we don't know why they make specific decisions. **Why We Need Explainability:** 1. **Trust**: Users need to trust AI decisions 2. **Debugging**: Find model errors and biases 3. **Compliance**: Legal requirements (GDPR, etc.) 4. **Safety**: Critical in healthcare, finance, autonomous vehicles 5. **Scientific Discovery**: Learn from what models learn **Types of Explanations:** **Global Explanations**: How does the model work overall? - Feature importance - Model behavior patterns - Decision rules **Local Explanations**: Why this specific prediction? - Which features influenced this decision? - What needs to change for different prediction? **Interpretability vs Explainability:** - **Interpretable**: Inherently understandable (linear regression, decision trees) - **Explainable**: Post-hoc explanations (neural networks with LIME, SHAP)

Feature Importance & SHAP

Explain model predictions with SHAP values:

python
import numpy as np
import torch
import torch.nn as nn
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# Create synthetic dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, 
                          n_redundant=2, random_state=42)
feature_names = [f'Feature_{i}' for i in range(10)]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Simple neural network
class SimpleNN(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 2)
        )
    
    def forward(self, x):
        return self.network(x)

# Train model
model = SimpleNN(10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

X_train_t = torch.FloatTensor(X_train)
y_train_t = torch.LongTensor(y_train)

print("Training model...")
for epoch in range(50):
    optimizer.zero_grad()
    outputs = model(X_train_t)
    loss = criterion(outputs, y_train_t)
    loss.backward()
    optimizer.step()

print(f"Training complete. Final loss: {loss.item():.4f}")

# Simple feature importance (gradient-based)
def compute_feature_importance(model, X):
    """Compute feature importance using gradients"""
    model.eval()
    X_tensor = torch.FloatTensor(X).requires_grad_(True)
    
    outputs = model(X_tensor)
    # Get prediction for class 1
    class_score = outputs[:, 1].sum()
    
    # Backward pass
    class_score.backward()
    
    # Importance = absolute gradient
    importance = torch.abs(X_tensor.grad).mean(dim=0).detach().numpy()
    return importance

# Compute importance
importance = compute_feature_importance(model, X_test)

print(f"\nFeature Importance (Gradient-based):")
sorted_idx = np.argsort(importance)[::-1]
for idx in sorted_idx[:5]:
    print(f"  {feature_names[idx]}: {importance[idx]:.4f}")

# SHAP-style explanation (simplified)
class SimpleSHAP:
    """Simplified SHAP implementation"""
    def __init__(self, model, background_data):
        self.model = model
        self.background_data = background_data
        self.baseline = background_data.mean(axis=0)
    
    def explain(self, instance):
        """Explain single prediction"""
        self.model.eval()
        
        # Baseline prediction
        baseline_t = torch.FloatTensor(self.baseline).unsqueeze(0)
        with torch.no_grad():
            baseline_pred = self.model(baseline_t)[0, 1].item()
        
        # Instance prediction
        instance_t = torch.FloatTensor(instance).unsqueeze(0)
        with torch.no_grad():
            instance_pred = self.model(instance_t)[0, 1].item()
        
        # Approximate SHAP values
        shap_values = np.zeros(len(instance))
        
        for i in range(len(instance)):
            # Create modified instance with feature i at baseline
            modified = instance.copy()
            modified[i] = self.baseline[i]
            modified_t = torch.FloatTensor(modified).unsqueeze(0)
            
            with torch.no_grad():
                modified_pred = self.model(modified_t)[0, 1].item()
            
            # SHAP value = difference in prediction
            shap_values[i] = instance_pred - modified_pred
        
        return shap_values, instance_pred, baseline_pred

# Explain a test instance
explainer = SimpleSHAP(model, X_train)
test_instance = X_test[0]

shap_values, pred, baseline = explainer.explain(test_instance)

print(f"\nSHAP Explanation for test instance:")
print(f"Baseline prediction: {baseline:.4f}")
print(f"Instance prediction: {pred:.4f}")
print(f"Prediction change: {pred - baseline:.4f}")

print(f"\nTop contributing features:")
sorted_idx = np.argsort(np.abs(shap_values))[::-1]
for idx in sorted_idx[:5]:
    direction = "increases" if shap_values[idx] > 0 else "decreases"
    print(f"  {feature_names[idx]}: {direction} prediction by {abs(shap_values[idx]):.4f}")
    print(f"    Value: {test_instance[idx]:.3f}")
Output:
Training model...
Training complete. Final loss: 0.3245

Feature Importance (Gradient-based):
  Feature_2: 0.0823
  Feature_1: 0.0756
  Feature_4: 0.0691
  Feature_3: 0.0534
  Feature_0: 0.0478

SHAP Explanation for test instance:
Baseline prediction: 0.4523
Instance prediction: 0.8934
Prediction change: 0.4411

Top contributing features:
  Feature_2: increases prediction by 0.1823
    Value: 1.456
  Feature_1: increases prediction by 0.1534
    Value: 0.923
  Feature_4: increases prediction by 0.0891
    Value: -0.234
  Feature_7: decreases prediction by 0.0267
    Value: -1.123
  Feature_3: increases prediction by 0.0198
    Value: 0.567

Grad-CAM for CNN Visualization

Visualize which parts of an image a CNN focuses on:

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class GradCAM:
    """Gradient-weighted Class Activation Mapping"""
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)
    
    def save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def generate_cam(self, input_image, target_class):
        """Generate Class Activation Map"""
        # Forward pass
        self.model.eval()
        output = self.model(input_image)
        
        # Zero gradients
        self.model.zero_grad()
        
        # Backward pass for target class
        target = output[0, target_class]
        target.backward()
        
        # Get gradients and activations
        gradients = self.gradients[0]  # (C, H, W)
        activations = self.activations[0]  # (C, H, W)
        
        # Global average pooling of gradients
        weights = gradients.mean(dim=(1, 2))  # (C,)
        
        # Weighted combination of activation maps
        cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
        for i, w in enumerate(weights):
            cam += w * activations[i]
        
        # ReLU (only positive influences)
        cam = F.relu(cam)
        
        # Normalize
        cam = cam - cam.min()
        cam = cam / cam.max()
        
        return cam.numpy()

# Create model
model = SimpleCNN()
print("CNN Model for Grad-CAM")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Create Grad-CAM
gradcam = GradCAM(model, model.conv3)

# Dummy input (simulating an image)
dummy_image = torch.randn(1, 3, 32, 32)

# Generate CAM for class 5
target_class = 5
cam = gradcam.generate_cam(dummy_image, target_class)

print(f"\nGenerated Grad-CAM for class {target_class}")
print(f"CAM shape: {cam.shape}")
print(f"CAM value range: [{cam.min():.3f}, {cam.max():.3f}]")

# Visualize important regions (simulate)
print(f"\nHigh-attention regions (value > 0.7):")
high_attention = np.where(cam > 0.7)
if len(high_attention[0]) > 0:
    print(f"  {len(high_attention[0])} pixels with high attention")
    print(f"  Centered around: row {high_attention[0].mean():.1f}, col {high_attention[1].mean():.1f}")
else:
    print(f"  Attention more distributed")

print(f"\n--- Grad-CAM Benefits ---")
print(f"✓ Visual explanation of CNN decisions")
print(f"✓ Identify which regions influence prediction")
print(f"✓ Debug model focus (e.g., looking at background vs object)")
print(f"✓ Works with any CNN architecture")
Output:
CNN Model for Grad-CAM
Total parameters: 1,394,762

Generated Grad-CAM for class 5
CAM shape: (4, 4)
CAM value range: [0.000, 1.000]

High-attention regions (value > 0.7):
  3 pixels with high attention
  Centered around: row 1.7, col 2.3

--- Grad-CAM Benefits ---
✓ Visual explanation of CNN decisions
✓ Identify which regions influence prediction
✓ Debug model focus (e.g., looking at background vs object)
✓ Works with any CNN architecture

Explainability Tools & Best Practices

**Popular XAI Libraries:** | Tool | Best For | Methods | |------|----------|---------| | **SHAP** | Any model type | Shapley values, unified framework | | **LIME** | Local explanations | Perturb inputs, fit linear model | | **Captum** | PyTorch models | Integrated gradients, Grad-CAM | | **InterpretML** | Microsoft | Glass-box & black-box models | | **Alibi** | Production ML | Counterfactuals, anchors | **Explanation Methods:** **Feature Attribution:** - **SHAP**: Game theory-based, consistent - **Integrated Gradients**: Path integral from baseline - **Layer-wise Relevance Propagation (LRP)**: Backpropagate relevance **Example-Based:** - **Prototypes**: Representative examples - **Counterfactuals**: "What needs to change?" - **Influence Functions**: Training data influence **Model-Specific:** - **Attention Weights**: For transformers - **Grad-CAM**: For CNNs - **Decision Rules**: For tree ensembles **Best Practices:** 1. **Multiple Methods**: Use complementary techniques 2. **Sanity Checks**: Verify explanations make sense 3. **User Studies**: Test with domain experts 4. **Document Limitations**: Explanations can be misleading 5. **Actionability**: Provide actionable insights **Pitfalls to Avoid:** - ❌ Over-trusting explanations (they can be manipulated) - ❌ Explaining wrong predictions (fix the model first) - ❌ One-size-fits-all (different stakeholders need different explanations) - ❌ Ignoring context (explanations need domain knowledge)

Sharan Initiatives - Making a Difference Together