【深度学习】GRU的结构图及公式

本文深入探讨了GRU和LSTM两种循环神经网络结构的差异,详细讲解了GRU如何通过简化结构来解决长距离依赖问题,同时保持与LSTM相似的功能,适合对循环神经网络有兴趣的读者。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

GRU与LSTM的区别

前面说到过LSTM的出现是为了解决传统RNN无法解决的长距离依赖问题而出现的,而GRU也具有该功能,但是结构相对于LSTM来说相对简单,可以将GRU看作是LSTM的一种优化或变体。

GRU的结构图

在这里插入图片描述

前向传播公式

其中“*”代表矩阵乘法,“⋅”代表点乘(相应位置的元素乘相应位置的元素)
rt=sigmoid(Wr∗[ht−1,xt]+br) r_t = sigmoid(W_r*[h_{t-1},x_t] + b_r)rt=sigmoid(Wr[ht1,xt]+br)
zt=sigmoid(Wt∗[ht−1,xt]+bz)z_t = sigmoid(W_t*[h_{t-1},x_t] + b_z)zt=sigmoid(Wt[ht1,xt]+bz)
h~t=tanh(Wh~∗[rt⋅ht−1,xt]+bh~)\widetilde h_t=tanh(W_{\widetilde h}*[r_t⋅h_{t−1},x_t]+b_{\widetilde h})ht=tanh(Wh[rtht1,xt]+bh)
ht=(1−zt)⋅ht−1+zt⋅h~th_t = (1-z_t)⋅h_{t-1} + z_t⋅{\widetilde h_t}ht=(1zt)ht1+ztht
yt=sigmoid(Wo∗ht+by)y_t = sigmoid(W_o*h_t + b_y)yt=sigmoid(Woht+by)

### 绘制CNN-GRU模型结构图 对于绘制CNN-GRU模型结构图,可以利用`plot_model`函数来自Keras实用工具库。此功能允许以图形方式展示模型架构,有助于理解各层之间的连接关系以及每层的具体配置。 ```python from tensorflow.keras.utils import plot_model import matplotlib.pyplot as plt def visualize_cnn_gru_structure(model, file_path='model.png'): """ 可视化给定的CNN-GRU模型结构 参数: model : Keras Model 对象 要可视化的模型实例 file_path : str 存储生成图像文件的位置,默认为'model.png' 返回: None """ # 使用keras自带的方法保存图片到指定路径 plot_model(model, to_file=file_path, show_shapes=True) # 假设已经定义好了名为cnn_gru_model的模型 visualize_cnn_gru_structure(cnn_gru_model) plt.imshow(plt.imread('model.png')) plt.axis('off') plt.show() ``` ### 训练过程可视化 针对训练过程中的性能指标(如损失值、准确度等),可以通过matplotlib库来实现图表形式的表现。这不仅限于简单的折线图,还可以扩展至更复杂的多子图布局,以便同时监控多个评价标准的变化趋势[^1]。 ```python history = cnn_gru_model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_val, y_val)) def plot_training_history(history): """ 展示训练历史记录中的loss和accuracy变化情况 参数: history : History object fit返回的历史对象 返回: None """ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5)) # Loss Plot ax[0].set_title('Loss') ax[0].plot(history.history['loss'], label='Train', color="blue") ax[0].plot(history.history['val_loss'], label='Validation', linestyle="--", color="orange") ax[0].legend(loc='upper right') # Accuracy Plot (if available in the metrics during training) if 'accuracy' in history.history.keys(): ax[1].set_title('Accuracy') ax[1].plot(history.history['accuracy'], label='Train', color="green") ax[1].plot(history.history['val_accuracy'], label='Validation', linestyle="--", color="red") ax[1].legend(loc='lower right') plt.tight_layout() plt.show() plot_training_history(history) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值