PyTorch的可视化工具

张量的可视化

显示单张图片

1
2
3
4
5
6
7
8
9
import torch as t
from torchvision import transforms
from matplotlib import pyplot as plt

to_pil = transforms.ToPILImage()
img = to_pil(t.randn(3, 64, 64)) # 随机噪声

plt.imshow(img)
plt.show()

将多张图片拼在一起

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch as t
from torch import nn
from torchvision import models
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt

#========================================================
# 数据加工
#========================================================

transform = transforms.Compose([
transforms.Resize(256), # 缩放图片, 保持长宽比不变, 最短边256
transforms.CenterCrop(224), # 从图片中间切出224x224的图片
transforms.ToTensor(), # 将图片(Image)转化为Tensor, 归一化到[0,1]
transforms.Normalize(mean=[0.5], std=[0.5]) # 标准化到[-1,1]
])

dataset = datasets.MNIST('data/', download=True, train=False, transform=transform)

dataloader = DataLoader(dataset, shuffle=True, batch_size=16)

#========================================================
# 模型迁移
#========================================================

# 预训练模型
resnet34 = models.resnet34(pretrained=True, num_classes=1000)
# 修改最后的全连接层为10分类问题
resnet34.fc = nn.Linear(512, 10)

#========================================================
# 可视化
#========================================================

from torchvision.utils import make_grid, save_image

dataiter = iter(dataloader)
img = make_grid(next(dataiter)[0], 4) # 拼成4x4网格图片
save_image(img, 'test.png') # 将tensor保存成png图片

to_img = transforms.ToPILImage()
img = to_img(img)

plt.imshow(img)
plt.show()

训练过程的可视化

1、TensorBoard

tensorboard_logger将TensorBoard的功能抽取出来,使得非TensorFlow用户也能用它进行可视化
第一步是安装,非常简单

1
pip install tensorboard_logger

第二步,用如下命令启动TensorBoard

1
tensorboard --logdir <your_dir:这里用D:/> --port <your_bind_port:这里用6006>

第三步,记录日志

1
2
3
4
5
6
7
8
9
from tensorboard_logger import Logger

logger = Logger(logdir='D:/', # 日志保存路径
flush_secs=2 # 刷新同步间隔
)

for i in range(100):
logger.log_value('loss:', 10-i**0.5, step=i)
logger.log_value('accuracy:', i**0.5/10)

第四步,网页查看
打开http://localhost:6006/就可以看到日志记录的内容了

2、visdom

visdom是Facebook专门为PyTorch开发的一款可视化工具。
第一步是安装,非常简单

1
pip install visdom

第二步,用如下命令启动visdom服务(默认绑定8097端口)

1
python -m visdom.server

使用如下命令可以将visdom服务放至后台运行

1
nohup python -m visdom.server &

第三步,记录日志

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import visdom
import torch as t

# 新建一个连接客户端
# 指定env='test1', 默认端口8097, host是'localhost'
vis = visdom.Visdom(env='test1')

x = t.arange(1, 30, 0.01)
y = t.sin(x)
# line绘线, histgram可视化分布, scatter散点图, bar柱状图, pie饼图, image可视化图片, text用于记录日志等文字信息...
vis.line(X=x,
Y=y,
win='sinx', # 指定pane名字
opts={'title': 'y=sin(x)'} # 可视化配置, 用于设置pane的显示格式
)

第四步,网页查看
打开http://localhost:8097/就可以看到日志记录的内容了

pytorch常用代码段整理