tensorflow--张量变换--3--矩阵提取( tf.gather )

函数概况

tf.gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

作用:根据indeces收集在params上的值。如图:

二、例子

import tensorflow as tf

temp = tf.range(0,10)*10 + tf.constant(1,shape=[10])
#收集下标1、5、9处的值
temp2 = tf.gather(temp,[1,5,9])

with tf.Session() as sess:

    print(sess.run(temp))
    print(sess.run(temp2))

输出

[ 1 11 21 31 41 51 61 71 81 91] 
[11 51 91]
药企,独角兽,苏州。团队长期招人,感兴趣的都可以发邮件聊聊:tiehan@sina.cn
个人公众号,比较懒,很少更新,可以在上面提问题,如果回复不及时,可发邮件给我: tiehan@sina.cn