### 使用 PyTorch 实现 CNN-BiLSTM 模型进行回归任务
为了构建一个适用于回归问题的 CNN-BiLSTM 模型,在 PyTorch 中需要定义模型结构并配置训练流程。下面是一个详细的实现方案。
#### 定义 CNN-BiLSTM 模型类
```python
import torch
from torch import nn
class CNNBiLSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout_prob):
super(CNNBiLSTM, self).__init__()
# Define the convolutional layer to extract spatial features from data
self.conv_layer = nn.Sequential(
nn.Conv1d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=3),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2)
)
# Initialize BiLSTM layers with specified parameters
self.lstm = nn.LSTM(input_size=hidden_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
bidirectional=True,
batch_first=True)
# Dropout layer for regularization and preventing overfitting
self.dropout = nn.Dropout(dropout_prob)
# Fully connected layer that maps LSTM outputs to regression predictions
self.fc = nn.Linear(hidden_dim * 2, output_dim)
def forward(self, x):
conv_out = self.conv_layer(x.permute(0, 2, 1))
lstm_input = conv_out.permute(0, 2, 1)
lstm_output, _ = self.lstm(lstm_input)
last_hidden_state = lstm_output[:, -1, :]
dropped_last_hidden_state = self.dropout(last_hidden_state)
prediction = self.fc(dropped_last_hidden_state)
return prediction
```
此代码片段展示了如何创建一个继承自 `nn.Module` 的新类 `CNNBiLSTM`,其中包含了卷积层、双向 LSTM 层以及全连接层的设计[^2]。
#### 设置损失函数和优化器
对于回归问题而言,均方误差 (MSE Loss) 是一种常用的损失函数;Adam 则作为默认的选择之一被广泛应用于梯度下降法中:
```python
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
```
#### 训练循环
接下来是完整的训练过程,包括前向传播、计算损失值、反向传播更新权重等操作:
```python
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.to(device).float()
labels = labels.to(device).float().view(-1, 1)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
```
上述代码实现了批量处理输入数据,并通过调用 `.to(device)` 方法确保所有张量都在同一设备上运行(CPU 或 GPU)。此外还加入了简单的进度条显示功能以便观察训练情况[^1]。