1. 程式人生 > >從 SRGAN(TensorFlow) 匯出生成網路(generator)資料

從 SRGAN(TensorFlow) 匯出生成網路(generator)資料

按《tensorflow2caffe(2) : 如何在tensorflow中取出模型引數》一文的程式碼原理: 把下面的程式碼放到 main.py 的 generator 部分:

        #-------------------------------------------------------------

        # 這裡應該是global_variables,如果trainable_variables,則會缺少一些引數
        # all_vars = tf.trainable_variables()

        all_vars = tf.global_variables()
        fp = open('SRGAN_generator_model.txt', 'w')
        for v in all_vars:

            name = v.name

            fname = name + '.prototxt'

            fname = fname.replace('/','_')

            print (fname)
            fp.write(fname)
            fp.write('\n')

            v_4d = np.array(sess.run(v))
            if v_4d.ndim == 4:

                #v_4d.shape [ H, W, I, O ]        

                v_4d = np.swapaxes(v_4d, 0, 2) # swap H, I

                v_4d = np.swapaxes(v_4d, 1, 3) # swap W, O

                v_4d = np.swapaxes(v_4d, 0, 1) # swap I, O

                #v_4d.shape [ O, I, H, W ]


                vshape = v_4d.shape[:]

                v_1d = v_4d.reshape(v_4d.shape[0]*v_4d.shape[1]*v_4d.shape[2]*v_4d.shape[3])

                fp.write('  blobs {\n')

                for vv in v_1d:

                    fp.write('    data: %8f' % vv)

                    fp.write('\n')

                fp.write('    shape {\n')

                for s in vshape:

                    fp.write('      dim: ' + str(s))#print dims

                    fp.write('\n')

                fp.write('    }\n')

                fp.write('  }\n')
            elif v_4d.ndim == 1 :#do not swap


                fp.write('  blobs {\n')

                for vv in v_4d:

                    fp.write('    data: %.8f' % vv)

                    fp.write('\n')

                fp.write('    shape {\n')

                fp.write('      dim: ' + str(v_4d.shape[0]))#print dims

                fp.write('\n')

                fp.write('    }\n')

                fp.write('  }\n')

        fp.close()
        #-------------------------------------------------------------

然後執行就匯出了一個文字方式的資料

SRGAN_generator_model.txt:

generator_generator_unit_input_stage_conv_Conv_weights:0.prototxt
  blobs {
    data: -0.022789
    data: -0.008191
    data: -0.001650
    ...省略
    data: 0.007882
    data: 0.007484
    shape {
      dim: 64
      dim: 3
      dim: 9
      dim: 9
    }
  }
generator_generator_unit_input_stage_conv_Conv_biases:0.prototxt
  blobs {
    data: -0.09940426
    ...省略
    data: -0.06667865
    shape {
      dim: 64
    }
  }
generator_generator_unit_input_stage_Prelu_alpha:0.prototxt
  blobs {
    ...省略
    shape {
      dim: 64
    }
  }
generator_generator_unit_resblock_1_conv_1_Conv_weights:0.prototxt
  blobs {
    ...省略
    shape {
      dim: 64
      dim: 64
      dim: 3
      dim: 3
    }
  }
generator_generator_unit_resblock_1_BatchNorm_beta:0.prototxt
...省略

這樣就可以和 caffe_srgan-master中的資料對比一下異同。