您的位置:首页 > 其它

pandas中one-hot编码的神坑

2018-01-28 17:57 295 查看
机器学习中,经常会用到one-hot编码。pandas中已经提供了这一函数。

但是这里有一个神坑,得到的one-hot编码数据类型是uint8,进行数值计算时会溢出!!!

import pandas as pd
import numpy as np
a = [1, 2, 3, 1]
one_hot = pd.get_dummies(a)
print(one_hot.dtypes)
print(one_hot)
print(-one_hot)


1    uint8
2    uint8
3    uint8
dtype: object
1  2  3
0  1  0  0
1  0  1  0
2  0  0  1
3  1  0  0
1    2    3
0  255    0    0
1    0  255    0
2    0    0  255
3  255    0    0


正确的做法是,将其转换成浮点:


one_hot = one_hot.astype('float')
print(-one_hot)


1    2    3
0 -1.0 -0.0 -0.0
1 -0.0 -1.0 -0.0
2 -0.0 -0.0 -1.0
3 -1.0 -0.0 -0.0
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息