torchについてstack()の例

2377 ワード

axisのパラメータは、torchの例に示すように、元の0次元をaxisの位置に移動するように理解することができる.
import numpy as np
import torch
a = np.arange(1,41,1).reshape(2,4,5)
ta = torch.from_numpy(a)
b = np.arange(41,81,1).reshape(2,4,5)
c = np.arange(81,121,1).reshape(2,4,5)
tb = torch.from_numpy(b)
tc = torch.from_numpy(c)

print(torch.stack((ta,tb,tc),axis = 0).shape) #3*2*4*5
print(torch.stack((ta,tb,tc),axis = 1).shape) #2*3*4*5
print(torch.stack((ta,tb,tc),axis = 2).shape) #2*4*3*5
print(torch.stack((ta,tb,tc),axis = 3).shape) #2*4*5*3

# torch.stack((ta,tb,tc),axis = 1)


print(np.stack((a,b,c),axis=0).shape)
print(np.stack((a,b,c),axis=1).shape)
print(np.stack((a,b,c),axis=2).shape)
print(np.stack((a,b,c),axis=3).shape)


print(np.stack(a,axis = 0).shape)
print(np.stack(a,axis = 0))
print(np.stack(a,axis = 1).shape)
print(np.stack(a,axis = 1))
print(np.stack(a,axis = 2).shape)
print(np.stack(a,axis = 2))



print(a)
print(np.vstack(a).shape)
print(np.vstack(a))
print(np.hstack(a).shape)
print(np.hstack(a))




print(np.vstack(va).shape)
print(np.vstack(va))
print(np.hstack(va).shape)
print(np.hstack(va))



/**
torch.Size([3, 2, 4, 5])
torch.Size([2, 3, 4, 5])
torch.Size([2, 4, 3, 5])
torch.Size([2, 4, 5, 3])
(3, 2, 4, 5)
(2, 3, 4, 5)
(2, 4, 3, 5)
(2, 4, 5, 3)
(2, 4, 5)
[[[ 1  2  3  4  5]
  [ 6  7  8  9 10]
  [11 12 13 14 15]
  [16 17 18 19 20]]

 [[21 22 23 24 25]
  [26 27 28 29 30]
  [31 32 33 34 35]
  [36 37 38 39 40]]]
(4, 2, 5)
[[[ 1  2  3  4  5]
  [21 22 23 24 25]]

 [[ 6  7  8  9 10]
  [26 27 28 29 30]]

 [[11 12 13 14 15]
  [31 32 33 34 35]]

 [[16 17 18 19 20]
  [36 37 38 39 40]]]
(4, 5, 2)
[[[ 1 21]
  [ 2 22]
  [ 3 23]
  [ 4 24]
  [ 5 25]]

 [[ 6 26]
  [ 7 27]
  [ 8 28]
  [ 9 29]
  [10 30]]

 [[11 31]
  [12 32]
  [13 33]
  [14 34]
  [15 35]]

 [[16 36]
  [17 37]
  [18 38]
  [19 39]
  [20 40]]]
[[[ 1  2  3  4  5]
  [ 6  7  8  9 10]
  [11 12 13 14 15]
  [16 17 18 19 20]]

 [[21 22 23 24 25]
  [26 27 28 29 30]
  [31 32 33 34 35]
  [36 37 38 39 40]]]
(8, 5)
[[ 1  2  3  4  5]
 [ 6  7  8  9 10]
 [11 12 13 14 15]
 [16 17 18 19 20]
 [21 22 23 24 25]
 [26 27 28 29 30]
 [31 32 33 34 35]
 [36 37 38 39 40]]
(4, 10)
[[ 1  2  3  4  5 21 22 23 24 25]
 [ 6  7  8  9 10 26 27 28 29 30]
 [11 12 13 14 15 31 32 33 34 35]
 [16 17 18 19 20 36 37 38 39 40]]
(3, 4)
[[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]]
(12,)
[ 1  2  3  4  5  6  7  8  9 10 11 12]
**/