您的位置:首页 > 其它

文本分类:对单一数据集进行训练集,测试集和验证集的划分

2019-08-16 18:00 323 查看
版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。 本文链接:https://blog.csdn.net/qq_38473254/article/details/99417019

文本分类:训练集、验证集和测试集的划分

文本分类CNN源程序 github 地址:https://github.com/gaussic/text-classification-cnn-rnn
数据集链接:https://pan.baidu.com/s/1UhraDB3MCittdK6p1rfE6A 提取码:2fa2

在进行中文文本分类的过程中,克隆文件后进行测试,记录下我所遇到的问题:

1.关于运行run_cnn.py(同理于run_rnn.py)

如果观察run_cnn,py的代码就会发现,要想执行训练程序,需要在命令行中输入:

python run_cnn.py train

同理,要想执行测试程序,在命令行中输入:

python run_cnn.py test

2.对数据集进行划分

如果要想对自己的数据集进行测试,只有一个txt文件,那就需要对其进行训练集,验证集和测试集的划分。
为什么要对数据集进行训练集,验证集和测试集的划分
这篇博客写的很清楚,建议读一读,然后开始我们对数据集的划分(这里train:val:test = 6:2:2)

首先导入模块,初始化训练集、验证集和测试集的列表为空

import os
import random

L_train = []
L_val = []
L_test = []

定义函数ReadFileDatas()和WriteDatasToFile(),可以方便我们读取txt的内容和将内容保存到txt文件中去,列表是不能使用write()函数的,需要先将其转换为string类型

# 读取文件中的内容,并写入列表FileNameList
def ReadFileDatas(original_filename):
FileNameList = []
file = open(original_filename, 'r+', encoding='utf-8-sig')
for line in file:
FileNameList.append(line)  # 写入文件内容到列表中去
print('数据集总量:', len(FileNameList))
file.close()
return FileNameList

# 将获取的列表中的内容转为 str ,再写入到txt文件中去
# listInfo为 ReadFileDatas 的列表
def WriteDatasToFile(listInfo, new_filename):
file_handle = open(new_filename, mode='a', encoding='utf-8-sig')
for idx in range(len(listInfo)):
str = listInfo[idx]  # 列表指针
str_Result = str
file_handle.write(str_Result)
file_handle.close()
print('写入 %s 文件成功.' % new_filename)

对数据集进行train:val:test = 6:2:2划分,再定义数据保存的格式

"""
将划分数据集用函数表示
划分数据集(train, val, test)的区间,(new.txt) 为随机打乱好的文件数据集
数据集列表集合
打开文件引用上一函数保存的文件
"""
def TrainValTestFile(new_filename):
# L_train = []
# L_val = []
# L_test = []
i = 0    # counter
j = 9352 # all lines
file_divide = open(new_filename, 'r', encoding='utf-8-sig')
lines = file_divide.readlines()
for line in lines:
if i < (j *0.6):
i += 1
L_train.append(line)
elif i < (j*0.8):
i += 1
L_val.append(line)
elif i < j:
i += 1
L_test.append(line)
print("总数据量:%d , 此时创建train, val, test数据集" % i)
return L_train, L_val, L_test

# 保存数据集(train, val, test)
def text_save(filename, data):  #filename为写入CSV文件的路径,data为要写入数据列表
file = open(filename, 'a', encoding='utf-8-sig')
for i in range(len(data)):
s = str(data[i]).replace('[','').replace(']','')  #去除[],这两行按数据不同,可以选择
# s = s.replace("'",'').replace(',','') +'\n'   #去除单引号,逗号,每行末尾追加换行符
file.write(s)
file.close()
print("保存数据集(路径)成功:%s" % filename)

最后调用函数,完成训练集,测试集和验证集的划分,并保存在指定目录

# 调用函数
if __name__ == "__main__":
listFileInfo = ReadFileDatas('data.txt')            # 读取文件
random.shuffle(listFileInfo)                         # 打乱顺序
WriteDatasToFile(listFileInfo,'new_data.txt')       # 保存新的文件

# 划分数据集并保存
TrainValTestFile('new_data.txt')
text_save('./data/data_train.txt', L_train)
text_save('./data/data_val.txt', L_val)
text_save('./data/data_test.txt', L_test)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: