【人工智能项目】LSTM实现数据预测分类实验
本次主要对csv文件中采集到的数据来区分树的品种实验,通过不同列的数据,送入lstm模型中,得到预测结果。

导包
# 导包
import numpy as np
import pandas as pd
import glob
import os
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
train_path = "./"
print(os.listdir(train_path))
['.ipynb_checkpoints', 'code.ipynb', 'data.csv', 'lstm.h5', 'plant_totaldatat.xlsx']
读取数据
# 读取文件
data = pd.read_csv("data.csv")
data
| Filename | 172.538 | 173.141 | 173.744 | 174.348 | 174.951 | 175.554 | 176.157 | 176.76 | 177.363 | ... | 1165.846 | 1166.373 | 1166.9 | 1167.427 | 1167.954 | 1168.481 | 1169.008 | 1169.535 | 1170.061 | Label | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 芭蕉0001.ROH | 414.421417 | 445.234558 | 482.571625 | 378.288757 | 483.976776 | 476.850617 | 423.253845 | 445.033813 | 477.653564 | ... | 487.088196 | 513.986938 | 532.956604 | 545.502625 | 504.853424 | 568.687744 | 547.811096 | 584.947449 | 564.773376 | 1 |
| 1 | 芭蕉0002.ROH | 469.523712 | 450.353333 | 447.543030 | 457.880981 | 467.616699 | 456.375458 | 483.575287 | 447.543030 | 415.224365 | ... | 560.357178 | 511.477722 | 473.337708 | 613.151001 | 513.384766 | 495.418793 | 618.771606 | 618.570923 | 495.619507 | 1 |
| 2 | 芭蕉0003.ROH | 508.265930 | 502.946411 | 522.317505 | 471.932556 | 512.682129 | 503.950104 | 498.429840 | 487.891144 | 465.910461 | ... | 597.694275 | 552.126953 | 540.885681 | 661.327881 | 553.030273 | 540.183106 | 650.889526 | 659.521240 | 550.219971 | 1 |
| 3 | 芭蕉0004.ROH | 490.801819 | 514.789917 | 529.945557 | 463.501617 | 527.536682 | 525.027466 | 489.898499 | 514.288025 | 503.247528 | ... | 567.784424 | 576.416138 | 573.906921 | 625.596680 | 548.915161 | 621.280823 | 632.221008 | 652.595825 | 599.199768 | 1 |
| 4 | 芭蕉0005.ROH | 431.383697 | 433.290680 | 436.703217 | 408.901154 | 459.386505 | 461.694977 | 453.264008 | 435.900269 | 438.810974 | ... | 535.867249 | 499.232788 | 503.849731 | 569.691467 | 518.704285 | 512.381043 | 577.921692 | 573.605835 | 520.812012 | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 2422 | 樟树10096.ROH | 376.682861 | 396.656189 | 391.637787 | 371.363342 | 434.796234 | 390.634094 | 378.489502 | 410.406677 | 394.147003 | ... | 478.155395 | 441.520904 | 413.317383 | 508.265930 | 447.844116 | 424.558624 | 497.927978 | 522.618652 | 449.951874 | 7 |
| 2423 | 樟树10097.ROH | 312.647797 | 359.419495 | 336.836578 | 315.056641 | 381.299835 | 351.891876 | 333.925903 | 377.586182 | 353.397400 | ... | 434.495117 | 420.242798 | 342.758331 | 476.047668 | 397.258423 | 369.054871 | 446.639709 | 460.390167 | 384.913086 | 7 |
| 2424 | 樟树10098.ROH | 383.809052 | 372.166290 | 419.941681 | 371.363342 | 412.112946 | 411.912201 | 399.566894 | 382.905731 | 405.287903 | ... | 438.208740 | 460.089081 | 427.469330 | 478.556885 | 463.300873 | 468.620392 | 485.181183 | 525.328613 | 500.035736 | 7 |
| 2425 | 樟树10099.ROH | 327.100861 | 333.725159 | 347.676392 | 332.621124 | 376.181030 | 364.538300 | 361.727966 | 377.786926 | 347.274902 | ... | 417.934326 | 411.410370 | 377.184723 | 433.190338 | 413.919586 | 395.752899 | 432.989593 | 445.636017 | 425.562317 | 7 |
| 2426 | 樟树10100.ROH | 380.697601 | 424.859741 | 441.119446 | 388.526367 | 448.446320 | 433.089966 | 416.428803 | 431.383697 | 450.654449 | ... | 500.838684 | 497.526520 | 458.182068 | 537.673889 | 493.210663 | 465.709717 | 551.625122 | 561.059753 | 480.263153 | 7 |
2427 rows × 1757 columns
# 查看前5行数据
data.head()
| Filename | 172.538 | 173.141 | 173.744 | 174.348 | 174.951 | 175.554 | 176.157 | 176.76 | 177.363 | ... | 1165.846 | 1166.373 | 1166.9 | 1167.427 | 1167.954 | 1168.481 | 1169.008 | 1169.535 | 1170.061 | Label | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 芭蕉0001.ROH | 414.421417 | 445.234558 | 482.571625 | 378.288757 | 483.976776 | 476.850617 | 423.253845 | 445.033813 | 477.653564 | ... | 487.088196 | 513.986938 | 532.956604 | 545.502625 | 504.853424 | 568.687744 | 547.811096 | 584.947449 | 564.773376 | 1 |
| 1 | 芭蕉0002.ROH | 469.523712 | 450.353333 | 447.543030 | 457.880981 | 467.616699 | 456.375458 | 483.575287 | 447.543030 | 415.224365 | ... | 560.357178 | 511.477722 | 473.337708 | 613.151001 | 513.384766 | 495.418793 | 618.771606 | 618.570923 | 495.619507 | 1 |
| 2 | 芭蕉0003.ROH | 508.265930 | 502.946411 | 522.317505 | 471.932556 | 512.682129 | 503.950104 | 498.429840 | 487.891144 | 465.910461 | ... | 597.694275 | 552.126953 | 540.885681 | 661.327881 | 553.030273 | 540.183106 | 650.889526 | 659.521240 | 550.219971 | 1 |
| 3 | 芭蕉0004.ROH | 490.801819 | 514.789917 | 529.945557 | 463.501617 | 527.536682 | 525.027466 | 489.898499 | 514.288025 | 503.247528 | ... | 567.784424 | 576.416138 | 573.906921 | 625.596680 | 548.915161 | 621.280823 | 632.221008 | 652.595825 | 599.199768 | 1 |
| 4 | 芭蕉0005.ROH | 431.383697 | 433.290680 | 436.703217 | 408.901154 | 459.386505 | 461.694977 | 453.264008 | 435.900269 | 438.810974 | ... | 535.867249 | 499.232788 | 503.849731 | 569.691467 | 518.704285 | 512.381043 | 577.921692 | 573.605835 | 520.812012 | 1 |
5 rows × 1757 columns
数据分析
data.index
RangeIndex(start=0, stop=2427, step=1)
print(data.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2427 entries, 0 to 2426
Columns: 1757 entries, Filename to Label
dtypes: float64(1755), int64(1), object(1)
memory usage: 32.5+ MB
None
# 去除缺失数据
data.dropna(axis=0, how='any', inplace=True)
data
| Filename | 172.538 | 173.141 | 173.744 | 174.348 | 174.951 | 175.554 | 176.157 | 176.76 | 177.363 | ... | 1165.846 | 1166.373 | 1166.9 | 1167.427 | 1167.954 | 1168.481 | 1169.008 | 1169.535 | 1170.061 | Label | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 芭蕉0001.ROH | 414.421417 | 445.234558 | 482.571625 | 378.288757 | 483.976776 | 476.850617 | 423.253845 | 445.033813 | 477.653564 | ... | 487.088196 | 513.986938 | 532.956604 | 545.502625 | 504.853424 | 568.687744 | 547.811096 | 584.947449 | 564.773376 | 1 |
| 1 | 芭蕉0002.ROH | 469.523712 | 450.353333 | 447.543030 | 457.880981 | 467.616699 | 456.375458 | 483.575287 | 447.543030 | 415.224365 | ... | 560.357178 | 511.477722 | 473.337708 | 613.151001 | 513.384766 | 495.418793 | 618.771606 | 618.570923 | 495.619507 | 1 |
| 2 | 芭蕉0003.ROH | 508.265930 | 502.946411 | 522.317505 | 471.932556 | 512.682129 | 503.950104 | 498.429840 | 487.891144 | 465.910461 | ... | 597.694275 | 552.126953 | 540.885681 | 661.327881 | 553.030273 | 540.183106 | 650.889526 | 659.521240 | 550.219971 | 1 |
| 3 | 芭蕉0004.ROH | 490.801819 | 514.789917 | 529.945557 | 463.501617 | 527.536682 | 525.027466 | 489.898499 | 514.288025 | 503.247528 | ... | 567.784424 | 576.416138 | 573.906921 | 625.596680 | 548.915161 | 621.280823 | 632.221008 | 652.595825 | 599.199768 | 1 |
| 4 | 芭蕉0005.ROH | 431.383697 | 433.290680 | 436.703217 | 408.901154 | 459.386505 | 461.694977 | 453.264008 | 435.900269 | 438.810974 | ... | 535.867249 | 499.232788 | 503.849731 | 569.691467 | 518.704285 | 512.381043 | 577.921692 | 573.605835 | 520.812012 | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 2422 | 樟树10096.ROH | 376.682861 | 396.656189 | 391.637787 | 371.363342 | 434.796234 | 390.634094 | 378.489502 | 410.406677 | 394.147003 | ... | 478.155395 | 441.520904 | 413.317383 | 508.265930 | 447.844116 | 424.558624 | 497.927978 | 522.618652 | 449.951874 | 7 |
| 2423 | 樟树10097.ROH | 312.647797 | 359.419495 | 336.836578 | 315.056641 | 381.299835 | 351.891876 | 333.925903 | 377.586182 | 353.397400 | ... | 434.495117 | 420.242798 | 342.758331 | 476.047668 | 397.258423 | 369.054871 | 446.639709 | 460.390167 | 384.913086 | 7 |
| 2424 | 樟树10098.ROH | 383.809052 | 372.166290 | 419.941681 | 371.363342 | 412.112946 | 411.912201 | 399.566894 | 382.905731 | 405.287903 | ... | 438.208740 | 460.089081 | 427.469330 | 478.556885 | 463.300873 | 468.620392 | 485.181183 | 525.328613 | 500.035736 | 7 |
| 2425 | 樟树10099.ROH | 327.100861 | 333.725159 | 347.676392 | 332.621124 | 376.181030 | 364.538300 | 361.727966 | 377.786926 | 347.274902 | ... | 417.934326 | 411.410370 | 377.184723 | 433.190338 | 413.919586 | 395.752899 | 432.989593 | 445.636017 | 425.562317 | 7 |
| 2426 | 樟树10100.ROH | 380.697601 | 424.859741 | 441.119446 | 388.526367 | 448.446320 | 433.089966 | 416.428803 | 431.383697 | 450.654449 | ... | 500.838684 | 497.526520 | 458.182068 | 537.673889 | 493.210663 | 465.709717 | 551.625122 | 561.059753 | 480.263153 | 7 |
2427 rows × 1757 columns
# 删除第一列数据
data = data.drop(['Filename'], axis=1)
data
| 172.538 | 173.141 | 173.744 | 174.348 | 174.951 | 175.554 | 176.157 | 176.76 | 177.363 | 177.966 | ... | 1165.846 | 1166.373 | 1166.9 | 1167.427 | 1167.954 | 1168.481 | 1169.008 | 1169.535 | 1170.061 | Label | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 414.421417 | 445.234558 | 482.571625 | 378.288757 | 483.976776 | 476.850617 | 423.253845 | 445.033813 | 477.653564 | 595.285400 | ... | 487.088196 | 513.986938 | 532.956604 | 545.502625 | 504.853424 | 568.687744 | 547.811096 | 584.947449 | 564.773376 | 1 |
| 1 | 469.523712 | 450.353333 | 447.543030 | 457.880981 | 467.616699 | 456.375458 | 483.575287 | 447.543030 | 415.224365 | 601.006409 | ... | 560.357178 | 511.477722 | 473.337708 | 613.151001 | 513.384766 | 495.418793 | 618.771606 | 618.570923 | 495.619507 | 1 |
| 2 | 508.265930 | 502.946411 | 522.317505 | 471.932556 | 512.682129 | 503.950104 | 498.429840 | 487.891144 | 465.910461 | 655.907959 | ... | 597.694275 | 552.126953 | 540.885681 | 661.327881 | 553.030273 | 540.183106 | 650.889526 | 659.521240 | 550.219971 | 1 |
| 3 | 490.801819 | 514.789917 | 529.945557 | 463.501617 | 527.536682 | 525.027466 | 489.898499 | 514.288025 | 503.247528 | 661.628967 | ... | 567.784424 | 576.416138 | 573.906921 | 625.596680 | 548.915161 | 621.280823 | 632.221008 | 652.595825 | 599.199768 | 1 |
| 4 | 431.383697 | 433.290680 | 436.703217 | 408.901154 | 459.386505 | 461.694977 | 453.264008 | 435.900269 | 438.810974 | 592.675842 | ... | 535.867249 | 499.232788 | 503.849731 | 569.691467 | 518.704285 | 512.381043 | 577.921692 | 573.605835 | 520.812012 | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 2422 | 376.682861 | 396.656189 | 391.637787 | 371.363342 | 434.796234 | 390.634094 | 378.489502 | 410.406677 | 394.147003 | 510.775147 | ... | 478.155395 | 441.520904 | 413.317383 | 508.265930 | 447.844116 | 424.558624 | 497.927978 | 522.618652 | 449.951874 | 7 |
| 2423 | 312.647797 | 359.419495 | 336.836578 | 315.056641 | 381.299835 | 351.891876 | 333.925903 | 377.586182 | 353.397400 | 424.960113 | ... | 434.495117 | 420.242798 | 342.758331 | 476.047668 | 397.258423 | 369.054871 | 446.639709 | 460.390167 | 384.913086 | 7 |
| 2424 | 383.809052 | 372.166290 | 419.941681 | 371.363342 | 412.112946 | 411.912201 | 399.566894 | 382.905731 | 405.287903 | 531.149963 | ... | 438.208740 | 460.089081 | 427.469330 | 478.556885 | 463.300873 | 468.620392 | 485.181183 | 525.328613 | 500.035736 | 7 |
| 2425 | 327.100861 | 333.725159 | 347.676392 | 332.621124 | 376.181030 | 364.538300 | 361.727966 | 377.786926 | 347.274902 | 443.227173 | ... | 417.934326 | 411.410370 | 377.184723 | 433.190338 | 413.919586 | 395.752899 | 432.989593 | 445.636017 | 425.562317 | 7 |
| 2426 | 380.697601 | 424.859741 | 441.119446 | 388.526367 | 448.446320 | 433.089966 | 416.428803 | 431.383697 | 450.654449 | 533.458435 | ... | 500.838684 | 497.526520 | 458.182068 | 537.673889 | 493.210663 | 465.709717 | 551.625122 | 561.059753 | 480.263153 | 7 |
2427 rows × 1756 columns
# 样本分布
# 以图方式表示
sns.countplot(data["Label"])
plt.xlabel("Label")
plt.title("Number of messages")
Text(0.5, 1.0, 'Number of messages')
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FIhX7uBd-1635931670390)(output_11_1.png)]](https://i-blog.csdnimg.cn/blog_migrate/95ced7a31c9d59e747abe8890b1fa2a2.png)
# 重新排序
df = data.sample(frac=1).reset_index(drop=True)
df
| 172.538 | 173.141 | 173.744 | 174.348 | 174.951 | 175.554 | 176.157 | 176.76 | 177.363 | 177.966 | ... | 1165.846 | 1166.373 | 1166.9 | 1167.427 | 1167.954 | 1168.481 | 1169.008 | 1169.535 | 1170.061 | Label | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 429.978546 | 448.345978 | 447.342285 | 430.380005 | 473.839569 | 442.323853 | 457.178406 | 452.862549 | 429.175598 | 582.739380 | ... | 494.415100 | 501.942718 | 467.616699 | 526.934509 | 482.170136 | 509.470367 | 526.432617 | 595.385803 | 530.246643 | 6 |
| 1 | 281.834656 | 293.979248 | 335.431427 | 288.057526 | 310.238953 | 317.365112 | 305.822723 | 321.179108 | 327.502319 | 398.462830 | ... | 329.911163 | 375.578827 | 351.691132 | 373.270355 | 349.984863 | 390.132263 | 369.255615 | 398.061371 | 381.199463 | 7 |
| 2 | 440.316498 | 426.164520 | 453.665497 | 430.480377 | 450.052216 | 461.594605 | 456.174713 | 440.517212 | 444.130493 | 607.429993 | ... | 515.291748 | 490.400360 | 489.998871 | 569.992554 | 481.567932 | 517.198731 | 559.353516 | 549.617737 | 548.614075 | 3 |
| 3 | 285.247192 | 309.737091 | 289.362305 | 302.008728 | 340.750977 | 327.000488 | 323.688324 | 345.769379 | 316.963623 | 418.034698 | ... | 338.442474 | 373.169983 | 362.530914 | 384.511627 | 340.750977 | 376.983978 | 375.578827 | 407.997833 | 391.236298 | 1 |
| 4 | 458.081696 | 475.345093 | 447.743744 | 437.606537 | 476.950989 | 457.278748 | 469.523712 | 437.305420 | 436.803589 | 594.683228 | ... | 486.887451 | 487.991516 | 502.544952 | 553.532104 | 483.374573 | 506.158203 | 554.033997 | 552.327698 | 540.785339 | 4 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 2422 | 441.822022 | 464.706024 | 461.996063 | 427.368958 | 481.266815 | 490.299988 | 465.207855 | 473.939911 | 446.037476 | 610.240295 | ... | 523.421570 | 524.927124 | 519.406860 | 577.720947 | 499.232788 | 537.473144 | 581.434570 | 584.646362 | 551.625122 | 2 |
| 2423 | 290.767456 | 264.270203 | 298.395477 | 291.670776 | 303.915741 | 331.416687 | 325.394592 | 316.762909 | 328.204895 | 404.384583 | ... | 352.795166 | 357.010651 | 372.969238 | 393.544769 | 368.553040 | 391.236298 | 399.767639 | 412.915894 | 403.180145 | 4 |
| 2424 | 303.915741 | 286.652344 | 305.722382 | 293.176300 | 329.810791 | 303.815369 | 317.967316 | 328.907471 | 294.179993 | 419.239105 | ... | 344.866058 | 354.802551 | 357.512512 | 400.972046 | 342.557587 | 383.006103 | 390.634094 | 388.024506 | 372.668152 | 1 |
| 2425 | 428.473022 | 434.394745 | 487.991516 | 408.098206 | 464.003449 | 489.396667 | 456.174713 | 431.684814 | 458.583557 | 590.066223 | ... | 528.038513 | 561.059753 | 473.237335 | 578.624268 | 537.673889 | 496.221741 | 553.732849 | 616.162048 | 509.169250 | 2 |
| 2426 | 454.468445 | 441.621277 | 450.754822 | 424.357910 | 459.286133 | 460.289825 | 469.222595 | 459.787964 | 443.729004 | 592.274353 | ... | 494.916931 | 484.779724 | 499.533875 | 555.840576 | 474.441772 | 526.231873 | 546.707092 | 568.988892 | 528.239258 | 4 |
2427 rows × 1756 columns
# 空值检查
df[df.isnull().values==True]
| 172.538 | 173.141 | 173.744 | 174.348 | 174.951 | 175.554 | 176.157 | 176.76 | 177.363 | 177.966 | ... | 1165.846 | 1166.373 | 1166.9 | 1167.427 | 1167.954 | 1168.481 | 1169.008 | 1169.535 | 1170.061 | Label |
|---|
0 rows × 1756 columns
# 得到x和y
x = df.iloc[:,:-1]
y = df.iloc[:,-1]
划分数据集
# 划分数据集
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2)
# 数据归一化处理
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
scaler.fit(x_train)
X_train = scaler.transform(x_train)
X_test = scaler.transform(x_test)
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
(1941, 1755)
(1941,)
(486, 1755)
(486,)
# 对数据处理
from keras.utils import np_utils
X_train = X_train.reshape((-1,1,1755))
Y_train = np_utils.to_categorical(y_train)
X_test = X_test.reshape((-1,1,1755))
Y_test = np_utils.to_categorical(y_test)
print(X_train.shape)
print(Y_train.shape)
print(X_test.shape)
print(Y_test.shape)
(1941, 1, 1755)
(1941, 8)
(486, 1, 1755)
(486, 8)
模型
from keras import Sequential
from keras.layers import LSTM,Activation,Dense,Dropout,Input,Embedding,BatchNormalization,Add,concatenate,Flatten
model = Sequential()
model.add(LSTM(units=50,return_sequences=True,input_shape=(1,1755)))
model.add(Dropout(0.2))
model.add(LSTM(units=50,return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(units=50,return_sequences=True))
model.add(Dropout(0.2))
# model.add(LSTM(units=50,return_sequences=True))
# model.add(Dropout(0.2))
model.add(LSTM(units=50))
model.add(Dropout(0.2))
# model.add(Dense(units=256))
# model.add(Dropout(0.2))
model.add(Dense(units=128))
model.add(Dropout(0.2))
model.add(Dense(units=64))
model.add(Dropout(0.2))
model.add(Dense(units=16))
model.add(Dropout(0.2))
model.add(Dense(units=8,activation="softmax"))
# Implement Learning rate decay
from keras.callbacks import EarlyStopping,ReduceLROnPlateau,ModelCheckpoint,LearningRateScheduler
checkpoint = ModelCheckpoint("lstm.h5",
monitor="val_loss",
mode="min",
save_best_only = True,
verbose=1)
earlystop = EarlyStopping(monitor = 'val_loss',
min_delta = 0,
patience = 5,
verbose = 1,
restore_best_weights = True)
reduce_lr = ReduceLROnPlateau(monitor = 'val_loss',
factor = 0.2,
patience = 3,
verbose = 1)
#min_delta = 0.00001)
callbacks = [earlystop, checkpoint, reduce_lr]
model.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['accuracy'])
history_fit = model.fit(x=X_train,
y=Y_train,
batch_size=8,
epochs=30,
verbose=1,
validation_data=(X_test, Y_test),
callbacks=callbacks)
Train on 1941 samples, validate on 486 samples
Epoch 1/30
1941/1941 [==============================] - 6s 3ms/step - loss: 1.0300 - accuracy: 0.6188 - val_loss: 0.5473 - val_accuracy: 0.8313
Epoch 00001: val_loss improved from inf to 0.54729, saving model to lstm.h5
Epoch 2/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.6064 - accuracy: 0.7836 - val_loss: 0.3829 - val_accuracy: 0.8374
Epoch 00002: val_loss improved from 0.54729 to 0.38287, saving model to lstm.h5
Epoch 3/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.4797 - accuracy: 0.8089 - val_loss: 0.3595 - val_accuracy: 0.8272
Epoch 00003: val_loss improved from 0.38287 to 0.35947, saving model to lstm.h5
Epoch 4/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.4672 - accuracy: 0.8083 - val_loss: 0.2970 - val_accuracy: 0.8354
Epoch 00004: val_loss improved from 0.35947 to 0.29702, saving model to lstm.h5
Epoch 5/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3946 - accuracy: 0.8557 - val_loss: 0.2658 - val_accuracy: 0.9033
Epoch 00005: val_loss improved from 0.29702 to 0.26579, saving model to lstm.h5
Epoch 6/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3519 - accuracy: 0.8712 - val_loss: 0.2217 - val_accuracy: 0.8909
Epoch 00006: val_loss improved from 0.26579 to 0.22171, saving model to lstm.h5
Epoch 7/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3287 - accuracy: 0.8743 - val_loss: 0.2439 - val_accuracy: 0.8683
Epoch 00007: val_loss did not improve from 0.22171
Epoch 8/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3400 - accuracy: 0.8635 - val_loss: 0.2036 - val_accuracy: 0.9259
Epoch 00008: val_loss improved from 0.22171 to 0.20360, saving model to lstm.h5
Epoch 9/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3541 - accuracy: 0.8666 - val_loss: 0.2087 - val_accuracy: 0.9321
Epoch 00009: val_loss did not improve from 0.20360
Epoch 10/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3227 - accuracy: 0.8691 - val_loss: 0.2141 - val_accuracy: 0.9362
Epoch 00010: val_loss did not improve from 0.20360
Epoch 11/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2842 - accuracy: 0.8851 - val_loss: 0.1821 - val_accuracy: 0.9506
Epoch 00011: val_loss improved from 0.20360 to 0.18205, saving model to lstm.h5
Epoch 12/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3343 - accuracy: 0.8712 - val_loss: 0.2297 - val_accuracy: 0.8951
Epoch 00012: val_loss did not improve from 0.18205
Epoch 13/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3082 - accuracy: 0.8800 - val_loss: 0.2213 - val_accuracy: 0.9321
Epoch 00013: val_loss did not improve from 0.18205
Epoch 14/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2550 - accuracy: 0.9052 - val_loss: 0.1765 - val_accuracy: 0.9444
Epoch 00014: val_loss improved from 0.18205 to 0.17651, saving model to lstm.h5
Epoch 15/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.3290 - accuracy: 0.8856 - val_loss: 0.2044 - val_accuracy: 0.9383
Epoch 00015: val_loss did not improve from 0.17651
Epoch 16/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2812 - accuracy: 0.9031 - val_loss: 0.1578 - val_accuracy: 0.9465
Epoch 00016: val_loss improved from 0.17651 to 0.15778, saving model to lstm.h5
Epoch 17/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2332 - accuracy: 0.9145 - val_loss: 0.1287 - val_accuracy: 0.9547
Epoch 00017: val_loss improved from 0.15778 to 0.12870, saving model to lstm.h5
Epoch 18/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2597 - accuracy: 0.9114 - val_loss: 0.1607 - val_accuracy: 0.9280
Epoch 00018: val_loss did not improve from 0.12870
Epoch 19/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2570 - accuracy: 0.9052 - val_loss: 0.1230 - val_accuracy: 0.9671
Epoch 00019: val_loss improved from 0.12870 to 0.12305, saving model to lstm.h5
Epoch 20/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2401 - accuracy: 0.9129 - val_loss: 0.1639 - val_accuracy: 0.9588
Epoch 00020: val_loss did not improve from 0.12305
Epoch 21/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2233 - accuracy: 0.9155 - val_loss: 0.1172 - val_accuracy: 0.9671
Epoch 00021: val_loss improved from 0.12305 to 0.11718, saving model to lstm.h5
Epoch 22/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2524 - accuracy: 0.9088 - val_loss: 0.1627 - val_accuracy: 0.9588
Epoch 00022: val_loss did not improve from 0.11718
Epoch 23/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2185 - accuracy: 0.9176 - val_loss: 0.1313 - val_accuracy: 0.9342
Epoch 00023: val_loss did not improve from 0.11718
Epoch 24/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.2344 - accuracy: 0.9160 - val_loss: 0.1223 - val_accuracy: 0.9527
Epoch 00024: val_loss did not improve from 0.11718
Epoch 00024: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 25/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1890 - accuracy: 0.9274 - val_loss: 0.0862 - val_accuracy: 0.9691
Epoch 00025: val_loss improved from 0.11718 to 0.08617, saving model to lstm.h5
Epoch 26/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1475 - accuracy: 0.9361 - val_loss: 0.0794 - val_accuracy: 0.9733
Epoch 00026: val_loss improved from 0.08617 to 0.07940, saving model to lstm.h5
Epoch 27/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1507 - accuracy: 0.9392 - val_loss: 0.0673 - val_accuracy: 0.9774
Epoch 00027: val_loss improved from 0.07940 to 0.06732, saving model to lstm.h5
Epoch 28/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1498 - accuracy: 0.9444 - val_loss: 0.0764 - val_accuracy: 0.9733
Epoch 00028: val_loss did not improve from 0.06732
Epoch 29/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1513 - accuracy: 0.9423 - val_loss: 0.0733 - val_accuracy: 0.9774
Epoch 00029: val_loss did not improve from 0.06732
Epoch 30/30
1941/1941 [==============================] - 4s 2ms/step - loss: 0.1338 - accuracy: 0.9418 - val_loss: 0.0815 - val_accuracy: 0.9753
Epoch 00030: val_loss did not improve from 0.06732
Epoch 00030: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
# 画曲线
def plot_performance(history=None,figure_directory=None,ylim_pad=[0,0]):
xlabel="Epoch"
legends=["Training","Validation"]
plt.figure(figsize=(20,5))
y1=history.history["accuracy"]
y2=history.history["val_accuracy"]
min_y=min(min(y1),min(y2))-ylim_pad[0]
max_y=max(max(y1),max(y2))+ylim_pad[0]
plt.subplot(121)
plt.plot(y1)
plt.plot(y2)
plt.title("Model Accuracy\n",fontsize=17)
plt.xlabel(xlabel,fontsize=15)
plt.ylabel("Accuracy",fontsize=15)
plt.ylim(min_y,max_y)
plt.legend(legends,loc="upper left")
plt.grid()
y1=history.history["loss"]
y2=history.history["val_loss"]
min_y=min(min(y1),min(y2))-ylim_pad[1]
max_y=max(max(y1),max(y2))+ylim_pad[1]
plt.subplot(122)
plt.plot(y1)
plt.plot(y2)
plt.title("Model Loss:\n",fontsize=17)
plt.xlabel(xlabel,fontsize=15)
plt.ylabel("Loss",fontsize=15)
plt.ylim(min_y,max_y)
plt.legend(legends,loc="upper left")
plt.grid()
plt.show()
# 可视化
plot_performance(history=history_fit)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4kudzXlk-1635931670393)(output_26_0.png)]](https://i-blog.csdnimg.cn/blog_migrate/ef82bd866b463b2e68dcd5fd6e2e0f45.png)
# 预测
predict_y = model.predict_classes(X_test)
predict_y
array([1, 6, 4, 3, 4, 1, 4, 6, 6, 1, 1, 1, 1, 1, 1, 4, 4, 4, 4, 5, 4, 5,
7, 1, 4, 5, 3, 4, 1, 6, 4, 4, 5, 4, 1, 1, 7, 4, 1, 4, 6, 4, 4, 5,
4, 7, 7, 6, 1, 1, 5, 6, 2, 1, 4, 4, 1, 4, 4, 4, 6, 5, 2, 6, 3, 1,
2, 4, 2, 4, 1, 1, 1, 1, 1, 1, 6, 4, 1, 3, 5, 2, 4, 6, 3, 4, 4, 3,
6, 5, 7, 1, 1, 2, 7, 4, 1, 6, 6, 2, 6, 1, 3, 4, 1, 1, 1, 4, 2, 1,
3, 6, 2, 4, 4, 4, 3, 4, 1, 1, 6, 7, 6, 7, 2, 5, 1, 3, 4, 1, 3, 3,
5, 4, 4, 7, 6, 2, 6, 4, 6, 6, 3, 5, 3, 5, 6, 3, 4, 1, 3, 6, 1, 4,
6, 4, 6, 2, 2, 1, 7, 4, 6, 3, 6, 6, 5, 4, 4, 4, 4, 2, 4, 6, 1, 3,
1, 6, 6, 4, 1, 1, 4, 1, 4, 4, 2, 3, 1, 6, 4, 4, 3, 6, 5, 3, 4, 6,
1, 1, 3, 5, 4, 1, 6, 3, 4, 3, 1, 2, 1, 4, 6, 5, 3, 5, 4, 4, 4, 4,
7, 3, 1, 4, 2, 4, 6, 7, 4, 1, 4, 3, 1, 4, 1, 5, 2, 5, 3, 4, 1, 2,
4, 5, 1, 4, 4, 6, 3, 1, 4, 4, 5, 5, 6, 4, 3, 3, 1, 4, 5, 1, 1, 2,
3, 1, 1, 6, 7, 6, 4, 6, 1, 3, 4, 1, 4, 2, 7, 4, 5, 1, 4, 2, 1, 7,
3, 6, 4, 4, 1, 7, 1, 5, 4, 4, 1, 4, 4, 1, 1, 4, 1, 1, 3, 6, 3, 3,
6, 5, 4, 3, 1, 2, 6, 6, 6, 4, 2, 2, 3, 1, 5, 1, 4, 1, 7, 3, 1, 1,
3, 5, 6, 2, 4, 1, 1, 6, 1, 6, 6, 6, 7, 1, 5, 4, 2, 7, 1, 6, 3, 1,
4, 5, 2, 1, 4, 5, 6, 3, 1, 5, 1, 6, 3, 1, 3, 6, 6, 5, 1, 6, 4, 1,
7, 3, 4, 3, 7, 3, 6, 1, 5, 3, 4, 2, 4, 5, 4, 1, 1, 4, 6, 3, 6, 5,
4, 6, 1, 6, 3, 1, 4, 4, 3, 1, 5, 6, 6, 3, 5, 3, 5, 2, 1, 3, 2, 4,
1, 4, 1, 3, 7, 6, 3, 4, 4, 1, 4, 2, 1, 4, 4, 2, 1, 3, 1, 3, 4, 7,
4, 4, 1, 1, 1, 1, 4, 4, 1, 4, 5, 6, 5, 3, 3, 1, 4, 3, 2, 2, 6, 4,
4, 3, 2, 2, 1, 6, 3, 1, 3, 1, 6, 7, 4, 4, 4, 1, 1, 4, 3, 1, 4, 5,
4, 3], dtype=int64)
from sklearn.metrics import accuracy_score,f1_score,confusion_matrix,classification_report
print(classification_report(y_test,predict_y))
precision recall f1-score support
1 1.00 1.00 1.00 117
2 0.97 1.00 0.99 36
3 0.99 0.88 0.93 75
4 0.92 0.99 0.95 117
5 1.00 1.00 1.00 43
6 1.00 0.99 0.99 73
7 1.00 0.96 0.98 25
accuracy 0.98 486
macro avg 0.98 0.97 0.98 486
weighted avg 0.98 0.98 0.98 486
小结
瓷们,点赞评论收藏走起来呀!!!

本文介绍了一项使用LSTM模型进行树种分类的实验。通过对采集到的数据进行预处理、特征提取及模型训练,实现了对不同树种的有效区分。实验采用的数据集包含多种树的数据,并且详细记录了从数据准备到模型训练的全过程。
492

被折叠的 条评论
为什么被折叠?



