1. 从零实现Wasserstein生成对抗网络(WGAN)的完整指南
生成对抗网络(GAN)近年来在图像生成领域取得了显著进展,但传统GAN训练过程中存在模式崩溃、训练不稳定等难题。2017年提出的Wasserstein GAN通过引入Wasserstein距离和一系列创新改进,显著提升了训练稳定性和生成质量。本文将手把手带你实现一个完整的WGAN模型,生成MNIST手写数字"7"。
1.1 WGAN的核心创新
WGAN与标准GAN的关键区别在于:
- 用Critic(评论家)替代Discriminator(判别器),输出实数评分而非概率
- 采用Wasserstein距离(Earth-Mover距离)衡量真实与生成分布差异
- 通过权重裁剪强制满足Lipschitz连续性条件
- 损失函数直接反映生成质量,训练更稳定
理论证明显示,当Critic达到最优时,Wasserstein距离提供了比JS散度更平滑的梯度信号。这意味着即使两个分布没有重叠(这在训练初期很常见),WGAN仍能提供有效的学习信号。
关键理解:Wasserstein距离可以理解为"将一个分布搬移到另一个分布所需的最小工作量",这种几何直觉使其比传统GAN使用的JS散度更具优势。
2. WGAN实现细节解析
2.1 Critic模型架构设计
Critic采用卷积神经网络结构,输入28×28灰度图像,输出一个实数评分。与DCGAN相比有几点关键修改:
def define_critic(in_shape=(28,28,1)):
init = RandomNormal(stddev=0.02)
const = ClipConstraint(0.01) # 权重裁剪约束
model = Sequential()
# 下采样至14x14
model.add(Conv2D(64, (4,4), strides=(2,2), padding='same',
kernel_initializer=init, kernel_constraint=const,
input_shape=in_shape))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# 下采样至7x7
model.add(Conv2D(64, (4,4), strides=(2,2), padding='same',
kernel_initializer=init, kernel_constraint=const))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# 线性激活输出层
model.add(Flatten())
model.add(Dense(1)) # 无激活函数
opt = RMSprop(lr=0.00005)
model.compile(loss=wasserstein_loss, optimizer=opt)
return model
关键实现细节: </


1614

被折叠的 条评论
为什么被折叠?



