# 一个最简单的机器学习例子
import numpy as np
import matplotlib.pyplot as plt
# 生成数据
np.random.seed(100) # 设置随机数种子,保证每次随机生成同一份数据
x = np.linspace(-1, 1, 100).reshape(100, 1)
y = 3 * np.power(x, 2) + 2 + 0.2 * np.random.rand(x.size).reshape(100, 1)
# 查看数据的分布情况
plt.scatter(x, y)
plt.show()
# 随机初始化参数
w1 = np.random.rand(1, 1)
b1 = np.random.rand(1, 1)
# 训练模型
lr = 0.001 # 学习率
for i in range(800): # 设置迭代次数,没有设置预测精确度阈值
# 前向传播
y_pred = np.power(x, 2) * w1 + b1
# 定义损失函数
loss = 0.5 * (y_pred - y) ** 2
loss = loss.sum() # 累和
#计算梯度
grad_w = np.sum((y_pred - y) * np.power(x, 2))
grad_b = np.sum((y_pred - y))
# 使用梯度下降法,是loss最小
w1 -= lr * grad_w
b1 -= lr * grad_b
# 可视化结果
plt.plot(x, y_pred, 'r-', label='predict')
plt.scatter(x, y, color='blue', marker='o', label='true') # true data
plt.xlim(-1, 1)
plt.ylim(2, 6)
plt.legend()
plt.show()
print(w1, b1)
用Numpy实现机器学习任务(一个很简单的小例子)
最新推荐文章于 2025-12-19 14:21:59 发布
该文展示了使用Python的numpy和matplotlib库进行简单的机器学习实践,通过梯度下降法训练模型来拟合一组由二次函数生成的数据点,最终可视化预测结果并与真实数据对比。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
Python3.9
Conda
Python
Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本
1264





