生成对抗网络(GAN)
了解 GAN 如何通过生成逼真图像、增强数据以及推动医疗保健、游戏等领域的创新,彻底改变人工智能。
生成对抗网络(GAN)是一类功能强大的机器学习(ML)框架,由Ian Goodfellow 及其同事于 2014 年首次提出。它们属于生成式人工智能领域,侧重于创建与给定训练数据集相似的新数据。GANs 背后的核心理念涉及两个神经网络(NN)--生成器和判别器--之间的竞争博弈。这种对抗过程促使系统产生高度逼真的合成输出,如图像、音乐或文本。
GANS 如何工作
GAN 架构由两个同时进行训练的主要组件组成:
- 生成器:该网络将随机噪音(通常是从高斯分布中采样的随机数向量)作为输入,并尝试将其转换为模拟真实数据分布的数据。例如,它可能会生成与训练数据集中的图像相似的猫的合成图像。它的目标是生成与真实数据无异的输出,从而有效地欺骗判别器。
- 判别器:该网络充当二元分类器。它同时接收真实数据样本(来自实际数据集)和虚假数据样本(由生成器创建)。它的任务是确定每个输入样本是真实的还是虚假的。它通过标准的监督学习技术来学习这一点,目的是对真实样本和生成样本进行正确分类。
对抗式培训过程
GAN 的训练是一个动态的过程,在这个过程中,生成器和判别器相互竞争,共同进步:
- 生成器生成一批合成数据。
- 判别器在包含真实数据和生成器合成数据的批次上进行训练,学习如何区分它们。反向传播法用于根据分类准确率更新权重。
- 然后根据判别器的输出对生成器进行训练。它的目标是生成被判别器错误地归类为真实的数据。梯度流回(暂时固定的)判别器,以更新生成器的权重。
如此循环往复,最终达到理想的平衡状态,即生成器生成的数据非常逼真,鉴别器只能随机猜测(准确率为 50%)样本的真假。此时,"生成器 "已学会近似训练集的基本数据分布。
GANS 与其他模式的比较
必须将 GAN 与其他类型的模型区分开来:
- 判别模型:大多数标准分类和回归模型(如用于图像分类或标准物体检测的模型)都是判别模型。它们根据输入特征学习决策边界来区分不同类别或预测值。相比之下,GAN 是生成型模型--它们学习数据本身的底层概率分布来创建新样本。
- 扩散模型 扩散模型是另一种功能强大的生成模型,近来备受瞩目,在图像生成方面往往能达到最先进的效果。它们的工作原理是逐渐向数据中添加噪声,然后学习逆转这一过程。与 GAN 相比,扩散模型有时能生成保真度更高的图像,并能提供更稳定的训练,但在推理过程中计算量更大。
挑战与进步
由于以下问题,训练 GANs 的难度可想而知:
为了应对这些挑战,研究人员开发了许多 GAN 变体,例如提高稳定性的 Wasserstein GAN(WGAN)和允许根据特定属性生成数据(例如生成特定数字的图像)的条件 GAN(cGAN)。PyTorch和TensorFlow等框架提供了各种工具和库,为GANs 的实施和训练提供了便利。