tensorflow笔记2-----实现鸢尾花分类

# -*- coding: utf-8 -*-
"""
使用tensorflow框架搭建网络实现鸢尾花分类
步骤:
1.准备数据
  ·加载数据集
  ·随机打乱数据集
  ·划分数据集:分为测试集与训练集
  ·将特征与标签匹配
2.定义网络中的每次迭代更新的参数:权值和偏置
3.使用梯度下降法更新参数,并在每一次记录测试集上的准确率
4.作出准确率的图像
"""
import tensorflow as tf
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
##加载数据集
data_x=load_iris().data
data_y=load_iris().target
##随机打乱数据集
np.random.seed(100)
np.random.shuffle(data_x)
np.random.seed(100)#每一次的随机种子相同,使标签和特征匹配
np.random.shuffle(data_y)
np.random.seed(100)
##划分数据集
train_x=data_x[:-30]
train_y=data_y[:-30]
test_x=data_x[-30:]
test_y=data_y[-30:]
##转换特征的数据类型,不然运行时会因为前后数据类型不一致而出错
train_x=tf.cast(train_x,dtype=tf.float32)
test_x=tf.cast(test_x,dtype=tf.float32)
##将特征与标签匹配,每一次喂入30个训练模型
train=tf.data.Dataset.from_tensor_slices((train_x,train_y)).batch(32)
test=tf.data.Dataset.from_tensor_slices((test_x,test_y)).batch(32)
##定义网络中的参数
w=tf.Variable(tf.random.truncated_normal([4,3],stddev=0.1,seed=1))##输入特征数量为4,输出特征为3个
b=tf.Variable(tf.random.truncated_normal([3],stddev=0.1,seed=1))##偏置项,输出特征为3个
##定义存储结果的相关变量
total_epoch=500##训练次数
test_acc=[]##存储测试集上的准确率
train_loss=[]##测试集上的损失
loss_b=0##存储喂入每一batch时的损失
lr=0.1##学习率
for epoch in range(total_epoch):
    for i,(x_train,y_train) in enumerate(train):
        with tf.GradientTape() as tape:##定义计算梯度的结构
             y=tf.matmul(x_train,w)+b##输入层的输出
             y=tf.nn.softmax(y)##softmax归一化
             y_=tf.one_hot(y_train,depth=3)##将对应训练集上的标签化为one-hot形式,方便计算损失
             loss=tf.reduce_mean(tf.square(y_-y))
             loss_b+=loss.numpy()
        grad=tape.gradient(loss,[w,b])##梯度
        ##更新参数
        w.assign_sub(lr*grad[0])
        b.assign_sub(lr*grad[1])
    print("Epoch {},loss {}".format(epoch,loss_b/4))
    train_loss.append(loss_b/4)
    loss_b=0
    ##在该参数的前提下,对测试集进行预测
    correct_num,total_num=0,0
    for i,(x_test,y_test) in enumerate(test):
        y=tf.matmul(x_test,w)+b##使用更新后的参数进行预测
        y=tf.nn.softmax(y)
        pred=tf.argmax(y,axis=1)##返回分类结果对应的最大值的标签
        pred=tf.cast(pred,dtype=y_test.dtype)
        correct_num+=tf.reduce_sum(tf.cast(tf.equal(pred,y_test),dtype=tf.int32))
        total_num+=x_test.shape[0]
    ##输出准确率
    acc=correct_num/total_num
    test_acc.append(acc)
    print("Test_acc:",acc.numpy())
    print('-------------------------')
##绘制最后的图像
plt.figure(figsize=(10,8))
plt.xlabel('epoch')
plt.ylabel('train_loss')
plt.plot(train_loss,marker='.',color='r',linestyle='--',label="loss")
plt.legend(loc="best")
plt.show()
plt.figure(figsize=(10,8))
plt.xlabel('epoch')
plt.ylabel('test_acc')
plt.plot(test_acc,marker='.',color='r',linestyle='--',label="accuracy")
plt.legend(loc="best")
plt.show()  

结果如下:
在这里插入图片描述
在这里插入图片描述

安装Docker安装插件,可以按照以下步骤进行操作: 1. 首先,安装Docker。可以按照官方文档提供的步骤进行安装,或者使用适合您操作系统的包管理器进行安装。 2. 安装Docker Compose插件。可以使用以下方法安装: 2.1 下载指定版本的docker-compose文件: curl -L https://github.com/docker/compose/releases/download/1.21.2/docker-compose-`uname -s`-`uname -m` -o /usr/local/bin/docker-compose 2.2 赋予docker-compose文件执行权限: chmod +x /usr/local/bin/docker-compose 2.3 验证安装是否成功: docker-compose --version 3. 在安装插件之前,可以测试端口是否已被占用,以避免编排过程中出错。可以使用以下命令安装netstat并查看端口号是否被占用: yum -y install net-tools netstat -npl | grep 3306 现在,您已经安装Docker安装Docker Compose插件,可以继续进行其他操作,例如上传docker-compose.yml文件到服务器,并在服务器上安装MySQL容器。可以参考Docker的官方文档或其他资源来了解如何使用DockerDocker Compose进行容器的安装和配置。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [Docker安装docker-compose插件](https://blog.youkuaiyun.com/qq_50661854/article/details/124453329)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *3* [Docker安装MySQL docker安装mysql 完整详细教程](https://blog.youkuaiyun.com/qq_40739917/article/details/130891879)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值