import torch
import torch.nn as nn
import torch.nn.functional as F
class AudioClassifier(nn.Module):
def __init__(self, num_classes):
super(AudioClassifier, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(2, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
)
self.classifier = nn.Sequential(
nn.Linear(128 * 25 * 22, 512),
nn.ReLU(),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(-1, 128 * 25 * 22)
x = self.classifier(x)
return x
if __name__ == '__main__':
net = AudioClassifier(num_classes=5)
temp = torch.randn(32, 1, 200, 176)
out = net(temp)
print(out)
print('out', out.shape)
主函数:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from conv_audio import AudioClassifier
from dataset import WavDataset
from tqdm import tqdm, trange
def main():
batchsz = 32
num_classes = 5
train_Data = DataLoader(WavDataset('2s_wav_mini'), batch_size=batchsz, shuffle=True)
test_Data = DataLoader(WavDataset('2s_wav_mini'), batch_size=batchsz, shuffle=True)
device = torch.device('cuda')
model = AudioClassifier(num_classes).to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(100):
loss = 0
train_total = 0
train_correct = 0
tqdm(total=100)
for batchidx, (x, label) in enumerate(train_Data):
# x = [b , 2, 200, 176 ]
# label = [b]
x, label = x.to(device), label.to(device)
logits = model(x)
# logits = [b,5]
# label = [b]
# loss = tensor scalar
loss = criteon(logits, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predicted = torch.max(logits.data, 1)
train_total += label.size(0)
train_correct += (predicted == label).sum().item()
# print('epoch', epoch, "loss", loss.item())
print('epoch:', epoch, 'train_correct_rate:', train_correct / train_total, 'train_loss:', loss.item())
if __name__ == '__main__':
main()