pytorch基本类型
import torch
"""
pytorch 基本数据类型
1)不支持内建string类型,使用one-hot ,Embedding表示;
使用one-hot表示时,如表示单词较多,则one-hot会比较长,但其中只有一个元素值为1,故整个向量表现得很稀疏。并且one-hot不能表示语义相关性,故使用Embedding表示;
2)内建类型:
CPU GPU
float torch.FloatTensor torch.cuda.FloatTensor
double torch.DoubleTensor torch.cuda.DoubleTensor
byte torch.ByteTensor torch.cuda.ByteTensor
int torch.IntTensor torch.cuda.IntTensor
long torch.LongTensor torch.cuda.LongTensor
3)类型推断:
"""
a = torch.randn(2,3)#随机初始化2row3col
print(a.type(),type(a),isinstance(a,torch.FloatTensor))#torch.FloatTensor <class 'torch.Tensor'> True
"""
4)cpu与gpu转换
"""
print(isinstance(a,torch.cuda.FloatTensor))#false
a = a.cuda()
print(isinstance(a,torch.cuda.FloatTensor))#true
"""
5)标量:主要用于计算loss时
Dimension 0 / Rank 0
a = torch.Tensor(1.0)#早期支持,最新版不支持
b = torch.Tensor(.3)
print(a,b)
标量的意思时没有方向的量,如买5个苹果,重量2.5KG,这里即为标量;
6)向量:注意这里注意数学和计算机不同表示 方式,本质上二者是一个东西。主要用于偏置,注意其是有方向的
Dimension 1 / Rank 1
"""
a = torch.FloatTensor([1.1])
b = torch.Tensor([1.1])#二者相同
c = torch.Tensor(2)#随机初始化
print(a,b,c)#tensor([1.1000]) tensor([1.1000]) tensor([1.0962e+09, 4.5874e-41])
import numpy as np
data = np.ones(2)
d = torch.from_numpy(data)#从numpy引入
print(d)#tensor([1., 1.], dtype=torch.float64)
print(torch.__version__)#1.1.0
print(d.shape,d.size())#torch.Size([2]) torch.Size([2])注意dimension和rank是一个东西,rank是数学上叫法;如[2,3]则其rank为2;size和shape指的是具体形状;
"""
Dimension2,Dimension3(适合RNN),Dimension4(适合图片)
7)
"""
a = torch.randn(2,3)
print(a.size(0),a.shape[0])#第0个元素
a = torch .randn(2,3,4)
print(a.size(0))
(adsbygoogle = window.adsbygoogle || []).push({});
- PyTorch基本数据类型(一)
- PyTorch基本用法(一)——Numpy,Torch对比
- 『PyTorch x TensorFlow』第八弹_基本nn.Module层函数
- PyTorch基本用法(五)——分类
- py的基本数据类型 12.13
- pytorch学习 基本组成、基本练习
- OC_封装、拆包基本数据类型
- 关于C和C++中的基本数据类型int、long、long long、float、double、char、string的大小及表示范围
- java初级之5基本数据类型
- PHP 八种基本的数据类型
- javascript六种数据基本类型
- 第十九节,基本数据类型,集合set
- C/C++基本数据类型所占字节数
- Delphi 基本类型转换
- android--jni--基本数据类型的使用
- 基本数据类型对象包装类
- python中的基本数据类型
- Python(二)基本数据类型和变量
- java基本数据类型
- V的machine learning(1)--原创 win10+Anaconda3-5.2.0+python3.6.5+pip3安装pytorch1.0-cuda8.0版本