1. 程式人生 > >Tensorflow深度學習之十九:矩陣切片與連結

Tensorflow深度學習之十九:矩陣切片與連結

1、TensorFlow矩陣切片操作:tf.slice函式

函式原型:slice(input_, begin, size, name=None)
引數:
input:待切片的矩陣tensor。
begin:起始位置,表示從哪一個資料開始進行切片。這個起始位置從0開始。若input是一個n維的矩陣,則begin是一個長度為n的tensor。
size:切片的大小(尺寸),表示則起始位置開始獲取每一維上的若干資料。是一個長度與begin相同的tensor。若size中第n個的資料為-1,則表示在該維度上,從起始位置開始的所有資料均被返回。
name:該操作的名稱,是一個可選引數,預設為None。

對於一個n維的矩陣,需滿足如下關係:
0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n]

import tensorflow as tf

# Tensorflow互動式會話
tf.InteractiveSession()

# 定義5x5大小的一個矩陣變數
a = tf.Variable(tf.truncated_normal(shape=[5, 5], dtype=tf.float32))

# 進行切片操作,起始位置為[1,1](從0開始),大小[2,2]
b = tf.slice(a, [1, 1
], [2, 2]) # 同上 c = tf.Variable(tf.truncated_normal(shape=[2, 6, 5], dtype=tf.float32)) d = tf.slice(c, [0, 2, 3], [2, 3, 1]) # 全域性變數初始化 tf.global_variables_initializer().run() # 輸出 print("Example 01") print("the original matrix:\n", a.eval()) print("after being sliced:\n", b.eval()) print("Example 02"
) print("the original matrix:\n", c.eval()) print("after being sliced:\n", d.eval())

程式執行結果如下:(結果或有不同)

Example 01
the original matrix:
 [[ 1.37798977  0.27846026  0.07193759  0.44368556  0.65868556]
 [-0.57639289 -0.64335102 -0.62483543  0.38987917  0.29301718]
 [ 0.18187736  0.11397317  1.85999572 -0.26037475  0.98114467]
 [ 0.69557261  0.01183218 -0.27376401 -1.15162456  1.11336803]
 [-0.66582751 -0.04991583 -1.58189285  0.98189503 -1.11317801]]
after being sliced:
 [[-0.64335102 -0.62483543]
 [ 0.11397317  1.85999572]]
Example 02
the original matrix:
 [[[-0.44467756 -1.05340731 -0.32313645 -0.69316941  0.04659459]
  [ 0.01275753 -0.11907347  1.70015264  0.60470396 -0.23756829]
  [ 0.07424127  1.01376414 -1.15661514 -0.46597373 -1.82189155]
  [-0.66635352 -0.34318891  0.49555108  0.13062055 -0.67137426]
  [ 0.04240284  0.55397838 -0.09988129 -0.93551743  0.6810317 ]
  [ 1.06745911  0.49900523  1.0482769   0.39871195  1.23199737]]

 [[ 1.22305858 -0.839634    0.63722724 -1.39846325 -0.04114933]
  [-1.11448932  0.20783874  0.39737079  1.13769484 -0.09408376]
  [-0.66636425  0.37878662 -0.32013494 -0.26526076  1.53422773]
  [-0.55344075  0.23021726  0.10251451  0.08433547  1.19850338]
  [ 1.73070538 -0.50309545 -0.52816319 -0.41802529 -1.52679396]
  [-1.60076332  0.88759929  0.01327948 -0.7242741  -0.70737672]]]
after being sliced:
 [[[-0.46597373]
  [ 0.13062055]
  [-0.93551743]]

 [[-0.26526076]
  [ 0.08433547]
  [-0.41802529]]]

2、TensorFlow矩陣連結操作:tf.concat函式
函式原型:concat(values, axis, name=”concat”)
引數:
values:需要連結的矩陣的集合,通常可以是一個list。
axis:需要進行連結的維度,若矩陣是n維的,則axis的取值為0~n-1。
name:名稱,是一個可選引數。

