案例带你学Pytorch系列(3)——GAN生成对抗网络
2019-05-16 12:43
471 查看
典型案例详解——生成对抗网络
import os import torch import torchvision import torch.nn as nn from torchvision import transforms from torchvision.utils import save_image device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') latent_size = 64 hidden_size = 256 image_size = 784 num_epochs = 2 batch_size = 100 sample_dir = 'samples' if not os.path.exists(sample_dir): os.makedirs(sample_dir) """ 建立文件夹用于存放生成的样本 exists(path):测试是否存在某路径,如不存在则建立 mkdirs(name) """ transform = transforms.Compose([ transforms.ToTensor(), # transforms.Normalize(mean=(0.5,0.5,0.5), # std=(0.5,0.5,0.5) ) ]) mnist = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transform, download=False) data_loader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True) D = nn.Sequential( nn.Linear(image_size, hidden_size), nn.LeakyReLU(0.2), nn.Linear(hidden_size, hidden_size), nn.LeakyReLU(0.2), nn.Linear(hidden_size, 1), nn.Sigmoid()) ''' 1、定义判别器(Discriminator) 2、LeakyReLU(negative_slope=1e-2, inplace=False):关键参数只有一个斜率,默认值1e-2 3、分别构造了3个线性层实例 ''' G = nn.Sequential( nn.Linear(latent_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, image_size), nn.Tanh()) ''' 1、定义生成(Generator) ''' D = D.to(device) G = G.to(device) criterion = nn.BCELoss() d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002) g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002) def denorm(x): out = (x+1)/2 return out.clamp(0, 1) ''' out.clamp(0,1) clamp:取(0,1)范围内的数值,即“掐头去尾” ''' # 需要在反向传播之前清零梯度 def reset_grad(): d_optimizer.zero_grad() g_optimizer.zero_grad() total_step = len(data_loader) for epoch in range(num_epochs): for i, (images, _) in enumerate(data_loader): images = images.reshape(batch_size, -1).to(device) # images.shape(100,784) real_labels = torch.ones(batch_size, 1).to( device) # 定义真标签,全1,shape(100,1) fake_labels = torch.zeros(batch_size, 1).to( device) # 定义假标签,全0,shape(100,1) # 训练判别器 # model(D/G) 模块对象重载了__call__运算符,可以像函数一样直接调用 # 前向传播,计算输出 outputs = D(images) d_loss_real = criterion(outputs, real_labels) # 计算判别器真图输出和真标签的损失 real_score = outputs z = torch.randn(batch_size, latent_size).to( device) # 设置随机种子z,shape(100,64) fake_images = G(z) # 随机种子喂入生成器,形成假图 outputs = D(fake_images) # 利用判别器,计算假图输出 d_loss_fake = criterion(outputs, fake_labels) # 计算判别器假图输出和假标签损失 fake_score = outputs d_loss = d_loss_fake+d_loss_real # 计算损失和 reset_grad() # 梯度归零,在反向传播之前,使用optimizer将它要更新的所有张量的梯度清零(这些张量是模型可学习的权重) d_loss.backward() # 梯度反向传播 d_optimizer.step() # 单步优化,调用optimizer的step函数更新所有参数 ''' 记住2+3: 2:前向传播、计算损失 3:清零梯度、反向传播、更新权重 optimizer.zero_grad() loss.backward() optimizer.step() ''' # 训练生成器 z = torch.randn(batch_size, latent_size).to(device) # 生成随机种子 fake_images = G(z) # 随机种子喂入生成器 outputs = D(fake_images) # 判别器判别假图输出 g_loss = criterion(outputs, real_labels) # 判别器假图输出和真标签损失 reset_grad() # 梯度归零 g_loss.backward() # 反向传播 g_optimizer.step() # 优化 if (i+1) % 200 == 0: print('Epoch [{}/{}],step[{}/{}],d_loss:{:.4f},g_loss:{:.4f},D(x):{:.2f},D(G(z)):{:.2f}' .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), real_score.mean(), fake_score.mean())) # 每200个bacth,输出一次 if (epoch+1) == 1: images = images.reshape(images.size(0), 1, 28, 28) save_image(denorm(images), os.path.join(sample_dir, 'real_images.png')) fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28) save_image(denorm(fake_images), os.path.join( sample_dir, 'fake_images-{}.png'.format(epoch+1))) torch.save(G.state_dict(), 'G1.ckpt') torch.save(D.state_dict(), 'D1.ckpt') ''' 1、状态字典(state_dict): 模型的可学习参数(即权重w和偏差b)包含在模型的 _parameters_ 中,(使用model.parameters())。 具有可学习参数的层(如卷积层、线性层等)的模型才具有 _state_dict_ 属性。 优化目标 `torch.optim` 也有 _state_dict_ 属性,它包含有关优化器的状态信息,以及使用的超参数。 2、存储和加载推断模型: torch.save(model.state_dict(), PATH) 加载: model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) load_state_dict()函数只接受字典对象,例如无法通过 model.load_state_dict(PATH)来加载模型。 model.eval(),用于设置 dropout 和 batch normalization 层为评估模式。 也可以保存完整模型或保存checkpoint,详情参阅: https://pytorch.org/tutorials/beginner/saving_loading_models.html '''
运行结果:
epoch 1:
epoch100:
运行到100次数字已接近可辨认。
案例来源:
https://github.com/yunjey/pytorch-tutorial.git
关注用案例学Pytorch:
https://github.com/houhuipeng/EasyPytorch-ByExample.git
作者:追逐一生
来源:CSDN
原文:https://blog.csdn.net/houhuipeng/article/details/90043548
版权声明:本文为博主原创文章,转载请附上博文链接!
相关文章推荐
- 生成对抗网络(GAN)的理论与应用完整入门介绍
- 生成对抗网络GAN
- 生成对抗网络GAN的数学公式的前因后果
- ICCV2017 | 一文详解GAN之父Ian Goodfellow 演讲《生成对抗网络的原理与应用》(附完整PPT)
- 生成对抗网络(GAN)初探
- 火热的生成对抗网络(GAN),你究竟好在哪里
- 生成对抗网络(GAN)相比传统训练方法有什么优势?(一)
- [TensorFlow]生成对抗网络(GAN)介绍与实践
- 生成对抗网络(GAN)的一些知识整理(课件)
- 洞见 | 生成对抗网络GAN最近在NLP领域有哪些应用
- 【深度学习理论】通俗理解生成对抗网络GAN
- (转)机器学习系列直播--使用对抗神经网络(GANs)生成猫
- 浅谈GAN生成对抗网络
- <视频教程-2>生成对抗网络GAN视频教程part6-完整版
- Gan 生成对抗网络
- <模型汇总_5>生成对抗网络GAN及其变体SGAN_WGAN_CGAN_DCGAN_InfoGAN_StackGAN
- GAN相关(二):DCGAN / 深度卷积对抗生成网络
- 生成对抗网络(GAN,Generative Adversarial Networks) 学习笔记
- [生成对抗网络] GAN
- 生成对抗网络学习笔记4----GAN(Generative Adversarial Nets)的实现