graph TD
Input["Input Image"] --> Transforms["torchvision.transforms<br>(Resize 224x224, Normalize)"]
Transforms --> Backbone["LeViT-128s Backbone<br>(Feature extraction)"]
Backbone --> Head["Custom Conv Head<br>(1x1 Conv + Pooling)"]
Head --> Prediction["Binary Prediction (AI vs Real)"]
Prediction --> Heatmap["Grad-CAM Heatmap"]
Prediction --> Explainer["SHAP Explainer"]
AI vs. Real Image Classification
Apr 2025
PyTorch
timm
Transformer
XAI
Machine Learning
A machine learning classification pipeline using PyTorch and a LeViT transformer backbone, featuring explainability mappings (Grad-CAM, SHAP).
Project Overview
AI vs. Real Image Classification is a binary classification pipeline designed to distinguish authentic photographs from AI-generated visual content. Built on PyTorch and PyTorch Image Models (timm), it uses a LeViT-128s vision transformer backbone, and incorporates Explainable AI (XAI) models to visualize model decisions.
Problem
- Synthetic Media Propagation: Generative adversarial networks (GANs) and diffusion models produce highly realistic synthetic images, increasing fraud risks.
- Transformer Latency: Deploying standard vision transformers (e.g. ViT-Base) on local edge nodes is blocked by heavy parameter counts and slow inference speeds.
- Black-Box Decisions: Deep neural networks do not natively show what visual regions drive predictions, preventing trust validation.
Features
- Hybrid Vision Transformer: Leverages a lightweight LeViT-128s backbone combining convolutional stages with attention maps.
- Transfer Learning Configuration: Ingests ImageNet pre-trained weights, swapping the classification head for custom 1x1 convolutions and average pooling.
- Data Augmentation Pipeline: Integrates spatial resizing, tensor mapping, and batch normalization via
torchvision.transforms. - Explainable AI (XAI) Integrations: Implements Grad-CAM and SHAP to render visual heatmaps highlighting artifact locations.
- Inference Speed Optimization: Compresses the weight file parameters to a 28MB memory footprint for faster loading.
Tech Stack
- Model & Training:
- PyTorch
- torchvision
- timm (PyTorch Image Models)
- LeViT-128s
- Explainability & Metrics:
- Grad-CAM
- SHAP
- NumPy / Pandas
- Development:
- Python
Architecture
My Contributions
- Built the binary classification pipeline in PyTorch utilizing
timmloaders. - Implemented ImageNet transfer learning hooks and replaced classification heads.
- Integrated Grad-CAM and SHAP explainability visualizers mapping prediction scores.
- Structured the torchvision preprocessing transforms and data loading classes.
- Programmed the optimized
inference.pyscript for local execution.
What I Learned
- Applying transfer learning parameters on vision transformer models.
- Implementing local explanation frameworks (Grad-CAM, SHAP).
- Evaluating convolutional vs. attention-based feature extraction paths.
- Optimizing weights compression for CPU deployment.
Results
- Successfully trained a lightweight classifier with a 28MB footprint.
- CAM heatmaps successfully identified that model decisions were driven by unnatural edge gradients and blending artifacts typical of diffusion models.
Future Work
- Train on multi-category datasets covering various diffusion engines (Midjourney, DALL-E 3).
- Integrate adversarial training loops to harden model robustness.
- Expose predictions as a local browser app using WebAssembly.
Links
- GitHub Repository: https://github.com/yuvraj-rathod-1202/HackRush25-ML