import tensorflow as tf

# Tensorflow互動式會話
tf.InteractiveSession()

# 定義兩個矩陣,大小為2x3x4
a = tf.Variable(tf.truncated_normal(shape=[2,3,4], dtype=tf.float32))

b = tf.Variable(tf.truncated_normal(shape=[2,3,4], dtype=tf.float32))

# 按照維度0連結
c1 = tf.concat([a, b], axis=0)

# 按照維度1連結
c2 = tf.concat([a, b], axis=1)

# 按照維度2連結
c3 = tf.concat([a, b], axis=2)

# 初始化變數
tf.global_variables_initializer().run()

# 輸出
print("01")

print(c1)

print(c1.eval())

print("02")

print(c2)

print(c2.eval())

print("03")

print(c3)

print(c3.eval())

程式執行結果如下:

01
Tensor("concat:0", shape=(4, 3, 4), dtype=float32)
[[[-0.08826777  1.92810595 -0.79408133 -0.34322619]
  [-1.71443737  0.70375884 -0.78194672 -0.41254947]
  [ 0.89348751 -0.08941202  0.70108914  0.64701825]]

 [[ 1.50688016  0.45680258 -1.08100998  0.24127837]
  [ 0.58221173 -1.41846514 -1.63450527 -0.41922286]
  [ 0.48436531 -1.20013559  0.95647675 -0.03131635]]

 [[-0.03254275 -1.8339541  -0.81978613 -1.25303519]
  [-1.55067682 -0.37825376 -0.63578284 -0.83120823]
  [ 0.09672505 -0.43550658 -0.31754431 -0.37109831]]

 [[ 1.59722102 -0.32856748 -1.33017409  1.43195128]
  [-0.58259052 -1.60538054  0.07504115  0.8916716 ]
  [-1.23682356 -0.24931362  1.19812703 -0.81907171]]]
02
Tensor("concat_1:0", shape=(2, 6, 4), dtype=float32)
[[[-0.08826777  1.92810595 -0.79408133 -0.34322619]
  [-1.71443737  0.70375884 -0.78194672 -0.41254947]
  [ 0.89348751 -0.08941202  0.70108914  0.64701825]
  [-0.03254275 -1.8339541  -0.81978613 -1.25303519]
  [-1.55067682 -0.37825376 -0.63578284 -0.83120823]
  [ 0.09672505 -0.43550658 -0.31754431 -0.37109831]]

 [[ 1.50688016  0.45680258 -1.08100998  0.24127837]
  [ 0.58221173 -1.41846514 -1.63450527 -0.41922286]
  [ 0.48436531 -1.20013559  0.95647675 -0.03131635]
  [ 1.59722102 -0.32856748 -1.33017409  1.43195128]
  [-0.58259052 -1.60538054  0.07504115  0.8916716 ]
  [-1.23682356 -0.24931362  1.19812703 -0.81907171]]]
03
Tensor("concat_2:0", shape=(2, 3, 8), dtype=float32)
[[[-0.08826777  1.92810595 -0.79408133 -0.34322619 -0.03254275 -1.8339541
   -0.81978613 -1.25303519]
  [-1.71443737  0.70375884 -0.78194672 -0.41254947 -1.55067682 -0.37825376
   -0.63578284 -0.83120823]
  [ 0.89348751 -0.08941202  0.70108914  0.64701825  0.09672505 -0.43550658
   -0.31754431 -0.37109831]]

 [[ 1.50688016  0.45680258 -1.08100998  0.24127837  1.59722102 -0.32856748
   -1.33017409  1.43195128]
  [ 0.58221173 -1.41846514 -1.63450527 -0.41922286 -0.58259052 -1.60538054
    0.07504115  0.8916716 ]
  [ 0.48436531 -1.20013559  0.95647675 -0.03131635 -1.23682356 -0.24931362
    1.19812703 -0.81907171]]]