Skip to content

A comparative study of different optimizers (SGD, Adam, RMSprop) for handwritten digit classification using deep neural networks on the MNIST dataset.

Notifications You must be signed in to change notification settings

0xafraidoftime/MNIST-Digit-Classification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 

Repository files navigation

MNIST Digit Classification with Neural Networks

A comparative study of different optimizers (SGD, Adam, RMSprop) for handwritten digit classification using deep neural networks on the MNIST dataset.

Overview

This project implements and compares three neural network models with different optimizers to classify handwritten digits from the MNIST dataset. The goal is to analyze the performance differences between SGD, Adam, and RMSprop optimizers on the same network architecture.

Dataset

  • Source: MNIST dataset via scikit-learn's fetch_openml
  • Size: 70,000 images (28x28 pixels)
  • Classes: 10 digits (0-9)
  • Split: 80% training, 20% testing
  • Preprocessing: Pixel values normalized to [0, 1] range

Model Architecture

The project implements two different neural network architectures:

Initial Architecture (3 Hidden Layers)

  • Input Layer: 784 neurons (28×28 flattened pixels)
  • Hidden Layer 1: 50 neurons, ReLU activation
  • Hidden Layer 2: 60 neurons, ReLU activation
  • Hidden Layer 3: 30 neurons, ReLU activation
  • Output Layer: 10 neurons, Softmax activation

Enhanced Architecture (4 Hidden Layers)

  • Input Layer: 784 neurons (28×28 flattened pixels)
  • Hidden Layer 1: 100 neurons, ReLU activation
  • Hidden Layer 2: 80 neurons, ReLU activation
  • Hidden Layer 3: 50 neurons, ReLU activation
  • Hidden Layer 4: 30 neurons, ReLU activation
  • Output Layer: 10 neurons, Softmax activation

Requirements

pip install tensorflow
pip install scikit-learn
pip install numpy
pip install matplotlib

Dependencies

  • tensorflow - Deep learning framework
  • scikit-learn - Dataset loading and preprocessing
  • numpy - Numerical computations
  • matplotlib - Data visualization and plotting

Usage

1. Data Preparation

The script automatically:

  • Downloads the MNIST dataset
  • Normalizes pixel values to [0, 1]
  • Splits data into training/testing sets (80/20)
  • Converts labels to one-hot encoding
  • Reshapes data for neural network input

2. Model Training and Evaluation

The project includes two experimental setups:

Experiment 1: Basic Optimizer Comparison

Three models trained separately with different optimizers:

# SGD Optimizer
dnnModel.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

# Adam Optimizer  
dnnModel_adam.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# RMSprop Optimizer
dnnModel_rmsprop.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

Experiment 2: Comprehensive Optimizer Study

Automated comparison of 5 optimizers on an enhanced architecture:

  • SGD
  • RMSprop
  • Adam
  • Nadam
  • Adagrad

3. Results and Visualization

The project provides comprehensive performance analysis:

  • Training Metrics: Final training accuracy for each optimizer
  • Test Evaluation: Test loss and accuracy on unseen data
  • Visual Comparison:
    • Training accuracy curves across epochs
    • Training loss curves across epochs
    • Side-by-side plots for easy comparison

3. Training Parameters

  • Epochs: 25
  • Batch Size: 64
  • Validation Split: 10% of training data
  • Loss Function: Categorical Crossentropy
  • Metrics: Accuracy

File Structure

project/
├── main.py                 # Main training script
├── README.md              # This file
└── requirements.txt       # Dependencies

Expected Results

The project outputs comprehensive performance metrics:

Training Results

  • Final training accuracy for each optimizer
  • Training history objects containing epoch-by-epoch metrics

Test Evaluation

  • Test loss and accuracy for each model
  • Performance comparison on unseen data

Visualizations

  • Training accuracy curves comparing all optimizers
  • Training loss curves showing convergence patterns
  • Side-by-side subplot layout for easy comparison

Sample Output

SGD Final training accuracy: 0.9234
Adam Final training accuracy: 0.9567
RMSProp Final training accuracy: 0.9445

SGD Test accuracy: 0.9123
Adam Test accuracy: 0.9421
RMSProp Test accuracy: 0.9334

Key Features

  • Data Preprocessing: Automatic normalization and reshaping
  • One-Hot Encoding: Converts integer labels to categorical format
  • Comprehensive Optimizer Comparison: Evaluation of 5 optimizers (SGD, Adam, RMSprop, Nadam, Adagrad)
  • Performance Evaluation: Both training and test set accuracy/loss metrics
  • Visualization: Training curves for accuracy and loss comparison
  • Two Model Architectures: Initial 3-layer and enhanced 4-layer networks
  • Validation Monitoring: 10% validation split for performance tracking

Model Performance Tracking

Each model training returns a history object (h_sgd, h_adam, h_rms) containing:

  • Training accuracy per epoch
  • Training loss per epoch
  • Validation accuracy per epoch
  • Validation loss per epoch

Future Enhancements

  • Add visualization of training curves
  • Implement test set evaluation
  • Add confusion matrix analysis
  • Compare training time across optimizers
  • Experiment with different architectures
  • Add early stopping and learning rate scheduling

Notes

  • The models use dense (fully connected) layers rather than convolutional layers
  • Input images are flattened from 28×28 to 784-dimensional vectors
  • All models share the same random weight initialization for fair comparison

Contributing

Feel free to experiment with:

  • Different optimizer parameters (learning rates, momentum)
  • Alternative network architectures
  • Additional regularization techniques
  • Different activation functions

License

This project is open source and available under the MIT License.

About

A comparative study of different optimizers (SGD, Adam, RMSprop) for handwritten digit classification using deep neural networks on the MNIST dataset.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published