自顶向下分析一个简单的语音识别系统(五)
2017-04-03 17:04
204 查看
本回我们主要分析run_model中的configuration过程的相关函数。
其中,n_input=26表示MFCC倒谱系数为26位,n_context=9表示当前25ms声音片段往前和往后分别9个声音片段做输入。MFCC将在后面详细分析。
由前面可知,neural_network.ini中network_type=BIRNN,下回我们将详细分析该网络。
后面将详细分析CTC损失函数。
本回简要分析了网络的configuration过程,下回将仔细分析网络的基本结构。
1.run_model函数
第二回我们简单介绍了run_model函数的结构,现在我们贴出代码如下所示:def run_model(self): self.graph = tf.Graph() with self.graph.as_default(), tf.device('/cpu:0'): with tf.device(self.tf_device): # Run multiple functions on the specificed tf_device # tf_device GPU set in configs, but is overridden if not available # __init__函数中调用gpu_tool.check_if_gpu_available函数,如果设备中有gpu,则self.tf_device=/gpu:0 self.setup_network_and_graph() self.load_placeholder_into_network() self.setup_loss_function() self.setup_optimizer() self.setup_decoder() self.setup_summary_statistics() # create the configuration for the session tf_config = tf.ConfigProto() tf_config.allow_soft_placement = True tf_config.gpu_options.per_process_gpu_memory_fraction = \ (1.0 / self.simultaneous_users_count) #设置gpu中的内存最大占用率,self.simultaneous_users_count=4 # create the session self.sess = tf.Session(config=tf_config) # initialize the summary writer self.writer = tf.summary.FileWriter( self.SUMMARY_DIR, graph=self.sess.graph) # Add ops to save and restore all the variables self.saver = tf.train.Saver() # For printing out section headers section = '\n{0:=^40}\n' # If there is a model_path declared, then restore the model #前述self.model_path=None if self.model_path is not None: self.saver.restore(self.sess, self.model_path) # If there is NOT a model_path declared, build the model from scratch else: # Op to initialize the variables init_op = tf.global_variables_initializer() # Initializate the weights and biases self.sess.run(init_op) # MAIN LOGIC for running the training epochs logger.info(section.format('Run training epoch')) self.run_training_epochs() logger.info(section.format('Decoding test data')) # make the assumption for working on the test data, that the epoch here is the last epoch _, self.test_ler = self.run_batches(self.data_sets.test, is_training=False, decode=True, write_to_file=False, epoch=self.epochs) # Add the final test data to the summary writer # (single point on the graph for end of training run) summary_line = self.sess.run( self.test_ler_op, {self.ler_placeholder: self.test_ler}) self.writer.add_summary(summary_line, self.epochs) logger.info('Test Label Error Rate: {}'.format(self.test_ler)) # save train summaries to disk self.writer.flush() self.sess.close()
2.setup_network_and_graph函数
本函数主要定义网络模型的输入输出placeholder,代码如下:def setup_network_and_graph(self): # e.g: log filter bank or MFCC features # shape = [batch_size, max_stepsize, n_input + (2 * n_input * n_context)] # the batch_size and max_stepsize can vary along each step self.input_tensor = tf.placeholder( tf.float32, [None, None, self.n_input + (2 * self.n_input * self.n_context)], name='input') # Use sparse_placeholder; will generate a SparseTensor, required by ctc_loss op. self.targets = tf.sparse_placeholder(tf.int32, name='targets') # 1d array of size [batch_size] self.seq_length = tf.placeholder(tf.int32, [None], name='seq_length')
其中,n_input=26表示MFCC倒谱系数为26位,n_context=9表示当前25ms声音片段往前和往后分别9个声音片段做输入。MFCC将在后面详细分析。
3.load_placeholder_into_network函数
该函数调用rnn.py中的SimpleLSTM/BiRNN函数构建网路的基本结构,代码如下:def load_placeholder_into_network(self): # logits is the non-normalized output/activations from the last layer. # logits will be input for the loss function. # nn_model is from the import statement in the load_model function # summary_op variables are for tensorboard if self.network_type == 'SimpleLSTM': self.logits, summary_op = SimpleLSTM_model( self.conf_path, self.input_tensor, tf.to_int64(self.seq_length) ) elif self.network_type == 'BiRNN': self.logits, summary_op = BiRNN_model( self.conf_path, self.input_tensor, tf.to_int64(self.seq_length), self.n_input, self.n_context ) else: raise ValueError('network_type must be SimpleLSTM or BiRNN') self.summary_op = tf.summary.merge([summary_op])
由前面可知,neural_network.ini中network_type=BIRNN,下回我们将详细分析该网络。
4.setup_loss_function函数
本函数设置语音识别模型的loss函数为ctc_loss,代码如下:def setup_loss_function(self): with tf.name_scope("loss"): self.total_loss = ctc_ops.ctc_loss( self.targets, self.logits, self.seq_length) self.avg_loss = tf.reduce_mean(self.total_loss) self.loss_summary = tf.summary.scalar("avg_loss", self.avg_loss) self.cost_placeholder = tf.placeholder(dtype=tf.float32, shape=[]) self.train_cost_op = tf.summary.scalar( "train_avg_loss", self.cost_placeholder)
后面将详细分析CTC损失函数。
5.setup_optimizer函数
本函数调用utils.py中的create_optimizer函数,使用AdamOptimizer对网络进行优化,代码如下:def setup_optimizer(self): # Note: The optimizer is created in models/RNN/utils.py with tf.name_scope("train"): self.optimizer = create_optimizer() self.optimizer = self.optimizer.minimize(self.avg_loss)
6.setup_decoder函数
本函数使用ctc中的两种策略对输出结果进行解码,代码如下:def setup_decoder(self): with tf.name_scope("decode"): if self.beam_search_decoder == 'default': self.decoded, self.log_prob = ctc_ops.ctc_beam_search_decoder( self.logits, self.seq_length, merge_repeated=False) elif self.beam_search_decoder == 'greedy': self.decoded, self.log_prob = ctc_ops.ctc_greedy_decoder( self.logits, self.seq_length, merge_repeated=False) else: logging.warning("Invalid beam search decoder option selected!")
7.setup_summary_statistics函数
本函数主要用于设置运行过程中产生的summary的收集点,代码如下:def setup_summary_statistics(self): # Create a placholder for the summary statistics with tf.name_scope("accuracy"): # Compute the edit (Levenshtein) distance of the top path distance = tf.edit_distance( tf.cast(self.decoded[0], tf.int32), self.targets) # Compute the label error rate (accuracy) self.ler = tf.reduce_mean(distance, name='label_error_rate') self.ler_placeholder = tf.placeholder(dtype=tf.float32, shape=[]) self.train_ler_op = tf.summary.scalar( "train_label_error_rate", self.ler_placeholder) self.dev_ler_op = tf.summary.scalar( "validation_label_error_rate", self.ler_placeholder) self.test_ler_op = tf.summary.scalar( "test_label_error_rate", self.ler_placeholder)
本回简要分析了网络的configuration过程,下回将仔细分析网络的基本结构。
相关文章推荐
- 自顶向下分析一个简单的语音识别系统(八)
- 自顶向下分析一个简单的语音识别系统(九)
- 自顶向下分析一个简单的语音识别系统(四)
- 自顶向下分析一个简单的语音识别系统(七)
- 自顶向下分析一个简单的语音识别系统(一)
- 自顶向下分析一个简单的语音识别系统(二)
- 自顶向下分析一个简单的语音识别系统(六)
- 自顶向下分析一个简单的语音识别系统(十)
- 自顶向下分析一个简单的语音识别系统(三)
- 一个简单的自顶向下语法分析(表达式求值)
- 一个简单存储过程的性能分析
- 对一个挂马网页的简单分析
- 一个简单的词法分析程序
- 一个简单的python代理服务器源码分析
- AutoTRACE是分析SQL的执行计划,执行效率的一个非常简单方便的工具
- 对一个简单递归的 时间复杂度的分析
- SysAuto病毒简单分析(一个盗QQ木马)
- 一个简单的PDF文件结构的分析
- 一个简单的词法分析程序
- 一个简单的ThreadPool分析