AI vs. Real Image Classification

Apr 2025

PyTorch
timm
Transformer
XAI
Machine Learning
A machine learning classification pipeline using PyTorch and a LeViT transformer backbone, featuring explainability mappings (Grad-CAM, SHAP).
Published

April 20, 2025

GitHub

Project Overview

AI vs. Real Image Classification is a binary classification pipeline designed to distinguish authentic photographs from AI-generated visual content. Built on PyTorch and PyTorch Image Models (timm), it uses a LeViT-128s vision transformer backbone, and incorporates Explainable AI (XAI) models to visualize model decisions.

Problem

  • Synthetic Media Propagation: Generative adversarial networks (GANs) and diffusion models produce highly realistic synthetic images, increasing fraud risks.
  • Transformer Latency: Deploying standard vision transformers (e.g. ViT-Base) on local edge nodes is blocked by heavy parameter counts and slow inference speeds.
  • Black-Box Decisions: Deep neural networks do not natively show what visual regions drive predictions, preventing trust validation.

Features

  • Hybrid Vision Transformer: Leverages a lightweight LeViT-128s backbone combining convolutional stages with attention maps.
  • Transfer Learning Configuration: Ingests ImageNet pre-trained weights, swapping the classification head for custom 1x1 convolutions and average pooling.
  • Data Augmentation Pipeline: Integrates spatial resizing, tensor mapping, and batch normalization via torchvision.transforms.
  • Explainable AI (XAI) Integrations: Implements Grad-CAM and SHAP to render visual heatmaps highlighting artifact locations.
  • Inference Speed Optimization: Compresses the weight file parameters to a 28MB memory footprint for faster loading.

Tech Stack

  • Model & Training:
    • PyTorch
    • torchvision
    • timm (PyTorch Image Models)
    • LeViT-128s
  • Explainability & Metrics:
    • Grad-CAM
    • SHAP
    • NumPy / Pandas
  • Development:
    • Python

Architecture

graph TD
    Input["Input Image"] --> Transforms["torchvision.transforms<br>(Resize 224x224, Normalize)"]
    Transforms --> Backbone["LeViT-128s Backbone<br>(Feature extraction)"]
    Backbone --> Head["Custom Conv Head<br>(1x1 Conv + Pooling)"]
    Head --> Prediction["Binary Prediction (AI vs Real)"]
    Prediction --> Heatmap["Grad-CAM Heatmap"]
    Prediction --> Explainer["SHAP Explainer"]

My Contributions

  • Built the binary classification pipeline in PyTorch utilizing timm loaders.
  • Implemented ImageNet transfer learning hooks and replaced classification heads.
  • Integrated Grad-CAM and SHAP explainability visualizers mapping prediction scores.
  • Structured the torchvision preprocessing transforms and data loading classes.
  • Programmed the optimized inference.py script for local execution.

What I Learned

  • Applying transfer learning parameters on vision transformer models.
  • Implementing local explanation frameworks (Grad-CAM, SHAP).
  • Evaluating convolutional vs. attention-based feature extraction paths.
  • Optimizing weights compression for CPU deployment.

Results

  • Successfully trained a lightweight classifier with a 28MB footprint.
  • CAM heatmaps successfully identified that model decisions were driven by unnatural edge gradients and blending artifacts typical of diffusion models.

Future Work

  • Train on multi-category datasets covering various diffusion engines (Midjourney, DALL-E 3).
  • Integrate adversarial training loops to harden model robustness.
  • Expose predictions as a local browser app using WebAssembly.