Skip to content

imjbassi/ResNet

Repository files navigation

ResNet vs Plain CNN: Understanding Residual Learning

Python 3.8+ PyTorch License: MIT

A comprehensive educational project demonstrating why residual connections matter in deep learning through hands-on implementation and visualization.

Project Goal

This project provides a complete, reproducible implementation that:

  1. Implements ResNet-18 from scratch in PyTorch
  2. Compares it against a Plain CNN with the same depth
  3. Visualizes training dynamics, gradient flow, and attention maps
  4. Explains the theory behind residual learning

Key Results

Model Best Test Accuracy Parameters
ResNet-18 ~92% 11.2M
Plain CNN-18 ~82% 11.2M
Improvement +10% Same

Key Insight: With identical depth and parameters, ResNet dramatically outperforms the plain network, demonstrating that residual connections solve the optimization problem, not just add capacity.

Architecture

Residual Block

Input x
    │
    ├──────────────────┐
    │                  │
Conv 3×3              │
    │               Identity
BN + ReLU            Shortcut
    │                  │
Conv 3×3              │
    │                  │
   BN                  │
    │                  │
    └───── + ─────────┘
           │
        ReLU
           │
       Output y = F(x) + x

Why This Matters

Instead of learning H(x) directly, the network learns the residual F(x) = H(x) - x:

  • Easy Identity: If optimal is identity, just push F(x) → 0
  • Gradient Highway: Identity path ensures gradients flow directly
  • Additive Learning: Each block adds a small correction

Project Structure

resnet_project/
│
├── model.py              # ResNet-18 and Plain CNN implementations
├── train.py              # Training script with comparison experiment
├── utils.py              # Training utilities and metrics
├── visualize.py          # Static visualization generation
├── gradcam.py            # Grad-CAM attention visualization
├── app.py                # Interactive Streamlit dashboard
├── requirements.txt      # Python dependencies
└── README.md             # This file

Quick Start

1. Installation

# Clone the repository
git clone https://github.com/imjbassi/ResNet.git
cd ResNet

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install -r requirements.txt

2. Run Training Comparison

# Train both models and compare (50 epochs, ~30 min on GPU)
python train.py --compare --epochs 50

# Or train individually
python train.py --model resnet --epochs 50
python train.py --model plain --epochs 50

3. Generate Visualizations

# Generate all visualizations
python visualize.py --results-dir ./results/TIMESTAMP

# Or run with demo data
python visualize.py --demo

4. Launch Interactive Dashboard

streamlit run app.py

Visualizations

Training Curves

Compares loss and accuracy between ResNet and Plain CNN over training epochs.

Training Curves

Accuracy Gap

Shows how ResNet's advantage over Plain CNN grows during training.

Accuracy Gap

Gradient Flow Analysis

Shows how gradients propagate through layers - ResNet maintains consistent gradients while Plain CNN suffers from vanishing gradients.

Gradient Flow

Residual Block Architecture

Visual explanation of the residual block structure and skip connections.

Residual Block

Summary Dashboard

Complete training analysis dashboard with all metrics.

Summary Dashboard

Key Concepts Demonstrated

1. The Degradation Problem

# Plain networks degrade with depth
plain_18_layer = 82%  # accuracy
plain_34_layer = 78%  # WORSE with more layers!

2. Residual Learning Solution

class ResidualBlock(nn.Module):
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        # THE KEY: Skip connection
        out += self.shortcut(x)  # Add input directly
        
        return F.relu(out)

3. Gradient Flow

∂L/∂x = ∂L/∂y × (∂F(x)/∂x + 1)
                            ↑
                    Always has a path!

Educational Value

This project teaches:

  1. Deep Learning Fundamentals

    • Vanishing gradient problem
    • Batch normalization
    • Modern training techniques
  2. PyTorch Best Practices

    • Model modularization
    • Custom training loops
    • Visualization hooks
  3. Research Methodology

    • Controlled experiments
    • Ablation studies
    • Results visualization
  4. Software Engineering

    • Code organization
    • Documentation
    • Reproducibility

Resume-Ready Bullet Point

Implemented and trained ResNet-18 from scratch on CIFAR-10, demonstrating 10% accuracy improvement over a plain CNN through residual connections. Created comprehensive visualizations including Grad-CAM attention maps to explain model behavior.

Configuration

Training Parameters

Parameter Default Description
--epochs 50 Number of training epochs
--batch-size 128 Batch size for training
--lr 0.1 Initial learning rate
--weight-decay 5e-4 L2 regularization

Scheduler

Cosine annealing learning rate schedule for smooth convergence.

References

  1. Deep Residual Learning for Image Recognition

  2. Grad-CAM: Visual Explanations from Deep Networks

Contributing

Contributions welcome! Please feel free to submit issues and pull requests.

License

MIT License - feel free to use this code for learning and projects.


Extensions (Advanced)

  1. Medical Imaging: Replace CIFAR-10 with X-ray or MRI data
  2. Transfer Learning: Fine-tune pretrained ResNet on custom dataset
  3. Architecture Variants: Implement ResNet-34, ResNet-50, ResNeXt
  4. Efficiency Analysis: Measure FLOPs, inference time, memory usage
  5. Attention Mechanisms: Add SE blocks or CBAM modules

Built for learning deep learning fundamentals

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages