您的位置:首页 > 编程语言 > Python开发

pytorch 中 numpy 类型与 torch 类型共享存储的问题

2018-03-23 19:11 513 查看
从今天起,总结学习 pytorch 过程中遇到的一些日后可能出错的小问题。
首先就是 pytorch 官网 tutorial 第一章讲的,numpy 类型与 torch 类型共享存储,并且还给出样例: http://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html#tensors 在文章中,作者举例,当 torch 类型转换为 numpy类型时,对其中一个操作就相当于对另一个操作:
a = torch.ones(5)
print(a)
Out:
1
1
1
1
1
[torch.FloatTensor of size 5]
b = a.numpy()
print(b)
Out:
[ 1.  1.  1.  1.  1.]
然后执行:
a.add_(1)
print(a)
print(b)
Out:
2
2
2
2
2
[torch.FloatTensor of size 5]

[ 2.  2.  2.  2.  2.]
但是,我试着将代码中的 a.add_(1) 替换为 a = a + 1,结果就不是这样的:
2
 2
 2
 2
 2
[torch.FloatTensor of size (5,)]

[1. 1. 1. 1. 1.]
可以看到,这个时候 a 变了,但是 b 并没有变。

#########################################################################

同理,反过来,当 numpy 类型转换为 torch 类型的时候,作者举例如下:
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)
Out:
[ 2.  2.  2.  2.  2.]

2
2
2
2
2
[torch.DoubleTensor of size 5]
如果我把代码中的 np.add(a, 1, out=a) 替换为 a = a + 1 的话,就又不共享存储了:
[[2. 2.]
 [2. 2.]
 [2. 2.]]

 1  1
 1  1
 1  1

[torch.DoubleTensor of size (3,2)]

具体为什么目前还没查到,先记在这里,日后发现为什么了再补上。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: