GAN生成对抗网络合集(六):GAN-cls –具有匹配感知的判别器(附代码)

本文介绍了GAN-cls技术,它通过改进判别器,使其不仅能判断图片真伪,还能判断匹配真伪。与ACGAN相比,GAN-cls通过将图像与标签结合输入判别器,简化了生成匹配标签样本的过程,无需改动生成器。文中提供了代码示例展示如何从LSGAN转换为GAN-cls,展示了训练和损失函数的修改。

1 GAN-cls原理

       这是一种GAN网络增强技术----具有匹配感知的判别器。前面讲过,在InfoGAN中,使用了ACGAN的方式进行指导模拟数据与生成数据的对应关系(分类)。在GAN-cls中该效果会以更简单的方式来实现,即增强判别器的功能,令其不仅能判断图片真伪,还能判断匹配真伪

(个人理解)没啥实质性改变,时间并未缩短,技术也没有怎么简化甚至变得复杂了。就是思想上的一个转变,原本ACGan是模拟样本+正确分类信息输入进去/真实样本+正确分类信息输入进D去。现在的GAN-cls变为输入真实样本和真实标签、虚拟样本和真实标签、虚拟标签和真实样本的三种组合形式(无对应图片的随机标签

       GAN-cls的具体做法是,在原有的GAN网络上,将判别器的输入变为图片与对应标签的连接数据。这样判别器的输入特征中就会有生成图像的特征与对应标签的特征。然后用这样的判别器分别对真实标签与真实图片、假标签与真实图片、真实标签与假图片进行判断,预期的结果依次为真、假、假,在训练的过程中沿着这个方向收敛即可。而对于生成器,则不需要做任何改动。这样简单的一步就完成了生成根据标签匹配的模拟数据功能。

在这里插入图片描述

2 代码

直接修改上一篇 GAN生成对抗网络合集(五):LSGan-最小二乘GAN(附代码) 代码,将其改成GAN-cls。

  1. 修改判别器D
    将判别器的输入改成x与y,新增加的y代表输入的样本标签(真、假);在内部处理中,先通过全连接网络将y变为与图片一样维度的映射,并调整为图片相同的形状,使用concat将二者连接到一起统一处理。后续的处理过程是一样的,两个卷积后再接两个全连接,最后一层输出disc。该部分代码如下:
# def discriminator(x, num_classes=10, num_cont=2):
def discriminator(x, y):  # 判别器函数 : x两次卷积,再接两次全连接; y代表输入的样本标签
    reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
    # print (reuse)
    # print (x.get_shape())
    with tf.variable_scope('discriminator', reuse=reuse):

        y = slim.fully_connected(y, num_outputs=n_input, activation_fn=leaky_relu)  # 将y变为与图片一样维度的映射
        y = tf.reshape(y, shape=[-1, 28, 28, 1])    # 将y统一成图片格式

        x = tf.reshape(x, shape=[-1, 28, 28, 1])

        # 将二者连接到一起,统一处理
        x = tf.concat(axis=3, values=[x, y])  # x.shape = [-1, 28, 28, 2]

        x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        # print ("conv2d",x.get_shape())
        x = slim.flatten(x)  # 输入扁平化
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=leaky_relu)
        # recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=leaky_relu)

        # 生成的数据可以分别连接不同的输出层产生不同的结果
        # 1维的输出层产生判别结果1或是0
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
        disc = tf.squeeze(disc, -1)
        # print ("disc",disc.get_shape()) # 0 or 1

        # 10维的输出层产生分类结果 (样本标签)
        # recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)

        # 2维输出层产生重构造的隐含维度信息
        # recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
    return disc  # recog_cat, recog_cont
  1. 添加错误标签输入符,构建网络结构
    添加错误标签misy,同时在判别器中分别将真实样本与真实标签、生成的图像gen与真实标签、真实样本与错误标签组成的输入传入判别器中。去掉隐含信息z_con部分。

注:这里是将3种输入的x与y分别按照batch_size维度连接变为判别器的一个输入的。生成结果后再使用split函数将其裁成3个结果disc_real、disc_fake和disc_mis,分别代表真实样本与真实标签、生成的图像gen与真实标签、真实样本与错误标签所对应的判别值。这么写会使代码看上去简洁一些,当然也可以一个一个地输入x、y,然后调用三次判别器,效果是一样的。

##################################################################
#  3.定义网络模型 : 定义 参数/输入/输出/中间过程(经过G/D)的输入输出
##################################################################
batch_size = 10  # 获取样本的批次大小32
classes_dim = 10  # 10 classes
con_dim = 2  # 隐含信息变量的维度, 应节点为z_con
rand_dim = 38  # 一般噪声的维度, 应节点为z_rand, 二者都是符合标准高斯分布的随机数。
n_input = 784  # 28 * 28

x = tf.placeholder(tf.float32, [None, n_input])  # x为输入真实图片images
y = tf.placeholder(tf.int32, [None])  # y为真实标签labels
misy = tf.placeholder(tf.int32, [None])  # 错误标签

# z_con = tf.random_normal((batch_size, con_dim))  # 2列
z_rand = tf.random_normal((batch_size, rand_dim))  # 38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_rand])  # 50列 shape = (10, 50)
gen = generator(z)  # shape = (10, 28, 28, 1)
genout = tf.squeeze(gen, -1)  # shape = (10, 28, 28)

# labels for discriminator
# y_real = tf.ones(batch_size)  # 真
# y_fake = tf.zeros(batch_size)  # 假

# 判别器D
xin = tf.concat(
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值