文章目录
基本概念
1、变分自编码器属于无监督学习
2、变分自编码器的主要作用是可以生成数据
3、VAE的网络结构:

Tensorflow实现
VAE实现 MNIST 手写数字识别
1、库导入:
import os
import tensorflow as tf
from tensorflow import keras
from PIL import Image
from matplotlib import pyplot as plt
from tensorflow.keras import Sequential, layers
import numpy as np
2、数据集加载:
# 数据集加载,自编码器不需要标签因为是无监督学习
(x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)
3、模型搭建:
3.1 网络模块
# 编码网络
self.vae_encoder = layers.Dense(self.units)
# 均值网络
self.vae_mean = layers.Dense(self.z_dim) # get mean prediction
# 方差网络(均值和方差是一一对应的,所以维度相同)
self.vae_variance = layers.Dense(self.z_dim) # get variance prediction
# 解码网络
self.vae_decoder = layers.Dense(self.units)
# 输出网络
self.vae_out = layers.Dense(784)
3.2 encoder传播
def encoder(self, x):
h = tf.nn.relu(self.vae_encoder(x))
#计算均值
mu = self.vae_mean(h)
#计算方差
log_var = self.vae_variance(h)
return mu, log_var
3.3 decoder传播
def decoder(self, z):
out = tf.nn.relu(self.vae_decoder(z))
out = self.vae_out(out)
return out
3.4 参数重设定
def reparameterize(self, mu, log_var):
eps = tf.random.normal(log_var.shape)
std = tf.exp(log_var) # 去掉log, 得到方差;
std = std**0.5 # 开根号,得到标准差;
z = mu + std * eps
return z
3.5 主网络结构
def call(self, inputs):
mu, log_var = self.encoder(inputs)
# reparameterizaion trick:最核心的部分
z = self.reparameterize(mu, log_var)
# decoder 进行还原
x_hat = self.decoder(z)
# Variational auto-encoder除了前向传播不同之外,还有一个额外的约束;
# 这个约束使得你的mu, var更接近正太分布;所以我们把mu, log_var返回;
return x_hat, mu, log_var
3.6 模型实例化
model = VAE(z_dim,units=128)
model.build(input_shape=(128, 784))
optimizer = keras.optimizers.Adam(lr=lr)
3.7 loss函数
# 把每个像素点当成一个二分类的问题;
rec_loss = tf.losses.binary_crossentropy(x, x_hat, from_logits=True)
rec_loss = tf.reduce_mean(rec_loss)
3.8 计算KL散度
KL散度公式:
D k l ( N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) = 1 2 ( − l

本文介绍了变分自编码器的基本概念,强调其无监督学习特性和数据生成能力。通过详细步骤展示了如何使用Tensorflow构建VAE模型,并应用到Fashion-MNIST数据集上,进行手写数字的重构和样本生成。
及Tensorflow实现&spm=1001.2101.3001.5002&articleId=124814476&d=1&t=3&u=47e27af8340b4fd0bd14891c9f3e23fa)
8836

被折叠的 条评论
为什么被折叠?



