Fine - Grained Image Classification

Last Updated : 23 Jul, 2025

Traditional image classification divides images into generic classes (e.g., cats vs. dogs). On the other hand, fine-grained image classification (FGIC) tries to identify images between visually similar subcategories, like dissimilar breeds of dogs or different automobile models.

Fine-Grained Image Classification is the process of labeling images into subcategories with similar visual characteristics.

Examples are:

  • Pigeon vs. sparrow bird species identification from images.
  • Identifying car models (e.g., Tesla Model 3 vs. Tesla Model S).
  • Differentiating plant species for agriculture or conservation.

Techniques for Fine-Grained Image Classification

To overcome the above challenges, researchers employ various strategies:

1. Part-Based Models: The models identify and examine particular portions of the object (e.g., the tail, wings, and head of a bird). This assists with the detection of fine differences.

2. Attention Mechanisms: Attention modules aid the model to concentrate on distinguishing parts of the image that are most useful in classification.

3. Metric Learning: Rather than classifying immediately, metric learning teaches the model to learn a space where analogous instances are nearby.

4. Data Augmentation: Advanced data augmentation methods such as mixup, CutMix, and pose-based augmentations are employed for enhancing generalization.

5. Transfer Learning: Fine-tuned pre-trained models (e.g., ResNet, EfficientNet) are employed over fine-grained datasets to take advantage of their acquired low-level and high-level features.

  • Caltech-UCSD Birds-200 (CUB-200): 200 bird species with over 11,000 images.
  • Stanford Cars: 16,000 images of 196 car models.
  • Oxford Flowers 102: 102 categories of flowers.
  • iNaturalist: Large-scale dataset for species classification.

Implementation

1. Install Required Libraries

Ensure you have the necessary dependencies installed:

pip install torch torchvision matplotlib

2. Load the Dataset

We'll use the CIFAR-10 dataset.

  • Preprocessed using resizing, normalizing, and tensors.
  • Loading is performed utilizing ImageFolder and encapsulated with DataLoader to have batch training.
Python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torchvision import models

# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms (resize to 224x224 for ResNet)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Required for ResNet input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# Class names
c_n = train_dataset.classes
print("Classes:", c_n)

Output:

Classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

3. Define and Train the Model

  • A pre-trained ResNet18 model is loaded and is accordingly modified.
  • We specify the loss function (CrossEntropyLoss) and optimizer (Adam).
  • The model is trained for some number of epochs, where at each step it learns to reduce the loss and enhance prediction accuracy.
Python
# Load pretrained ResNet18
model = models.resnet18(pretrained=True)

# Replace the final layer for CIFAR-10 (10 classes)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)

model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
def train_model(model, train_loader, criterion, optimizer, num_epochs=3):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100*correct/total:.2f}%")
        
# Run training
train_model(model, train_loader, criterion, optimizer, num_epochs=3)

Output:

Epoch [1/3] - Loss: 0.7729, Accuracy: 74.02%

Epoch [2/3] - Loss: 0.6525, Accuracy: 77.59%

Epoch [3/3] - Loss: 0.6322, Accuracy: 78.23%

4. Evaluate the Model

  • Once trained, we set the model to evaluation mode.
  • We evaluate the model on unseen test images and compute the accuracy by comparing predicted and actual labels.
Python
# Evaluation
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

# Run evaluation
evaluate_model(model, test_loader)

Output:

Test Accuracy: 80.53%

5. Visualize Predictions

  • We select some images from the test set.
  • Show them with matplotlib and the model
Python
# Function to visualize some predictions
def imshow(img):
    img = img.numpy().transpose((1, 2, 0))
    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Unnormalize
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    plt.axis("off")

# Display predictions
images, labels = next(iter(test_loader))
images, labels = images.to(device), labels.to(device)

outputs = model(images)
_, preds = torch.max(outputs, 1)

fig = plt.figure(figsize=(10, 10))
for i in range(9):
    ax = fig.add_subplot(3, 3, i+1)
    imshow(images[i].cpu())
    ax.set_title(f"Pred: {class_names[preds[i]]}\nActual: {class_names[labels[i]]}")
plt.show()

Output:

fine_grained
Predictions

You can download the source code here

Applications

Fine-Grained Image Classification has diverse real-world applications across various industries:

1. Wildlife Conservation

  • Helps in species identification for ecological studies.
  • Used in endangered species monitoring via camera traps.
  • AI models classify bird species, insects, or marine animals from images.

2. Medical Diagnosis

  • Distinguishes tumor subtypes in histopathology images.
  • Identifies different stages of diabetic retinopathy.
  • Differentiates between skin lesion types for early cancer detection.

3. E-commerce and Fashion

  • Recognizes clothing attributes (e.g., dress patterns, sleeve types).
  • Helps in product recommendation systems based on visual similarity.
  • Identifies fake vs. real branded items (e.g., Nike, Adidas).

4. Autonomous Vehicles

  • Recognizes different car models and makes.
  • Helps detect traffic signs with fine-grained details.
  • Improves pedestrian recognition based on clothing attributes.

Challenges in Fine-Grained Classification

Fine-grained classification is challenging because:

  1. Small inter-class variance: Subcategories have very similar features.
  2. Large intra-class variance: The same subcategory can appear in different conditions (e.g., lighting, pose).
  3. Data annotation difficulty: Requires expert knowledge (e.g., botanists for plant classification).
  4. Background distractions: Non-relevant background elements can confuse models.
Comment