import torch
from torch.nn import functional as F
X = torch.tensor([[0,0],[1,0],[0,1],[1,1]], dtype = torch.float32)
torch.random.manual_seed(420)
dense = torch.nn.Linear(2,3)
zhat = dense(X)
print(zhat)
sigma = F.softmax(zhat,dim=1)
print(sigma)
#sigma = F.softmax(zhat,dim=1)
这里的dim=1还是等于0或-1.算完可以看看,总共就3类,那肯定是三个数相加等于1啊。