使用scikit-learn进行音乐分类
2016-06-29 20:01
399 查看
使用scikit-learn进行音乐分类
之前简要说了下Yaafe和Essentia的安装使用,自己也专门对这两个框架做了封装,实现了各个特征提取的接口,并将特征值存到本地文件中方便训练时使用。下面这张图是对应提取的特征数据,一共35种特征,每种特征维度不一样。例如MFCC最大值维度为13维。这里的音乐分类主要是按情感进行分类,一共分成7种情感:happy,sad,passion,excited,quiet,nostalgic,relax。
一、得到训练数据
def getTrainData(train_path, train_mode, mode_name, featureList): num = [] train_data = [] train_target = [] introFeature = [] for i in range(len(train_mode)): allFeature = [] # print len(train_mode),len(train_path),i,train_path[0],train_path[1] # print tool.getFileNum(train_path[i]) num.append(tool.getFileNum(train_path[i])) for j in tool.getAllFile(train_path[i]): yl.startEngine(j) el.startLoader(j) oneFeature, introFeature = getFeature(yl, el, featureList) allFeature.append(oneFeature) # print allFeature # 判断是否有歌曲的特征被提取出 if len(allFeature) == 0: print '无音频文件' else: wf.setMode(train_mode[i]) # 设置类别和目录.输出 wf.setFeaturePrintPath(train_path[i]) # 音乐目录/feature_num/train_mode[i].txt readme.txt wf.writeFeature(allFeature, introFeature) tool.mkdir('dataSystem/' + mode_name) wf.writeDetial('dataSystem/' + mode_name + '/feature.txt', featureList) for singleFeature in allFeature: train_data.append(singleFeature) # 设置训练数据 train_target.append(i) # 设置训练目标 return train_data, train_target
train_path和train_mode分别表示传入的一类音乐文件夹路径,该类别标签,例如train_path=‘/home/chenming/Music/Happy’表示happy类型音乐文件夹,train_mode=‘happy’表示该组训练数据类型为happy。featureList是传入的需要提取的特征,例如featureList=[1,2,3]表示提取FEATURES中对应1,2,3位置的特征(上面第一张图片)。
train_data = []和 train_target = []分别是得到的训练数据以及该数据对应的标签。最终特征数据存到/home/chenming/Music/Happy下feature_num/happy.txt中
二、训练
def train(train_data, train_target, train_mode, mode_name): # 数据标准化 train_data = np.array(train_data) train_target = np.array(train_target) # sklearn提供的数据处理方法 # train_data = preprocessing.scale(train_data) # test_data = preprocessing.scale(test_data) (meanData, stdData) = tool.trainDataStandardNor(train_data) # 存储 np.savetxt('dataSystem/' + mode_name + '/meanData.txt', meanData) np.savetxt('dataSystem/' + mode_name + '/stdData.txt', stdData) wf.writeDetial('dataSystem/' + mode_name + '/modeList.txt', train_mode) #np.savetxt('dataSystem/' + mode_name + '/modeList.txt', train_mode) ''' c=np.arange(1,10,1) g=np.arange(0,1,0.1) tuned_parameters = [{'kernel': ['rbf'], 'gamma': g,'C': c}] #print classifierType,type(classifierType) ''' clf = svm.SVC() if classifierType == 'svm': clf = svm.SVC() elif classifierType == 'knn': clf = neighbors.KNeighborsClassifier() # knn K-临近算法 elif classifierType == 'tree': clf = tree.DecisionTreeClassifier() # 决策树 # clf = grid_search.GridSearchCV(svm.SVC(), tuned_parameters, cv=3) #svm支持向量机 print train_data print train_target clf.fit(train_data, train_target) # clf.fit(train_data) # clf.score(test_data, test_target) # showStyleDetail(test_data,test_target,clf) from sklearn.externals import joblib tool.mkdir('svcSystem/' + mode_name) joblib.dump(clf, 'svcSystem/' + mode_name + '/svc1.pkl') # clf2=joblib.load('svc/svc1.pkl') # print clf.score(test_data, test_target) # showDetail(test_data,test_target,clf2) # print clf.best_estimator_ return '训练成功,模型' + mode_name + '存储成功,'+'位置:svcSystem/'+mode_name +'/svc1.pkl'
(meanData, stdData) = tool.trainDataStandardNor(train_data)是对数据进行归一化处理,代码如下:
def trainDataStandardNor(featureMatrix): #print 'train_feature',featureMatrix meanData=featureMatrix.mean(axis=0) #平均值 stdData=featureMatrix.std(axis=0) #标准差 for data in featureMatrix: for i in range(len(data)): if stdData[i]==0: data[i]=0 else: print data[i] data[i]=(data[i]-meanData[i])/stdData[i] return (meanData,stdData)
http://www.tuicool.com/articles/qeIzI3F 这篇博客介绍了Scikit-Learn的使用。
三、测试
def test(test_data, test_target, test_mode, mode_name): test_data = np.array(test_data) test_target = np.array(test_target) print 'test_data:', test_data print 'test_target:', test_target print 'test_mode:', test_mode path1 = 'dataSystem/' + mode_name + '/meanData.txt' path2 = 'dataSystem/' + mode_name + '/stdData.txt' path3 = 'svcSystem/' + mode_name + '/svc1.pkl' path4 = 'dataSystem/' + mode_name + '/modeList.txt' print '数据路径:' + 'dataSystem/' + mode_name + '和svcSystem/' + mode_name meanData = np.loadtxt(path1) stdData = np.loadtxt(path2) modeList = wf.readFileTxt(path4) mode = wf.readFileTxt('dataSystem/' + mode_name + '/modeList.txt') tool.testDataStandardNor(test_data, meanData, stdData) from sklearn.externals import joblib clf = joblib.load(path3) s = showDetail(clf, test_data, test_target, modeList) return s
showDetail函数展示结果:
def showDetail(clf, test_data, test_target, mode): modeNum = [] expected = test_target predicted = clf.predict(test_data) # print expected # print predicted # print test_mode out = ''' *******准确率=被识别为该分类的正确分类记录数/被识别为该分类的记录数 *******召回率=被识别为该分类的正确分类记录数/测试集中该分类的记录总数 *******F1-score=2(准确率 * 召回率)/(准确率 + 召回率),F1-score是F-measure(又称F-score)beta=1时的特例 *******support=测试集中该分类的记录总数\n ''' for d in mode: print 'd', d out = out + str(metrics.classification_report(expected, predicted, target_names=mode)) print (metrics.classification_report(expected, predicted, target_names=mode)) result = metrics.confusion_matrix(expected, predicted) out = out + '\n预测数目:\t' print '预测数目:\t', for j in range(len(mode)): out = out + str(mode[j]) + '\t' print mode[j], '\t', out = out + '\n\n' print '\n' for i in range(len(result)): out = out + '测试数据' + str(i + 1) + ':\t' print '测试数据' + str(i + 1) + ':\t', num = 0 for p in range(len(result[i])): out = out + str(result[i][p]) + '\t' print result[i][p], '\t', num = num + result[i][p] out = out + '\n' print '\n' modeNum.append([result[i][i], num]) out = out + '\n' for k in range(len(modeNum)): out = out + str(mode[k]) + ':\t命中/总数:' + str(modeNum[k][0]) + '/' + str(modeNum[k][1]) + '\t正确率:' + str( tool.ratio(modeNum[k][0], modeNum[k][1])) + '\n' print mode[k], ':\t命中/总数:', modeNum[k][0], '/', modeNum[k][1], '\t正确率:', tool.ratio(modeNum[k][0], modeNum[k][1]) return out
以下是测试结果测试50首happy歌曲音乐,得到47首正确,3首识别错误:
相关文章推荐
- 《剑指offer》-02字符串替换
- 欢迎使用CSDN-markdown编辑器
- Lua创建一个类 继承
- linux 调试利器gdb, strace, pstack, pstree, lsof
- pragma message的作用
- POJ3256:Cow Picnic
- 线程的同步(synchronized关键字)
- java线程池管理
- CentOS 配置Apache+Mysql+PHP
- Linux下多线程查看工具(pstree、ps、pstack),linux命令之-pstree使用说明, linux 查看线程状态。 不指定
- C++错误累积
- Thread
- Uva11212 编辑书稿(Editing a book,IDE算法)
- POJ 3264 线段树 ST
- (三)springMVC WebUploader分片上传
- #define a int[10] typedef int a[10]
- POJ 3264 线段树 ST
- Java 之 简单工厂模式
- BFC——块级格式化上下文
- php做一个简单的分页类