Deep Learning & XAI Academic Research ICLR 2025

XAIguiFormer: Explainable AI Guided Transformer for Brain Disorder Classification

ML Reproducibility Challenge for ICLR 2025 paper. Novel architecture leveraging explainable AI not just for interpretability, but as an active component to enhance transformer performance in EEG-based brain disorder classification through connectome analysis.

1. Context & Research Challenge

This project is part of the ML Reproducibility Challenge, reproducing and validating the ICLR 2025 paper "XAIguiFormer: Explainable Artificial Intelligence Guided Transformer for Brain Disorder Identification" by Guo et al. The challenge focuses on EEG-based classification of brain disorders (ADHD, MDD, OCD) using a novel transformer architecture that integrates XAI as a performance enhancer rather than just an interpretability tool.

Research Objectives

  • Reproducibility Validation: Implement and verify the original paper's claims on TDBRAIN dataset (88 patients, 4 diagnostic categories)
  • XAI-Guided Learning: Leverage DeepLIFT explanations to refine self-attention mechanisms during training
  • Connectome Preservation: Maintain graph topological structure through atomic tokenization of frequency bands
  • Performance Enhancement: Achieve ~66% balanced accuracy with XAI guidance vs ~63% baseline

2. Core Architectural Innovations

XAIguiFormer introduces four major innovations that revolutionize the use of explainable AI in deep learning.

1. Connectome Tokenizer

  • Atomic Graph Tokens: Treats single-band graphs as indivisible units to preserve topological structure
  • GNN Encoding: GINEConv-based encoder with inductive bias for small datasets
  • Frequency Sequences: Generates temporal sequences without fragmenting connectivity patterns

2. dRoFE (Demographic Rotary Frequency Encoding)

  • Demographic Integration: Encodes age and gender into token embeddings via rotary matrices
  • RoPE Adaptation: Inspired by Rotary Position Encoding but tailored for EEG frequency characteristics
  • Individual Difference Mitigation: Reduces inter-subject variability while preserving diagnostic signals

3. XAI-Guided Self-Attention

  • Concurrent Explanations: Multi-layer DeepLIFT applied during forward pass for real-time guidance
  • Query/Key Refinement: Feature importance scores modulate attention matrices dynamically
  • Dual-Pass Architecture: Standard transformer followed by XAI-refined predictions

4. Dual Loss Function

  • Combined Objectives: Coarse (standard) + refined (XAI-guided) predictions with α weighting
  • End-to-End Training: XAI guidance as trainable supervision signal (α=0.7 optimal)
  • Class Balancing: Weighted loss for imbalanced diagnostic categories

3. Architecture Pipeline

End-to-end pipeline from raw EEG signals to explainable brain disorder predictions.

  • Step 1: Raw EEG signals → Multi-band connectomes (COH + wPLI connectivity)
  • Step 2: GNN Tokenizer → Frequency-embedded graph sequences with dRoFE encoding
  • Step 3: Transformer Encoder → Standard predictions with concurrent DeepLIFT explanations
  • Step 4: XAI-Guided Refinement → Query/Key matrix adjustment based on feature importance
  • Step 5: Final Predictions → Dual loss optimization (coarse + refined)

4. Technical Implementation

Modern deep learning stack for neuroscience research with reproducibility guarantees.

  • Deep Learning: PyTorch 2.0+ with CUDA support, PyTorch Geometric for graph neural networks
  • Explainability: Captum for DeepLIFT multi-layer explanations, custom XAI wrapper classes
  • Neuroscience: MNE-Python for EEG preprocessing (PREP pipeline, ICA), MNE-Connectivity for connectome construction
  • Environment: Poetry for dependency management, YAML configs for reproducibility, fixed random seeds

5. Data Processing Pipeline

Three-stage pipeline for converting raw EEG to graph-structured tokens.

TDBRAIN Dataset

  • Size: 88 patients across 4 diagnostic categories (ADHD, MDD, OCD, Healthy Controls)
  • Modality: Eyes-closed resting-state EEG, 26 channels, BIDS format
  • Challenge: Small dataset requiring careful splitting, regularization, and adaptive batching

Preprocessing Pipeline

  • Stage 1: PREP pipeline (filtering, bad channel detection, ICA artifact removal, epoching)
  • Stage 2: Multi-band connectome construction (COH + wPLI for 5 frequency bands: delta, theta, alpha, beta, gamma)
  • Stage 3: Graph aggregation into PyTorch Geometric batches with metadata preservation

6. Training & Evaluation

Comprehensive training pipeline with XAI-guided supervision and robust evaluation metrics.

  • Training Features: Adaptive dataset splitting for small datasets, concurrent XAI explanation generation, dual-pass architecture with shared weights
  • Evaluation Metrics: Balanced Accuracy (BAC) as primary metric for class imbalance, AUROC/AUC-PR for comprehensive assessment
  • Interpretability: Attention entropy for quantifying attention concentration, frequency importance analysis (theta/beta ratio as biomarker)
  • Optimization: Gradient clipping for XAI refinement stability, memory-efficient torch-scatter fallback implementation

7. Results & Reproducibility

Successful reproduction of original paper's claims with validated performance improvements.

Performance Metrics

  • TDBRAIN BAC: ~66% with XAI guidance vs ~63% baseline (4.8% improvement)
  • Attention Quality: Lower entropy indicating focused attention patterns on diagnostic features
  • Biomarker Discovery: Theta/beta frequency ratio confirmed as primary diagnostic signal

Configuration Sensitivity

  • Alpha Parameter: XAI guidance weight (α=0.7 optimal for TDBRAIN)
  • Learning Rate: Dataset-specific tuning required (1e-4 to 5e-4 range)
  • Batch Size: Adaptive sizing for small datasets with memory constraints

8. Conclusion & Future Directions

This reproducibility challenge successfully validates the XAIguiFormer architecture, demonstrating that explainable AI can serve as an active performance enhancer beyond traditional post-hoc interpretability. The dual-pass XAI-guided approach shows consistent improvements on EEG-based brain disorder classification, with potential for broader applications in neuroscience and clinical deployment.

Future directions include validation on larger clinical datasets (TUAB with 2,993 sessions), integration of foundation model pre-training strategies, multi-modal fusion with other neuroimaging modalities (fMRI, MEG), real-time optimization for clinical deployment, and federated learning for privacy-preserving multi-site training.

Technologies & Resources

Key Technologies

Academic References

  • Guo et al. "XAIguiFormer: Explainable Artificial Intelligence Guided Transformer for Brain Disorder Identification." ICLR 2025.
  • Shrikumar, A., Greenside, P., & Kundaje, A. "Learning Important Features Through Propagating Activation Differences." ICML 2017.

Project Information

Challenge: ML Reproducibility Challenge for ICLR 2025

Dataset: TDBRAIN (88 patients, 4 diagnostic categories)

Contact: For repository access or technical inquiries, contact Martin LE CORRE

Documentation: 📄 View detailed README