RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for ar
2018-10-08 21:14
555 查看
运行如下程序:
import numpy as np import torch from torch import nn from torch.autograd import Variable import matplotlib.pyplot as plt class LinearRegression(nn.Module): def __init__(self): super(LinearRegression,self).__init__() self.linear = nn.Linear(1,1) def forward(self, x): out = self.linear(x) return out x_train = np.array([[3.3],[4.4],[5.5],[6.710],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32) y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32) x_train = torch.from_numpy(x_train) y_train = torch.from_numpy(y_train) num_epochs = 1000 if torch.cuda.is_available(): print("GPU1") model = LinearRegression().cuda() else: print("CPU1") model = LinearRegression() criterion = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(),lr=1e-3) for epoch in range(num_epochs): if torch.cuda.is_available(): print('GPU2') inputs = Variable(x_train).cuda() target = Variable(y_train).cuda() else: print("CPU2") inputs = Variable(x_train) target = Variable(y_train) # forward out = model(inputs) loss = criterion(out,target) # backward optimizer.zero_grad() # 梯度归零 loss.backward() # 反向传播 optimizer.step() # 更新参数 # if (epoch+1) % 20 ==0: # print('Epoch[{}/{}], loss:{:,6f}'.format(epoch+1,num_epochs,loss.data[0])) model.eval() predict = model(Variable(x_train)) predict = predict.data.cpu().numpy() plt.plot(x_train.numpy(),y_train.numpy(),'ro',label='Original data') plt.plot(x_train.numpy(),predict,label='Fitting Line') plt.show()
这行报错
predict = model(Variable(x_train))
RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #4 'mat1'
意思是要求的目标类型是torch.cuda.FloatTensor,但是找到的数据类型是torch.FloatTensor,所以需要在数据类型后面加上
.cuda()。将
predict = model(Variable(x_train))改为
predict = model(Variable(x_train.cuda()))阅读更多
相关文章推荐
- No qualifying bean of type [org.apache.solr.client.solrj.SolrServer] found for dependency: expected
- No code "EPSG:4326" from authority "EPSG" found for object of type "EngineeringCRS"
- org.springframework.data.mapping.model.MappingException: No id property found for object of type
- 配置ssm项目报错:No qualifying bean of type ... found for dependency ...expected at least 1 bean which
- Property 'edgesForExtendedlayout' not found on object of type 解决方案
- No qualifying bean of type [XXX.XXX.XXX] found for dependency: expected at least 1 bean which q
- ReactJs 报错 Element type is invalid: expected a string (from built-in components) or a class/function (for composite components) but got: undefined. Check the render method of `Me`.
- No qualifying bean of type found for dependency: expected at least 1 bean which qualifies as autowir
- Spring的JUnit错误:NoSuchBeanDefinitionException: No unique bean of type,expected single matching bean but found 2 [dataSource1,data
- No matching bean of type [xx] found for dependency: expected at least 1 bean which qualifies as autowire candidate for this dependency
- No qualifying bean of type found for dependency: expected at least 1 bean which qualifies as autowir
- org.springframework.beans.factory.NoSuchBeanDefinitionException: No qualifying bean of type [com.oskyhang.gbd.service.UserService] found for dependency: expected at least 1 bean which qualifies as aut
- Property 'edgesForExtendedlayout' not found on object of type
- client found response content type of text/html charset=utf-8 but expected text/xml
- No qualifying bean of type...found for dependency:expected at least 1 bean which qualifies... 异常解决方案
- No converter found for return value of type:
- vc-complex-type.2.4.c: The matching wildcard is strict, but no declaration can be found for element
- SSH报错之 Provided id of the wrong type for class model.User. Expected: class java.lang
- 【已解决】java.lang.IllegalArgumentException: No converter found for return value of type: class java.uti
- Deserialization Problems ... The constructor to deserialize an object of type ... was not found