1. 程式人生 > >tf.cond()函式的用法

tf.cond()函式的用法

 

這個函式跟if...else...的功能很像,主要控制tensorflow中計算圖的張量的流向。官網中有對函式引數的解釋如下:

tf.cond(
    pred,
    true_fn=None,
    false_fn=None,
    strict=False,
    name=None,
    fn1=None,
    fn2=None
)
  • pred: A scalar determining whether to return the result of true_fn
     or false_fn.
  • true_fn: The callable to be performed if pred is true.
  • false_fn: The callable to be performed if pred is false.
  • strict: A boolean that enables/disables 'strict' mode; see above.
  • name: Optional name prefix for the returned tensors.

看的有點糊塗,用一個例子來解釋下吧 !

import tensorflow as tf
x=tf.constant(2)
y=tf.constant(5)
flag=tf.constant(True)
op=tf.cond(flag,lambda :tf.add(x,y),lambda : tf.multiply(x,y))
with tf.Session() as sess:
    result=sess.run(op)
    print(result)

當flag為True時,執行‘加’操作,結果為7,當flag為Flase時,執行“乘”操作,結果為10