1. 程式人生 > >自頂向下分析一個簡單的語音識別系統(五)

自頂向下分析一個簡單的語音識別系統(五)

本回我們主要分析run_model中的configuration過程的相關函式。

1.run_model函式

第二回我們簡單介紹了run_model函式的結構,現在我們貼出程式碼如下所示:

    def run_model(self):
        self.graph = tf.Graph()
        with self.graph.as_default(), tf.device('/cpu:0'):

            with tf.device(self.tf_device):
                # Run multiple functions on the specificed tf_device
# tf_device GPU set in configs, but is overridden if not available # __init__函式中呼叫gpu_tool.check_if_gpu_available函式,如果裝置中有gpu,則self.tf_device=/gpu:0 self.setup_network_and_graph() self.load_placeholder_into_network() self.setup_loss_function() self.setup_optimizer() self.setup_decoder() self.setup_summary_statistics() # create the configuration for the session
tf_config = tf.ConfigProto() tf_config.allow_soft_placement = True tf_config.gpu_options.per_process_gpu_memory_fraction = \ (1.0 / self.simultaneous_users_count) #設定gpu中的記憶體最大佔用率,self.simultaneous_users_count=4 # create the session
self.sess = tf.Session(config=tf_config) # initialize the summary writer self.writer = tf.summary.FileWriter( self.SUMMARY_DIR, graph=self.sess.graph) # Add ops to save and restore all the variables self.saver = tf.train.Saver() # For printing out section headers section = '\n{0:=^40}\n' # If there is a model_path declared, then restore the model #前述self.model_path=None if self.model_path is not None: self.saver.restore(self.sess, self.model_path) # If there is NOT a model_path declared, build the model from scratch else: # Op to initialize the variables init_op = tf.global_variables_initializer() # Initializate the weights and biases self.sess.run(init_op) # MAIN LOGIC for running the training epochs logger.info(section.format('Run training epoch')) self.run_training_epochs() logger.info(section.format('Decoding test data')) # make the assumption for working on the test data, that the epoch here is the last epoch _, self.test_ler = self.run_batches(self.data_sets.test, is_training=False, decode=True, write_to_file=False, epoch=self.epochs) # Add the final test data to the summary writer # (single point on the graph for end of training run) summary_line = self.sess.run( self.test_ler_op, {self.ler_placeholder: self.test_ler}) self.writer.add_summary(summary_line, self.epochs) logger.info('Test Label Error Rate: {}'.format(self.test_ler)) # save train summaries to disk self.writer.flush() self.sess.close()

2.setup_network_and_graph函式

本函式主要定義網路模型的輸入輸出placeholder,程式碼如下:

    def setup_network_and_graph(self):
        # e.g: log filter bank or MFCC features
        # shape = [batch_size, max_stepsize, n_input + (2 * n_input * n_context)]
        # the batch_size and max_stepsize can vary along each step
        self.input_tensor = tf.placeholder(
            tf.float32, [None, None, self.n_input + (2 * self.n_input * self.n_context)], name='input')

        # Use sparse_placeholder; will generate a SparseTensor, required by ctc_loss op.
        self.targets = tf.sparse_placeholder(tf.int32, name='targets')
        # 1d array of size [batch_size]
        self.seq_length = tf.placeholder(tf.int32, [None], name='seq_length')

其中,n_input=26表示MFCC倒譜系數為26位,n_context=9表示當前25ms聲音片段往前和往後分別9個聲音片段做輸入。MFCC將在後面詳細分析。

3.load_placeholder_into_network函式

該函式呼叫rnn.py中的SimpleLSTM/BiRNN函式構建網路的基本結構,程式碼如下:

    def load_placeholder_into_network(self):
        # logits is the non-normalized output/activations from the last layer.
        # logits will be input for the loss function.
        # nn_model is from the import statement in the load_model function
        # summary_op variables are for tensorboard
        if self.network_type == 'SimpleLSTM':
            self.logits, summary_op = SimpleLSTM_model(
                self.conf_path,
                self.input_tensor,
                tf.to_int64(self.seq_length)
            )
        elif self.network_type == 'BiRNN':
            self.logits, summary_op = BiRNN_model(
                self.conf_path,
                self.input_tensor,
                tf.to_int64(self.seq_length),
                self.n_input,
                self.n_context
            )
        else:
            raise ValueError('network_type must be SimpleLSTM or BiRNN')
        self.summary_op = tf.summary.merge([summary_op])

由前面可知,neural_network.ini中network_type=BIRNN,下回我們將詳細分析該網路。

4.setup_loss_function函式

本函式設定語音識別模型的loss函式為ctc_loss,程式碼如下:

    def setup_loss_function(self):
        with tf.name_scope("loss"):
            self.total_loss = ctc_ops.ctc_loss(
                self.targets, self.logits, self.seq_length)
            self.avg_loss = tf.reduce_mean(self.total_loss)
            self.loss_summary = tf.summary.scalar("avg_loss", self.avg_loss)

            self.cost_placeholder = tf.placeholder(dtype=tf.float32, shape=[])

            self.train_cost_op = tf.summary.scalar(
                "train_avg_loss", self.cost_placeholder)

後面將詳細分析CTC損失函式。

5.setup_optimizer函式

本函式呼叫utils.py中的create_optimizer函式,使用AdamOptimizer對網路進行優化,程式碼如下:

    def setup_optimizer(self):
        # Note: The optimizer is created in models/RNN/utils.py
        with tf.name_scope("train"):
            self.optimizer = create_optimizer()
            self.optimizer = self.optimizer.minimize(self.avg_loss)

6.setup_decoder函式

本函式使用ctc中的兩種策略對輸出結果進行解碼,程式碼如下:

    def setup_decoder(self):
        with tf.name_scope("decode"):
            if self.beam_search_decoder == 'default':
                self.decoded, self.log_prob = ctc_ops.ctc_beam_search_decoder(
                    self.logits, self.seq_length, merge_repeated=False)
            elif self.beam_search_decoder == 'greedy':
                self.decoded, self.log_prob = ctc_ops.ctc_greedy_decoder(
                    self.logits, self.seq_length, merge_repeated=False)
            else:
                logging.warning("Invalid beam search decoder option selected!")

7.setup_summary_statistics函式

本函式主要用於設定執行過程中產生的summary的收集點,程式碼如下:

    def setup_summary_statistics(self):
        # Create a placholder for the summary statistics
        with tf.name_scope("accuracy"):
            # Compute the edit (Levenshtein) distance of the top path
            distance = tf.edit_distance(
                tf.cast(self.decoded[0], tf.int32), self.targets)

            # Compute the label error rate (accuracy)
            self.ler = tf.reduce_mean(distance, name='label_error_rate')
            self.ler_placeholder = tf.placeholder(dtype=tf.float32, shape=[])
            self.train_ler_op = tf.summary.scalar(
                "train_label_error_rate", self.ler_placeholder)
            self.dev_ler_op = tf.summary.scalar(
                "validation_label_error_rate", self.ler_placeholder)
            self.test_ler_op = tf.summary.scalar(
                "test_label_error_rate", self.ler_placeholder)

本回簡要分析了網路的configuration過程,下回將仔細分析網路的基本結構。