1. 程式人生 > >[Tensorflow] 統計模型的引數數量 How to calculate the amount of parameters in my model?

[Tensorflow] 統計模型的引數數量 How to calculate the amount of parameters in my model?

import logging
logging.basicConfig(level=logging.INFO, format='%(message)s', filemode='w', filename=config.logger)

def _params_usage():
	total = 0
	prompt = []
	for v in tf.trainable_variables():
		shape = v.get_shape()
		cnt = 1
		for dim in shape:
			cnt *= dim.value
		prompt.append('{} with shape {} has {}'
.format(v.name, shape, cnt)) logging.info(prompt[-1]) total += cnt prompt.append('totaling {}'.format(total)) logging.info(prompt[-1]) return '\n'.join(prompt)

shape is of type TensorShape. It is an iterable and each element is of type Dimension, whose attribute .value gives the raw integer of the dimension.

The above function _params_usage() prints out infos in the specified logging approach, and returns a string. This is intended to prints out in parallel to a logging file and the stdout stream.