[tensorflow] 如何從pb模型檔案中獲得引數資訊 How to obtain parameters information from a tensorflow .pb file?
阿新 • • 發佈:2018-11-04
因為要和SOTA比較模型的複雜度,我想知道引數數量。但是模型檔案不是tensorflow checkpoint,而是pb檔案,我發現當匯入graph後,tf.trainable_variables()返回空。
Problem setting : I need to compare with state-of-the-arts the model complexity so the model parameter amount is needed. However the provided model file isn’t the ckpt file, but pb file, and the variables returned by tf.trainable_variables()
這個回答給出了方法。
This answer gives the solution.
舉例:
In my case:
# import graph
with open('spmc_120_160_4x3f.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def, input_map={'Placeholder:0': frames_lr}, return_elements= ['output:0'])
output = output[0]
# ... other codes
# obtain variables
constant_values = {}
with tf.Session() as sess:
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
for constant_op in constant_ops:
constant_values[constant_op.name] = sess.run(constant_op.outputs[ 0])
# printing variables
print_params(constant_values)
def print_params(constant_values):
total = 0
prompt = []
forbidden = ['shape','stack']
for k,v in constant_values.items():
# filtering some by checking ndim and name
if v.ndim<1: continue
if v.ndim==1:
token = k.split(r'/')[-1]
flag = False
for word in forbidden:
if token.find(word)!=-1:
flag = True
break
if flag:
continue
shape = v.shape
cnt = 1
for dim in shape:
cnt *= dim
prompt.append('{} with shape {} has {}'.format(k, shape, cnt))
print(prompt[-1])
total += cnt
prompt.append('totaling {}'.format(total))
print(prompt[-1])
return prompt
因為匯入的都是constant節點,而我需要的其實是trainable_variables,所以我只能手動的根據ndim和name過濾掉一些。
As mentioned in the answer, the imported nodes in the graph are constant ones, which mix constants and variables, and the latter ones are what is needed. So i have to manually filter the usesless ones, by checking their ndim and name.