keras的LSTM文本分类预处理及代码简单实现
2018-03-11 11:45
447 查看
最近公司在做文本分类处理这一块,自己也没接触,过于是在茫茫的博客中苦苦寻找,给大家推荐一篇比较不错的实在github上找到的,网址贴上。这篇文章的代码实现还是比较简单的。但是存在两个问题:一是训练集。这个训练集是直接从这个包中直接加载并训练的,对于实际的应用来说做文本的类的人士需要将自己的训练集和测试集替换掉它。二是自动分类的预处理过程。它没有介绍在训练之前数据的预处理以及数据的卷积化(应该是这样说,也就是将数据经过切分、去重、排序(可选)等操作后转化为一串对应类别的序列。不论是处理成one-hot还是word-embidding)。所以这篇文章我会介绍一下数据预处理的过程,而后面循环训练的部分具体各函数的具体功能以及运算方法我还没深入了解,就不在大家面前献丑了。还有就小编如果哪里有不正确或者不合适的地方请不吝指正或者联系小编进行修改。(qq:1549990441 加时请详细备注,谢谢!)
小编自己理解的文本分类的基本过程:分类-->根据训练模型构建数据的shape-->训练-->结果保存
详解如下:
1、分类:
这一块的话小编在实现的时候是手动分类的,因为公司最近项目在游戏方面,所以分类的层级结构大致是这样的:
2、分好类之后开始构建所需要的shape。数据shape格式如下:
这里小编要提醒一下:1表示这篇文章属于该类别,0表示不属于,可以存在全为1的一行,即一篇文章对应对多个类别,但是不能出现全为0的情况,因为这样训练没意义。所以在构建shape是需要对全为0的进行剔除,小编后面的代码中也会实现。并且在构建shape的时候需要每一篇文章与shape对应。我是如上图这样理解的:行代表每一篇文章,列代表你锁分的类别,而在最后训练的结果中中也是这样的shape,三十数据却变成了0--1之间的一个数,即预测值,哪一列 接近于1说明越接近哪一个类别。
3、shape构建完之后开始构建模型。有两种方法:一是向layer添加list的方式,二是通过.add()方式一层层添加,一个add为一层。小编制这里用的是 .add()方法。
model = Sequential()
model.add(Dense(25, input_dim=784)) # 这里的25需要与你构建的shape对应
model.add(Activation('relu'))
每一个函数里面的参数小编在这里就不一一介绍了,详细见keras文档,中文文档。
4、模型构建完开始编译。设置优化器,损失函数,评估模型的指标等,一般默认就可以,后期优化可以在详细修改。
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
5、模型训练。其他参数需要在添加修改。
fit(self, x_train, y_train, batch_size=32,
validation_data=(x_test, y_test))
fit函数返回一个History的对象,其History.history属性记录了损失函数和其他指标的数值随epoch变化的情况,如果有验证集的话,也包含了验证集的这些指标变化情况。
x_train:输入训练数据。模型只有一个输入时x的类型是numpy array,模型多个输入时x的类型应当为list,list的元素是对应于各个输入的numpy array
# y_train:标签,numpy array
# batch_size:指定进行梯度下降时每个batch包含的样本数。
<
8e4f
p># validation_data:形式为(X,y)的tuple,是指定的验证集。此参数将覆盖validation_spilt。
5、测试、预测数据。
evaluate(self, x_test, y_test, batch_size=32, verbose=1)
#本函数按batch计算在某些输入数据上模型的误差,其参数有:
#x_test:输入数据,与fit一样,是numpy array或numpy array的list
#y_test:标签,numpy array
#batch_size:整数,含义同fit的同名参数
#verbose:含义同fit的同名参数,但只能取0或1
predict_proba(self, x_test, batch_size=32, verbose=1)
本函数按batch产生输入数据属于各个类别的概率,函数的返回值是类别概率的numpy array
6、结果保存。
小编在调试及正确性上在优化代码,所以延迟更新,需要的小伙伴可先关注等几天,小编会尽快补全。
小编自己理解的文本分类的基本过程:分类-->根据训练模型构建数据的shape-->训练-->结果保存
详解如下:
1、分类:
这一块的话小编在实现的时候是手动分类的,因为公司最近项目在游戏方面,所以分类的层级结构大致是这样的:
2、分好类之后开始构建所需要的shape。数据shape格式如下:
这里小编要提醒一下:1表示这篇文章属于该类别,0表示不属于,可以存在全为1的一行,即一篇文章对应对多个类别,但是不能出现全为0的情况,因为这样训练没意义。所以在构建shape是需要对全为0的进行剔除,小编后面的代码中也会实现。并且在构建shape的时候需要每一篇文章与shape对应。我是如上图这样理解的:行代表每一篇文章,列代表你锁分的类别,而在最后训练的结果中中也是这样的shape,三十数据却变成了0--1之间的一个数,即预测值,哪一列 接近于1说明越接近哪一个类别。
3、shape构建完之后开始构建模型。有两种方法:一是向layer添加list的方式,二是通过.add()方式一层层添加,一个add为一层。小编制这里用的是 .add()方法。
model = Sequential()
model.add(Dense(25, input_dim=784)) # 这里的25需要与你构建的shape对应
model.add(Activation('relu'))
每一个函数里面的参数小编在这里就不一一介绍了,详细见keras文档,中文文档。
4、模型构建完开始编译。设置优化器,损失函数,评估模型的指标等,一般默认就可以,后期优化可以在详细修改。
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
5、模型训练。其他参数需要在添加修改。
fit(self, x_train, y_train, batch_size=32,
validation_data=(x_test, y_test))
fit函数返回一个History的对象,其History.history属性记录了损失函数和其他指标的数值随epoch变化的情况,如果有验证集的话,也包含了验证集的这些指标变化情况。
x_train:输入训练数据。模型只有一个输入时x的类型是numpy array,模型多个输入时x的类型应当为list,list的元素是对应于各个输入的numpy array
# y_train:标签,numpy array
# batch_size:指定进行梯度下降时每个batch包含的样本数。
<
8e4f
p># validation_data:形式为(X,y)的tuple,是指定的验证集。此参数将覆盖validation_spilt。
5、测试、预测数据。
evaluate(self, x_test, y_test, batch_size=32, verbose=1)
#本函数按batch计算在某些输入数据上模型的误差,其参数有:
#x_test:输入数据,与fit一样,是numpy array或numpy array的list
#y_test:标签,numpy array
#batch_size:整数,含义同fit的同名参数
#verbose:含义同fit的同名参数,但只能取0或1
predict_proba(self, x_test, batch_size=32, verbose=1)
本函数按batch产生输入数据属于各个类别的概率,函数的返回值是类别概率的numpy array
6、结果保存。
小编在调试及正确性上在优化代码,所以延迟更新,需要的小伙伴可先关注等几天,小编会尽快补全。
相关文章推荐
- Android实现下载工具的简单代码
- [20180313智慧餐厅推荐系统02]基于python的socket编程代码,实现PC与服务器的简单通信
- AjaxPanel自定义控件实现页面无刷新数据交互(做了个示例程序, 效果确实比较Cool, 用法非常简单! )(示例代码下载)
- AJAX实现简单的注册页面异步请求实例代码
- 一段多个access表汇总的简单样例 (备忘 根据情况修改相应代码可实现excel多表入access汇总)
- python kmeans聚类简单介绍和实现代码
- 基于jQuery实现Div窗口震动特效代码-代码简单
- JS实现简单的运行代码 & 侧边广告
- 实现简单的队和栈结构,附代码,图
- JQuery简单实现锚点链接的平滑滚动(一段代码控制所有锚点)
- jquery实现简单瀑布流代码
- 代码片段 - Golang 实现简单的 Web 服务器
- 一个简单的时间片轮转多道程序内核代码 的实现
- iOS 实现简单的移动UIView代码实例
- javascript简单拖拽实现代码(鼠标事件 mousedown mousemove mouseup)
- Java实现“年-月-日 上午/下午时:分:秒”的简单代码
- 几种简单的负载均衡算法及其Java代码实现
- jquer之ajaxQueue简单实现代码
- js上下视差滚动简单实现代码