您的位置:首页 > 其它

划分训练集和测试集和验证集

2018-01-08 19:14 288 查看
划分训练集和测试集和验证集:import os
import codecs
import random
random.seed(1229)

data = []
with codecs.open('neg.txt', "r", encoding='utf-8', errors='ignore') as fdata:
now = fdata.readlines()
data.append(['0 ' + item for item in now])
with codecs.open('pos.txt', "r", encoding='utf-8', errors='ignore') as fdata:
now = fdata.readlines()
data.append(['1 ' + item for item in now])

def get_test(data, n, x):
st, ed = len(data) * x // n, len(data) * (x+1) // n
return data[st:ed]

def get_train(data, n, x):
st, ed = len(data) * x // n, len(data) * (x+1) // n
return data[:st] + data[ed:]

for i in range(10):
train_ori = [get_train(item, 10, i) for item in data]
test_ori = [get_test(item, 10, i) for item in data]

train = []
dev = []
test = []
for j in range(2):
random.shuffle(train_ori[j])
x = len(train_ori[j]) * 9 // 10
train += train_ori[j][:x]
dev += train_ori[j][x:]
test += test_ori[j]
random.shuffle(train)
random.shuffle(dev)
random.shuffle(test)
os.system('mkdir mr%s' % i)
open('mr%s/train.txt' % i, 'w').writelines(train)
open('mr%s/dev.txt' % i, 'w').writelines(dev)
open('mr%s/test.txt' % i, 'w').writelines(test)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: