昨天我们已经介绍了如何在不同的文件中,导入其他目录的文件,核心在于了解导入方式和python解释器检索目录的方式。
搞清楚了这些,那我们就可以来看看,如何把一个文件,拆分成多个具有着独立功能的文件,然后通过import的方式,来调用这些文件。这样具有几个好处:
- 可以让项目文件变得更加规范和清晰
- 可以让项目文件更加容易维护,修改某一个功能的时候,只需要修改一个文件,而不需要修改多个文件。
- 文件变得更容易复用,部分通用的文件可以单独拿出来,进行其他项目的复用。
机器学习项目的流程
一个典型的机器学习项目通常包含以下阶段:
- 数据加载:从文件、数据库、API 等获取原始数据。
- 命名参考:
load_data.py、data_loader.py
- 命名参考:
- 数据探索与可视化:了解数据特性,初期可用 Jupyter Notebook,成熟后固化绘图函数。
- 命名参考:
eda.py、visualization_utils.py
- 命名参考:
- 数据预处理:处理缺失值、异常值,进行标准化、归一化、编码等操作。
- 命名参考:
preprocess.py、data_cleaning.py、data_transformation.py
- 命名参考:
- 特征工程:创建新特征,选择、优化现有特征。
- 命名参考:
feature_engineering.py
- 命名参考:
- 模型训练:构建模型架构,设置超参数并训练,保存模型。
- 命名参考:
model.py、train.py
- 命名参考:
- 模型评估:用合适指标评估模型在测试集上的性能,生成报告。
- 命名参考:
evaluate.py
- 命名参考:
- 模型预测:用训练好的模型对新数据预测。
- 命名参考:
predict.py、inference.py
- 命名参考:
文件的组织
1. 项目核心代码组织
- src/(source的缩写):存放项目的核心源代码。按照机器学习项目阶段进一步细分:
- src/data/:放置与数据相关的代码。
src/data/load_data.py:负责从各类数据源(如文件系统、数据库、API 等)读取原始数据。src/data/preprocess.py:进行数据清洗(处理缺失值、异常值)、数据转换(标准化、归一化、编码等)操作。src/data/feature_engineering.py:根据业务和数据特点,创建新特征或对现有特征进行选择、优化。
- src/models/:关于模型的代码。
src/models/model.py:定义模型架构,比如神经网络结构、机器学习算法模型设定等。src/models/train.py:设置模型超参数,并执行训练过程,保存训练好的模型。src/models/evaluate.py:使用合适的评估指标(如准确率、召回率、均方误差等),在测试集上评估模型性能,生成评估报告。src/models/predict.py或src/models/inference.py:利用训练好的模型对新数据进行预测。
- src/utils/:存放通用辅助函数代码,可进一步细分:
src/utils/io_utils.py:包含文件读写相关帮助函数,比如读取特定格式文件、保存数据到文件等。src/utils/logging_utils.py:实现日志记录功能,方便记录项目运行过程中的信息,便于调试和监控。src/utils/math_utils.py:特定的数值计算函数,像自定义的矩阵运算、统计计算等。src/utils/plotting_utils.py:绘图工具函数,用于生成数据可视化图表(如绘制损失函数变化曲线、特征分布直方图等 )。
- src/data/:放置与数据相关的代码。
2. 配置文件管理
- config/ 目录:集中存放项目的配置文件,方便管理和切换不同环境(开发、测试、生产)的配置。
config/config.py或config/settings.py:以 Python 代码形式定义配置参数。config/config.yaml或config/config.json:采用 YAML 或 JSON 格式,清晰列出文件路径、模型超参数、随机种子、API 密钥等可配置参数。.env文件:通常放在项目根目录,用于存储敏感信息(如数据库密码、API 密钥等),在代码中通过环境变量的方式读取,一般会被.gitignore忽略,防止敏感信息泄露。
3. 实验与探索代码
-
notebooks/ 或 experiments/ 目录:用于初期的数据探索、快速实验、模型原型验证。
notebooks/initial_eda.ipynb:在项目初期,使用 Jupyter Notebook 进行数据探索与可视化,了解数据特性,分析数据分布、相关性等。experiments/model_experimentation.py:编写脚本对不同模型架构、超参数组合进行快速实验,对比实验结果,寻找最优模型设置。
这部分往往是最开始的探索阶段,后面跑通了后拆分成了完整的项目,留作纪念用。
4. 项目产出物管理
- data/ 目录:存放项目相关数据。
data/raw/:放置从外部获取的未经处理的原始数据,保持数据原始状态。data/processed/:存放经过预处理(清洗、转换、特征工程等操作)后的数据,供模型训练和评估使用。data/interim/:(可选)保存中间处理结果,比如数据清洗过程中生成的临时文件、特征工程中间步骤产生的数据等。
- models/ 目录:专门存放训练好的模型文件,根据模型保存格式不同,可能是
.pkl(Python pickle 格式,常用于保存 sklearn 模型 )、.h5(常用于保存 Keras 模型 )、.joblib等。 - reports/ 或 output/ 目录:存储项目运行产生的各类报告和输出文件。
reports/evaluation_report.txt:记录模型评估的详细结果,包括各项评估指标数值、模型性能分析等。reports/visualizations/:存放数据可视化图片,如损失函数收敛图、预测结果对比图等。output/logs/:保存项目运行日志文件,记录项目从开始到结束过程中的关键信息,如训练开始时间、训练过程中的损失值变化、预测时间等。
总结一下通用的拆分起步思路:
- 首先,按照机器学习的主要工作流程(数据处理、训练、评估等)将代码分离到不同的
.py文件中。 这是最基本也是最有价值的一步。 - 然后,创建一个
utils.py来存放通用的辅助函数。 - 考虑将所有配置参数集中到一个
config.py文件中。 - 为你的数据和模型产出物创建专门的顶层目录,如
data/和models/,将它们与你的源代码(通常放在src/目录)分开。
当遵循这些通用的拆分思路和原则时,项目结构自然会变得清晰。
注意事项
if name == “main”
常常会看到if name == "main"这个写法,实际上,每个文件都是一个对象,对象就会有属性和方法。
如果直接运行这个文件,则__name__等于__main__,若这个文件被其他模块导入,则__name__不等于__main__。
这个写法有如下好处:
-
明确程序起点:一个 Python 项目往往由多个模块组成。if name == “main” 可清晰界定程序执行的起始位置。比如一个包含数据处理模块 data_processing.py、模型训练模块 model_training.py 的机器学习项目,在 model_training.py 中用 if name == “main” 包裹训练相关的主逻辑代码,运行该文件时就知道需要从这里开始执行(其他文件都是附属文件),让项目结构和执行流程更清晰。(大多时候如此)
-
避免执行:python遵从模块导入即执行机制,当你使用 import xxx 导入一个模块时,Python 会执行该模块中的所有顶层代码(即不在任何函数或类内部的代码)。如果顶层代码中定义了全局变量或执行了某些操作(如读取文件、初始化数据库连接),这些操作会在导入时立即生效,并可能影响整个程序的状态。为了避免执行不必要的代码,我们可以使用 if name == “main” 来避免在导入时执行不必要的代码。这样,只有当模块被直接运行时(即被执行 python xxx.py),才会执行顶层代码,而导入时则不会执行。这样,我们就可以确保在导入模块时,不会执行不必要的代码,从而提高程序的性能和可维护性。
-
合理的资源管理:if name == “main” 与定义 main 函数结合使用,函数内变量在函数执行完这些变量被释放,能及时回收内存资源,避免内存泄漏,保证程序高效运行。
编码格式
规范的py文件,首行会有:# -- coding: utf-8 --
主要目的是 显式声明文件的编码格式,确保 Python 解释器能正确读取和解析文件中的非 ASCII 字符(如中文、日文、特殊符号等)。也就是说这个是写给解释器看的。
因为,在 Python 2.x 时代,默认编码是 ASCII,不支持直接在代码中写入非 ASCII 字符(如中文注释、字符串中的中文),否则会报错(SyntaxError: Non-UTF-8 code starting with…)。但是Python 3.x 默认为 UTF-8 编码,理论上可以省略编码声明。但实际开发中,为了兼容旧代码、明确文件编码规则,或在团队协作中避免因编辑器 / 环境配置不同导致的乱码问题,许多开发者仍会保留这一行声明。
ps:
- 编码声明必须出现在文件的前两行(通常是首行),否则会被忽略。
- 如果编码格式没问题,可能是vscode的编码格式不是utf-8,可以尝试修改编码格式。
- 常见的编码报错是因为字符串编码问题,可以尝试显式转化,即读取的时候转化为utf-8编码。
非 ASCII 字符的代码如下所示:
# -*- coding: utf-8 -*-
msg = "你好,世界!" # 中文字符串
print(msg)
输出:
你好,世界!
很多时候,项目中会包含gitattribute文件,来确保在不同操作系统和编辑器中,文件的编码格式一致。这里我们后面说到git工具再介绍
类型注解
Python 的类型注解是在 Python 3.5+ 引入的特性,用于为变量、函数参数、返回值和类属性等添加类型信息。虽然 Python 仍是动态类型语言,但类型注解可以提高代码可读性、可维护性,并支持静态类型检查工具(如 mypy)。
其次你在安装python插件的时候,附带安装了2个插件
- 一个是python debugger用于断点调试,我们已经介绍了
- 另一个是pylance,用于代码提示和类型检查,这个插件会根据你的代码中的类型注解,给出相应的提示和检查,比如你定义了一个函数,参数类型是int,那么当你传入一个字符串时,它会提示你传入的参数类型不正确。
变量类型注解语法为 变量名: 类型
# 变量的类型注解
name: str = "Alice"
age: int = 30
height: float = 1.75
is_student: bool = False
函数类型注解为函数参数和返回值指定类型,语法为 def 函数名(参数: 类型) -> 返回类型。
def add(a: int, b: int) -> int:
return a + b
def greet(name: str) -> None:
print(f"Hello, {name}")
类属性与方法的类型注解:为类的属性和方法添加类型信息。
# 定义一个矩形类
class Rectangle:
width: float # 矩形宽度(浮点数),类属性的类型注解(不初始化值)
height: float # 矩形高度(浮点数)
def __init__(self, width: float, height: float):
self.width = width
self.height = height
def area(self) -> float:
# 计算面积(宽度 × 高度)
return self.width * self.height
上述的width: float # 矩形宽度(浮点数)这个写法由于没有对变量赋值,所以是一种类型注解写法
将之前练习的信贷风险预测模型代码进行拆分
src\data\preprocessing.py
import pandas as pd
import numpy as np
from typing import Tuple, Dict
def load_data(file_path: str) -> pd.DataFrame:
"""加载数据文件
Args:
file_path:数据文件路径
Returns:
加载的数据框
"""
return pd.read_csv(file_path)
def encode_categorical_features(data: pd.DataFrame) -> Tuple[pd.DataFrame, Dict]:
"""对分类特征进行编码
Args:
data:原始数据集
Returns:
编码后的数据框和编码映射字典
"""
# Home Ownership 标签编码
home_ownership_mapping = {
'Own Home': 1,
'Rent': 2,
'Have Mortgage': 3,
'Home Mortgage': 4
}
# Years in current job 标签编码
years_in_job_mapping = {
'< 1 year': 1,
'1 year': 2,
'2 years': 3,
'3 years': 4,
'4 years': 5,
'5 years': 6,
'6 years': 7,
'7 years': 8,
'8 years': 9,
'9 years': 10,
'10+ years': 11
}
# Term 映射
term_mapping = {
'Short Term': 0,
'Long Term': 1
}
# 应用映射
data_encoded = data.copy()
data_encoded['Home Ownership'] = data['Home Ownership'].map(home_ownership_mapping)
data_encoded['Years in current job'] = data['Years in current job'].map(years_in_job_mapping)
data_encoded['Term'] = data['Term'].map(term_mapping)
data_encoded.rename(columns={'Term': 'Long Term'}, inplace=True)
# Purpose 独热编码
data_encoded = pd.get_dummies(data_encoded, columns=['Purpose'])
# 将独热编码列转换为整数类型
purpose_columns = [col for col in data_encoded.columns if col not in data.columns]
for col in purpose_columns:
data_encoded[col] = data_encoded[col].astype(int)
mappings = {
'home_ownership': home_ownership_mapping,
'years_in_job': years_in_job_mapping,
'term': term_mapping
}
return data_encoded, mappings
def handle_missing_values(data: pd.DataFrame) -> pd.DataFrame:
"""处理缺失值
Args:
data:包含缺失值的数据框
Returns:
补全缺失值后的数据框
"""
data_clean = data.copy()
continuous_features = data.select_dtypes(include=['int64', 'float64']).columns.tolist()
for feature in continuous_features:
mode_value = data[feature].mode()[0]
data_clean[feature].fillna(mode_value, inplace=True)
return data_clean
if __name__ == "__main__":
data = load_data("py60-stud\Ipynb\data.csv")
data_encoded, mapping = encode_categorical_features(data)
data_clean = handle_missing_values(data_encoded)
print("数据已处理完成")
src\models\train.py
# -*- coding: utf-8 -*-
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time
import joblib # 用于保存模型
from typing import Tuple # 用于类型注解
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
from data.preprocessing import load_data, encode_categorical_features, handle_missing_values # 数据预处理
def prepare_data() -> Tuple:
"""准备训练数据
Returns:
训练集和测试集的特征和标签
"""
# 加载和预处理数据
data = load_data("py60-stud\Ipynb\data.csv")
data_encoded, mapping = encode_categorical_features(data)
data_clean = handle_missing_values(data_encoded)
# 分离特征和标签
X = data_clean.drop(['Credit Default'], axis=1)
y = data_clean['Credit Default']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
return X_train, X_test, y_train, y_test
def train_model(X_train, y_train, model_params=None) -> RandomForestClassifier:
"""训练随机森林模型
Args:
X_train: 训练集特征
y_train: 训练集标签
model_params: 模型参数字典
Returns:
训练好的模型
"""
if model_params is None:
model_params = {'random_state': 42}
model = RandomForestClassifier(**model_params)
model.fit(X_train, y_train)
return model
# 模型评估
def evaluate_model(model, X_test, y_test) -> None:
"""评估模型性能
Args:
model: 训练好的模型
X_test: 测试集特征
y_test: 测试集标签
"""
# 模型预测
y_pred = model.predict(X_test)
# 分类报告
print(f"\n分类报告:\n{classification_report(y_test, y_pred)}")
# 混淆矩阵
print(f"\n混淆矩阵:\n{confusion_matrix(y_test, y_pred)}")
def save_model(model, model_path: str) -> None:
"""保存模型
Args:
model: 训练好的模型
model_path: 模型保存路径
"""
os.makedirs(os.path.dirname(model_path), exist_ok=True) # 创建保存模型文件所需的目录结构
joblib.dump(model, model_path)
print(f"模型已保存在:{model_path}")
if __name__ == "__main__":
# 准备数据
X_train, X_test, y_train, y_test = prepare_data()
# 记录开始时间
start_time = time.time()
# 模型训练
rf_model = train_model(X_train, y_train)
# 记录结束时间
end_time = time.time()
print(f"模型训练时间为:{end_time - start_time:.4f}秒。")
# 模型评估
evaluate_model(rf_model, X_test, y_test)
save_model(rf_model, "models/random_forest_model.joblib")
src\visualization\plots.py
import matplotlib.pyplot as plt
import seaborn as sns
import shap
import numpy as np
from typing import Any
def plot_feature_importance_shap(model: Any, X_test, save_path: str = None) -> None:
"""绘制SHAP特征重要性图
Args:
model: 训练好的模型
X_test: 测试数据
save_path: 图片保存路径
"""
# 初始化SHAP解释器
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
# 绘制特征重要性条形图
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values[:, :, 0], X_test, plot_type="bar", show=False)
plt.title("SHAP特征重要性")
if save_path:
plt.savefig(save_path)
print(f"特征重要性图已保存至: {save_path}")
plt.show()
def plot_confusion_matrix(y_true, y_pred, save_path: str = None) -> None:
"""绘制混淆矩阵热力图
Args:
y_true: 真实标签
y_pred: 预测标签
save_path: 图片保存路径
"""
plt.figure(figsize=(8, 6))
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('混淆矩阵')
plt.ylabel('真实标签')
plt.xlabel('预测标签')
if save_path:
plt.savefig(save_path)
print(f"混淆矩阵图已保存至: {save_path}")
plt.show()
def set_plot_style():
"""设置绘图样式"""
plt.style.use('seaborn')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
if __name__ == "__main__":
# 设置绘图样式
set_plot_style()
# 这里可以添加测试代码
print("可视化模块加载成功!")
92

被折叠的 条评论
为什么被折叠?



