您的位置:首页 > 理论基础 > 计算机网络

案例带你学Pytorch系列(3)——GAN生成对抗网络

2019-05-16 12:43 471 查看

典型案例详解——生成对抗网络

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 2
batch_size = 100
sample_dir = 'samples'

if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
"""
建立文件夹用于存放生成的样本
exists(path):测试是否存在某路径,如不存在则建立
mkdirs(name)
"""

transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize(mean=(0.5,0.5,0.5),
#                   std=(0.5,0.5,0.5) )
])

mnist = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transform,
download=False)

data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)

D = nn.Sequential(
nn.Linear(image_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, 1),
nn.Sigmoid())

'''
1、定义判别器(Discriminator)
2、LeakyReLU(negative_slope=1e-2, inplace=False):关键参数只有一个斜率,默认值1e-2
3、分别构造了3个线性层实例
'''

G = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, image_size),
nn.Tanh())

'''
1、定义生成(Generator)
'''

D = D.to(device)
G = G.to(device)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

def denorm(x):
out = (x+1)/2
return out.clamp(0, 1)
''' out.clamp(0,1)
clamp:取(0,1)范围内的数值,即“掐头去尾”
'''
# 需要在反向传播之前清零梯度

def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()

total_step = len(data_loader)
for epoch in range(num_epochs):
for i, (images, _) in enumerate(data_loader):
images = images.reshape(batch_size, -1).to(device)
# images.shape(100,784)
real_labels = torch.ones(batch_size, 1).to(
device)  # 定义真标签,全1,shape(100,1)
fake_labels = torch.zeros(batch_size, 1).to(
device)  # 定义假标签,全0,shape(100,1)

# 训练判别器
# model(D/G) 模块对象重载了__call__运算符,可以像函数一样直接调用
# 前向传播,计算输出
outputs = D(images)
d_loss_real = criterion(outputs, real_labels)  # 计算判别器真图输出和真标签的损失
real_score = outputs

z = torch.randn(batch_size, latent_size).to(
device)  # 设置随机种子z,shape(100,64)
fake_images = G(z)  # 随机种子喂入生成器,形成假图
outputs = D(fake_images)  # 利用判别器,计算假图输出
d_loss_fake = criterion(outputs, fake_labels)  # 计算判别器假图输出和假标签损失
fake_score = outputs

d_loss = d_loss_fake+d_loss_real  # 计算损失和
reset_grad()  # 梯度归零,在反向传播之前,使用optimizer将它要更新的所有张量的梯度清零(这些张量是模型可学习的权重)
d_loss.backward()  # 梯度反向传播
d_optimizer.step()  # 单步优化,调用optimizer的step函数更新所有参数
'''
记住2+3:
2:前向传播、计算损失
3:清零梯度、反向传播、更新权重
optimizer.zero_grad()
loss.backward()
optimizer.step()
'''

# 训练生成器
z = torch.randn(batch_size, latent_size).to(device)  # 生成随机种子
fake_images = G(z)  # 随机种子喂入生成器
outputs = D(fake_images)  # 判别器判别假图输出
g_loss = criterion(outputs, real_labels)  # 判别器假图输出和真标签损失
reset_grad()  # 梯度归零
g_loss.backward()  # 反向传播
g_optimizer.step()  # 优化

if (i+1) % 200 == 0:
print('Epoch [{}/{}],step[{}/{}],d_loss:{:.4f},g_loss:{:.4f},D(x):{:.2f},D(G(z)):{:.2f}'
.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), real_score.mean(), fake_score.mean()))
# 每200个bacth,输出一次

if (epoch+1) == 1:
images = images.reshape(images.size(0), 1, 28, 28)
save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
save_image(denorm(fake_images), os.path.join(
sample_dir, 'fake_images-{}.png'.format(epoch+1)))

torch.save(G.state_dict(), 'G1.ckpt')
torch.save(D.state_dict(), 'D1.ckpt')

'''
1、状态字典(state_dict):
模型的可学习参数(即权重w和偏差b)包含在模型的 _parameters_ 中,(使用model.parameters())。
具有可学习参数的层(如卷积层、线性层等)的模型才具有 _state_dict_ 属性。
优化目标 `torch.optim` 也有 _state_dict_ 属性,它包含有关优化器的状态信息,以及使用的超参数。
2、存储和加载推断模型:
torch.save(model.state_dict(), PATH)
加载:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH)) load_state_dict()函数只接受字典对象,例如无法通过 model.load_state_dict(PATH)来加载模型。
model.eval(),用于设置 dropout 和 batch normalization 层为评估模式。

也可以保存完整模型或保存checkpoint,详情参阅:
https://pytorch.org/tutorials/beginner/saving_loading_models.html

'''

运行结果:
epoch 1:

epoch100:

运行到100次数字已接近可辨认。

案例来源:
https://github.com/yunjey/pytorch-tutorial.git
关注用案例学Pytorch:
https://github.com/houhuipeng/EasyPytorch-ByExample.git

作者:追逐一生
来源:CSDN
原文:https://blog.csdn.net/houhuipeng/article/details/90043548
版权声明:本文为博主原创文章,转载请附上博文链接!

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: