歡迎您光臨本站 註冊首頁

TensorFlow中如何確定張量的形狀例項

←手機掃碼閱讀     kyec555 @ 2020-06-24 , reply:0

我們可以使用tf.shape()獲取某張量的形狀張量。

  import tensorflow as tf  x = tf.reshape(tf.range(1000), [10, 10, 10])  sess = tf.Session()  sess.run(tf.shape(x))     Out[1]: array([10, 10, 10])

 

我們可以使用tf.shape()在計算圖中確定改變張量的形狀。

  high = tf.shape(x)[0] // 2  width = tf.shape(x)[1] * 2  x_reshape = tf.reshape(x, [high, width, -1])  sess.run(tf.shape(x_reshape))     Out: array([ 5, 20, 10])

 

我們可以使用tf.shape_n()在計算圖中得到若干個張量的形狀。

  y = tf.reshape(tf.range(504), [7,8,9])  sess.run(tf.shape_n([x, y]))     Out: [array([10, 10, 10]), array([7, 8, 9])]

 

我們可以使用tf.size()獲取張量的元素個數。

sess.run([tf.size(x), tf.size(y)])

Out: [1000, 504]

tensor.get_shape()或者tensor.shape是無法在計算圖中用於確定張量的形狀。

  In [20]: x.get_shape()  Out[20]: TensorShape([Dimension(10), Dimension(10), Dimension(10)])     In [21]: x.get_shape()[0]  Out[21]: Dimension(10)     In [22]: type(x.get_shape()[0])  Out[22]: tensorflow.python.framework.tensor_shape.Dimension     In [23]: x.get_shape()  Out[23]: TensorShape([Dimension(10), Dimension(10), Dimension(10)])     In [24]: sess.run(x.get_shape())  ---------------------------------------------------------------------------  TypeError     Traceback (most recent call last)  ~Anaconda3libsite-packages	ensorflowpythonclientsession.py in __init__(self, fetches, contraction_fn)   299  self._unique_fetches.append(ops.get_default_graph().as_graph_element(  --> 300  fetch, allow_tensor=True, allow_operation=True))   301 except TypeError as e:     ~Anaconda3libsite-packages	ensorflowpythonframeworkops.py in as_graph_element(self, obj, allow_tensor, allow_operation)   3477 with self._lock:  -> 3478 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)   3479     ~Anaconda3libsite-packages	ensorflowpythonframeworkops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)   3566 raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,  -> 3567        types_str))   3568     TypeError: Can not convert a TensorShapeV1 into a Tensor or Operation.     During handling of the above exception, another exception occurred:     TypeError     Traceback (most recent call last)in----> 1 sess.run(x.get_shape())     ~Anaconda3libsite-packages	ensorflowpythonclientsession.py in run(self, fetches, feed_dict, options, run_metadata)   927 try:   928 result = self._run(None, fetches, feed_dict, options_ptr,  --> 929    run_metadata_ptr)   930 if run_metadata:   931  proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)     ~Anaconda3libsite-packages	ensorflowpythonclientsession.py in _run(self, handle, fetches, feed_dict, options, run_metadata)   1135 # Create a fetch handler to take care of the structure of fetches.   1136 fetch_handler = _FetchHandler(  -> 1137  self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)   1138   1139 # Run request and get response.     ~Anaconda3libsite-packages	ensorflowpythonclientsession.py in __init__(self, graph, fetches, feeds, feed_handles)   469 """   470 with graph.as_default():  --> 471 self._fetch_mapper = _FetchMapper.for_fetch(fetches)   472 self._fetches = []   473 self._targets = []  ~Anaconda3libsite-packages	ensorflowpythonclientsession.py in for_fetch(fetch)   269  if isinstance(fetch, tensor_type):   270  fetches, contraction_fn = fetch_fn(fetch)  --> 271  return _ElementFetchMapper(fetches, contraction_fn)   272 # Did not find anything.   273 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,  ~Anaconda3libsite-packages	ensorflowpythonclientsession.py in __init__(self, fetches, contraction_fn)   302  raise TypeError('Fetch argument %r has invalid type %r, '   303    'must be a string or Tensor. (%s)' %  --> 304    (fetch, type(fetch), str(e)))   305 except ValueError as e:   306  raise ValueError('Fetch argument %r cannot be interpreted as a '  TypeError: Fetch argument TensorShape([Dimension(10), Dimension(10), Dimension(10)]) has invalid type, must be a string or Tensor. (Can not convert a TensorShapeV1 into a Tensor or Operation.)

 

我們可以使用tf.rank()來確定張量的秩。tf.rank()會返回一個代表張量秩的張量,可直接在計算圖中使用。

  In [25]: tf.rank(x)  Out[25]:In [26]: sess.run(tf.rank(x))  Out[26]: 3

 

補充知識:tensorflow迴圈改變tensor的值

使用tf.concat()實現4維tensor的迴圈賦值

  alist=[[[[1,1,1],[2,2,2],[3,3,3]],[[4,4,4],[5,5,5],[6,6,6]]],[[[7,7,7],[8,8,8],[9,9,9]],[[10,10,10],[11,11,11],[12,12,12]]]] #2,2,3,3-n,c,h,w  kenel=(np.asarray(alist)*2).tolist()  print(kenel)  inputs=tf.constant(alist,dtype=tf.float32)  kenel=tf.constant(kenel,dtype=tf.float32)  inputs=tf.transpose(inputs,[0,2,3,1]) #n,h,w,c  kenel=tf.transpose(kenel,[0,2,3,1]) #n,h,w,c  uints=inputs.get_shape()  h=int(uints[1])  w=int(uints[2])  encoder_output=[]  for b in range(int(uints[0])):   encoder_output_c=[]   for c in range(int(uints[-1])):    one_channel_in = inputs[b, :, :, c]    one_channel_in = tf.reshape(one_channel_in, [1, h, w, 1])    one_channel_kernel = kenel[b, :, :, c]    one_channel_kernel = tf.reshape(one_channel_kernel, [h, w, 1, 1])    encoder_output_cc = tf.nn.conv2d(input=one_channel_in, filter=one_channel_kernel, strides=[1, 1, 1, 1], padding="SAME")    if c==0:     encoder_output_c=encoder_output_cc    else:     encoder_output_c=tf.concat([encoder_output_c,encoder_output_cc],axis=3)     if b==0:    encoder_output=encoder_output_c   else:    encoder_output = tf.concat([encoder_output, encoder_output_c], axis=0)    with tf.Session() as sess:   print(sess.run(tf.transpose(encoder_output,[0,3,1,2])))   print(encoder_output.get_shape())

 

輸出:

  [[[[ 32. 48. 32.]   [ 56. 84. 56.]   [ 32. 48. 32.]]     [[ 200. 300. 200.]   [ 308. 462. 308.]   [ 200. 300. 200.]]]       [[[ 512. 768. 512.]   [ 776. 1164. 776.]   [ 512. 768. 512.]]     [[ 968. 1452. 968.]   [1460. 2190. 1460.]   [ 968. 1452. 968.]]]]  (2, 3, 3, 2)

 


[kyec555 ] TensorFlow中如何確定張量的形狀例項已經有420次圍觀

http://coctec.com/docs/program/show-post-239532.html