import torch
help((torch.cat)) cat
cat(seq,dim,out=None)
seq , , :seq=(a,b), a,b
dim ,dim=0,
dim=1,
# :
#dim=0 :
import torch
n_data = torch.ones((100,2))
x0_data = torch.normal(2*n_data,1)
y0_data = torch.zeros((100,1))
x1_data = torch.normal(-2*n_data,1)
y1_data = torch.ones((100,1))
x_data = torch.cat((x0_data,x1_data),0).type(torch.FloatTensor)
y_data = torch.cat((y0_data,y1_data),0).type(torch.LongTensor)
print('x_data :',x_data.shape)
print("y_data :",y_data.shape)
result:
x_data : torch.Size([200, 2])
y_data : torch.Size([200, 1])
# :
#dim=1 :
import torch
n_data = torch.ones((100,2))
x0_data = torch.normal(2*n_data,1)
y0_data = torch.zeros((100,1))
x1_data = torch.normal(-2*n_data,1)
y1_data = torch.ones((100,1))
x_data = torch.cat((x0_data,x1_data),1).type(torch.FloatTensor)
y_data = torch.cat((y0_data,y1_data),1).type(torch.LongTensor)
print('x_data :',x_data.shape)
print("y_data :",y_data.shape)
result:
x_data : torch.Size([100, 4])
y_data : torch.Size([100, 2])