📖 引言
在前面几篇里,我们从 卷积的基本原理 到 KSC 全像素预测,逐步掌握了 CNN 的使用方法。 很多同学会问: 👉 “为什么有的论文用 LeNet,有的用 VGG,还有的用 ResNet?” 👉 “这些结构到底区别在哪里?”
本篇就带大家认识 CNN 的经典发展路线,并通过实战在 KSC 数据上比较它们的表现。
🧩 CNN 架构发展脉络
-
LeNet(1998)
-
最早的 CNN,用于手写数字识别。
-
特点:卷积+池化交替,最后接全连接。结构浅但奠定了基础。
-
-
VGG(2014)
-
提出“小卷积核堆叠”的思想,用 3×3 卷积堆叠代替大卷积核。
-
特点:层数深,参数多,效果好,但计算开销大。
-
-
ResNet(2015)
-
引入残差连接(skip connection),解决了深层 CNN 梯度消失的问题。
-
特点:可训练更深网络,性能优异。
-
⚙️ 实验设置
-
数据:KSC (PCA 降维至 30)
-
任务:像素级分类(训练:有标签像素;预测:测试像素)
-
评估:OA / AA / Kappa + 混淆矩阵
-
对比网络:LeNet / VGG-like / ResNet-like
💻 完整代码(可直接运行)
# -*- coding: utf-8 -*-
"""
案例③-4:经典 CNN 架构篇(LeNet / VGG / ResNet)
"""
import os, numpy as np, scipy.io as sio
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, cohen_kappa_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt, matplotlib
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
# ===== 参数 =====
DATA_DIR = r"your_path" # 修改为你的路径
PCA_DIM = 30
TRAIN_RATIO = 0.3
EPOCHS = 15
BATCH = 256
LR = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
# ===== 1. 数据加载 =====
X = sio.loadmat(os.path.join(DATA_DIR,"KSC.mat"))["KSC"].astype(np.float32)
Y = sio.loadmat(os.path.join(DATA_DIR,"KSC_gt.mat"))["KSC_gt"].astype(int)
H,W,B = X.shape
num_classes = Y.max()
coords = np.argwhere(Y != 0)
labels = Y[coords[:,0], coords[:,1]] - 1
train_ids, test_ids = train_test_split(
np.arange(len(coords)), train_size=TRAIN_RATIO,
stratify=labels, random_state=SEED
)
train_pixels_raw = X[coords[train_ids,0], coords[train_ids,1]]
scaler = StandardScaler().fit(train_pixels_raw)
pca = PCA(n_components=PCA_DIM, random_state=SEED).fit(scaler.transform(train_pixels_raw))
X_flat = X.reshape(-1, B)
X_flat_std = scaler.transform(X_flat)
X_flat_pca = pca.transform(X_flat_std)
X_pca = X_flat_pca.reshape(H,W,PCA_DIM)
X_train = X_pca[coords[train_ids,0], coords[train_ids,1]]
y_train = labels[train_ids]
X_test = X_pca[coords[test_ids,0], coords[test_ids,1]]
y_test = labels[test_ids]
class HsiDataset(Dataset):
def __init__(self, X, y):
self.X = torch.from_numpy(X).float().unsqueeze(1) # [N,1,P]
self.y = torch.from_numpy(y).long()
def __len__(self): return len(self.y)
def __getitem__(self, i): return self.X[i], self.y[i]
train_loader = DataLoader(HsiDataset(X_train,y_train), batch_size=BATCH, shuffle=True)
test_loader = DataLoader(HsiDataset(X_test,y_test), batch_size=BATCH, shuffle=False)
# ===== 2. 定义三种经典架构 =====
class LeNet1D(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.conv = nn.Sequential(
nn.Conv1d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
nn.AvgPool1d(2),
nn.Conv1d(6, 16, kernel_size=5, padding=2), nn.ReLU(),
nn.AvgPool1d(2)
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(16*(PCA_DIM//4), 120), nn.ReLU(),
nn.Linear(120, 84), nn.ReLU(),
nn.Linear(84, num_classes)
)
def forward(self,x):
return self.fc(self.conv(x))
class VGG1D(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Conv1d(1, 32, 3, padding=1), nn.ReLU(),
nn.Conv1d(32, 32, 3, padding=1), nn.ReLU(),
nn.MaxPool1d(2),
nn.Conv1d(32, 64, 3, padding=1), nn.ReLU(),
nn.Conv1d(64, 64, 3, padding=1), nn.ReLU(),
nn.MaxPool1d(2),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(64*(PCA_DIM//4), 128), nn.ReLU(),
nn.Linear(128, num_classes)
)
def forward(self,x):
return self.fc(self.features(x))
class ResBlock1D(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv1d(ch, ch, 3, padding=1), nn.ReLU(),
nn.Conv1d(ch, ch, 3, padding=1)
)
def forward(self,x):
return torch.relu(self.conv(x)+x)
class ResNet1D(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.conv1 = nn.Conv1d(1, 32, 3, padding=1)
self.block1 = ResBlock1D(32)
self.pool = nn.AdaptiveMaxPool1d(8)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(32*8, 128), nn.ReLU(),
nn.Linear(128, num_classes)
)
def forward(self,x):
x = torch.relu(self.conv1(x))
x = self.block1(x)
x = self.pool(x)
return self.fc(x)
models = {
"LeNet": LeNet1D(num_classes).to(DEVICE),
"VGG" : VGG1D(num_classes).to(DEVICE),
"ResNet": ResNet1D(num_classes).to(DEVICE)
}
# ===== 3. 训练&评估 =====
criterion = nn.CrossEntropyLoss()
def run_epoch(model, loader, train=True, optimizer=None):
model.train(train)
total, correct, loss_sum = 0,0,0.0
y_true,y_pred=[],[]
for xb,yb in loader:
xb,yb = xb.to(DEVICE), yb.to(DEVICE)
out = model(xb); loss = criterion(out,yb)
if train:
optimizer.zero_grad(); loss.backward(); optimizer.step()
pred = out.argmax(1)
correct += (pred==yb).sum().item()
total += yb.size(0)
loss_sum += loss.item()*yb.size(0)
y_true.extend(yb.cpu().numpy()); y_pred.extend(pred.cpu().numpy())
return loss_sum/total, correct/total, np.array(y_true), np.array(y_pred)
results={}
for name,model in models.items():
optimizer = optim.Adam(model.parameters(), lr=LR)
best_acc=0
for ep in range(1,EPOCHS+1):
trL,trA,_,_ = run_epoch(model, train_loader, True, optimizer)
teL,teA,y_t,y_p = run_epoch(model, test_loader, False)
best_acc = max(best_acc, teA)
if ep%5==0:
print(f"{name} Epoch{ep}: TrainAcc {trA*100:.2f}% | TestAcc {teA*100:.2f}%")
cm = confusion_matrix(y_t,y_p)
oa = accuracy_score(y_t,y_p); kappa = cohen_kappa_score(y_t,y_p)
results[name] = (oa,kappa)
print(f"\n{name} 分类报告:\n",classification_report(y_t,y_p,digits=4))
# ===== 4. 可视化对比 =====
names=list(results.keys())
oa_vals=[results[n][0]*100 for n in names]
kappa_vals=[results[n][1]*100 for n in names]
x=np.arange(len(names))
plt.figure(figsize=(7,4.5))
plt.bar(x-0.15, oa_vals, 0.3, label="OA")
plt.bar(x+0.15, kappa_vals,0.3,label="Kappa*100")
plt.xticks(x,names); plt.ylabel("指标(%)")
plt.title("经典 CNN 架构对比(KSC)"); plt.legend(); plt.tight_layout(); plt.show()
📊 结果解读
-
LeNet:结构简单,参数少,训练快;在小样本任务能快速收敛,但表达力有限。
-
VGG-like:层数更深,卷积核堆叠能提取更丰富特征,但计算量明显增加。
-
ResNet-like:残差结构让网络更稳定,往往在测试集准确率和 Kappa 上更好。

🔚 总结
-
LeNet 是 CNN 的“入门课”,VGG 提升了深度与特征表达,ResNet 则解决了梯度问题。
-
在遥感高光谱任务里,ResNet 通常能在复杂数据上表现更优。
-
通过这三种架构的对比,大家能更直观理解 CNN 的发展思路,为后续更复杂网络(DenseNet、MobileNet 等)打下基础。
欢迎大家关注下方公众号获取更多内容!!
&spm=1001.2101.3001.5002&articleId=151261755&d=1&t=3&u=bd41b5154be64400a5ad485f5c5b8db0)
1128

被折叠的 条评论
为什么被折叠?



