# -*- coding: utf-8 -*-
"""
DKT-DSC for Assistment2012 (完整可运行版)
最后更新: 2024-07-01
"""
import os
import sys
import numpy as np
import tensorflow.compat.v1 as tf
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tf.disable_v2_behavior()
try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
print("警告: psutil模块未安装,内存监控功能受限")
from scipy.sparse import coo_matrix
from tensorflow.contrib import rnn
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, r2_score, roc_curve, auc
import math
import random
from datetime import datetime
import warnings
# 忽略警告
warnings.filterwarnings('ignore')
# ==================== 配置部分 ====================
DATA_BASE_PATH = './data/'
data_name = 'Assist_2012'
# 模拟知识图谱路径(实际使用时替换为真实路径)
KNOWLEDGE_GRAPH_PATHS = {
'graphml': './output_assist2012_gat_improved/knowledge_graph.graphml',
'nodes': './output_assist2012_gat_improved/graph_nodes.csv',
'edges': './output_assist2012_gat_improved/graph_edges.csv'
}
# 创建模拟数据路径
os.makedirs(DATA_BASE_PATH, exist_ok=True)
os.makedirs(os.path.dirname(KNOWLEDGE_GRAPH_PATHS['nodes']), exist_ok=True)
# ==================== 模拟数据生成 ====================
def generate_mock_data():
"""生成模拟数据用于测试"""
# 生成模拟训练数据 (300条记录)
train_data = pd.DataFrame({
'user_id': np.repeat(range(10), 30),
'problem_id': np.random.randint(1, 100, 300),
'correct': np.random.randint(0, 2, 300),
'start_time': np.arange(300) # 使用简单递增数字模拟时间戳
})
train_data.to_csv(os.path.join(DATA_BASE_PATH, f'{data_name}_train.csv'), index=False)
# 生成模拟测试数据 (100条记录)
test_data = pd.DataFrame({
'user_id': np.repeat(range(5), 20),
'problem_id': np.random.randint(1, 100, 100),
'correct': np.random.randint(0, 2, 100),
'start_time': np.arange(100) + 300 # 时间戳接续训练数据
})
test_data.to_csv(os.path.join(DATA_BASE_PATH, f'{data_name}_test.csv'), index=False)
# 生成模拟知识图谱节点数据
node_ids = [f'problem_{i}' for i in range(1, 101)] + \
[f'concept_{i}' for i in range(1, 21)]
node_types = ['problem'] * 100 + ['concept'] * 20
mock_node_data = pd.DataFrame({
'node_id': node_ids,
'type': node_types,
'difficulty': np.random.rand(120),
'avg_accuracy': np.random.rand(120),
'total_attempts': np.random.randint(100, 1000, 120),
'avg_confidence': np.random.rand(120)
})
mock_node_data.to_csv(KNOWLEDGE_GRAPH_PATHS['nodes'], index=False)
# 生成模拟边数据
sources = np.random.choice(node_ids, 500)
targets = np.random.choice(node_ids, 500)
weights = np.random.rand(500)
mock_edge_data = pd.DataFrame({
'source': sources,
'target': targets,
'weight': weights
})
mock_edge_data.to_csv(KNOWLEDGE_GRAPH_PATHS['edges'], index=False)
# 检查并生成模拟数据
if not os.path.exists(os.path.join(DATA_BASE_PATH, f'{data_name}_train.csv')):
print("[系统] 检测到缺少数据文件,正在生成模拟数据...")
generate_mock_data()
# ==================== Flags配置 ====================
tf.flags.DEFINE_float("epsilon", 1e-8, "Adam优化器的epsilon值")
tf.flags.DEFINE_float("l2_lambda", 0.003, "L2正则化系数")
tf.flags.DEFINE_float("learning_rate", 2e-4, "学习率")
tf.flags.DEFINE_float("max_grad_norm", 5.0, "梯度裁剪阈值")
tf.flags.DEFINE_float("keep_prob", 0.7, "Dropout保留概率")
tf.flags.DEFINE_integer("hidden_layer_num", 2, "隐藏层数量")
tf.flags.DEFINE_integer("hidden_size", 64, "隐藏层大小")
tf.flags.DEFINE_integer("evaluation_interval", 1, "评估间隔周期数")
tf.flags.DEFINE_integer("batch_size", 32, "批次大小") # 减小批次大小以便在模拟数据上运行
tf.flags.DEFINE_integer("problem_len", 20, "问题序列长度")
tf.flags.DEFINE_integer("epochs", 5, "训练周期数") # 减少epoch以便快速测试
tf.flags.DEFINE_boolean("allow_soft_placement", True, "允许软设备放置")
tf.flags.DEFINE_boolean("log_device_placement", False, "记录设备放置信息")
tf.flags.DEFINE_string("train_data_path", os.path.join(DATA_BASE_PATH, f'{data_name}_train.csv'), "训练数据路径")
tf.flags.DEFINE_string("test_data_path", os.path.join(DATA_BASE_PATH, f'{data_name}_test.csv'), "测试数据路径")
FLAGS = tf.flags.FLAGS
# 焦点损失参数
FOCAL_LOSS_GAMMA = 2.0
FOCAL_LOSS_ALPHA = 0.25
# 学习率衰减参数
DECAY_STEPS = 100
DECAY_RATE = 0.97
# 早停参数
EARLY_STOP_PATIENCE = 3
def memory_usage():
if HAS_PSUTIL:
try:
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 ** 2)
except:
return 0.0
return 0.0
# ==================== 时间戳处理工具函数 ====================
def parse_timestamp(timestamp_str):
"""尝试多种格式解析时间戳"""
if isinstance(timestamp_str, (int, float, np.number)):
return float(timestamp_str)
if isinstance(timestamp_str, str):
timestamp_str = timestamp_str.strip('"\' ')
# 尝试常见时间格式
for fmt in ('%Y-%m-%d %H:%M:%S', '%m/%d/%Y %H:%M', '%Y-%m-%d', '%s'):
try:
if fmt == '%s': # Unix时间戳
return float(timestamp_str)
dt = datetime.strptime(timestamp_str, fmt)
return dt.timestamp()
except ValueError:
continue
return np.nan
# ==================== 知识图谱加载器 ====================
class KnowledgeGraphLoader:
def __init__(self):
self.node_features = None
self.adj_matrix = None
self.problem_to_node = {}
self.node_id_map = {}
self.static_node_count = 0
self._rows = None
self._cols = None
def load(self):
print("\n[KG] 加载知识图谱...")
try:
if not os.path.exists(KNOWLEDGE_GRAPH_PATHS['nodes']):
raise FileNotFoundError(f"节点文件未找到: {KNOWLEDGE_GRAPH_PATHS['nodes']}")
if not os.path.exists(KNOWLEDGE_GRAPH_PATHS['edges']):
raise FileNotFoundError(f"边文件未找到: {KNOWLEDGE_GRAPH_PATHS['edges']}")
node_df = pd.read_csv(KNOWLEDGE_GRAPH_PATHS['nodes'])
self.static_node_count = len(node_df)
print(f"[KG] 总节点数: {self.static_node_count}")
# 处理空值
print("[KG] 处理特征空值...")
feature_cols = [col for col in node_df.columns if col not in ['node_id', 'type']]
for col in feature_cols:
if node_df[col].isna().any():
if 'accuracy' in col or 'confidence' in col:
median_val = node_df[col].median()
node_df[col] = node_df[col].fillna(median_val)
else:
for node_type in ['problem', 'concept']:
mask = node_df['type'] == node_type
type_median = node_df.loc[mask, col].median()
node_df.loc[mask, col] = node_df.loc[mask, col].fillna(type_median)
# 特征标准化
raw_features = node_df[feature_cols].values
raw_features = np.nan_to_num(raw_features)
feature_mean = np.mean(raw_features, axis=0)
feature_std = np.std(raw_features, axis=0) + 1e-8
self.node_features = np.array(
(raw_features - feature_mean) / feature_std,
dtype=np.float32
)
# 创建映射
self.node_id_map = {row['node_id']: idx for idx, row in node_df.iterrows()}
# 创建问题映射
self.problem_to_node = {}
problem_count = 0
for idx, row in node_df.iterrows():
if row['type'] == 'problem':
try:
problem_id = int(row['node_id'].split('_')[1])
self.problem_to_node[problem_id] = idx
problem_count += 1
except (IndexError, ValueError):
continue
print(f"[KG] 已加载 {problem_count} 个问题节点映射")
# 加载边数据
edge_df = pd.read_csv(KNOWLEDGE_GRAPH_PATHS['edges'])
rows, cols, data = [], [], []
grouped = edge_df.groupby('source')
for src, group in tqdm(grouped, total=len(grouped), desc="处理边数据"):
src_idx = self.node_id_map.get(src, -1)
if src_idx == -1:
continue
neighbors = []
for _, row in group.iterrows():
tgt_idx = self.node_id_map.get(row['target'], -1)
if tgt_idx != -1:
neighbors.append((tgt_idx, row['weight']))
neighbors.sort(key=lambda x: x[1], reverse=True)
top_k = min(100, len(neighbors))
for i in range(top_k):
rows.append(src_idx)
cols.append(neighbors[i][0])
data.append(neighbors[i][1])
# 添加自环
for i in range(self.static_node_count):
rows.append(i)
cols.append(i)
data.append(1.0)
# 创建稀疏矩阵
adj_coo = coo_matrix(
(data, (rows, cols)),
shape=(self.static_node_count, self.static_node_count),
dtype=np.float32
)
self.adj_matrix = adj_coo.tocsc()
self._rows = np.array(rows)
self._cols = np.array(cols)
except Exception as e:
print(f"知识图谱加载失败: {str(e)}")
raise
# ==================== 图注意力层 ====================
class GraphAttentionLayer:
def __init__(self, input_dim, output_dim, kg_loader, scope=None):
self.kg_loader = kg_loader
self.node_count = kg_loader.static_node_count
self._rows = kg_loader._rows
self._cols = kg_loader._cols
with tf.variable_scope(scope or "GAT"):
self.W = tf.get_variable(
"W", [input_dim, output_dim],
initializer=tf.initializers.variance_scaling(
scale=0.1, mode='fan_avg', distribution='uniform')
)
self.attn_kernel = tf.get_variable(
"attn_kernel", [output_dim * 2, 1],
initializer=tf.initializers.variance_scaling(
scale=0.1, mode='fan_avg', distribution='uniform')
)
self.bias = tf.get_variable(
"bias", [output_dim],
initializer=tf.zeros_initializer()
)
def __call__(self, inputs):
inputs = tf.clip_by_value(inputs, -5, 5)
h = tf.matmul(inputs, self.W)
h = tf.clip_by_value(h, -5, 5)
h_src = tf.gather(h, self._rows)
h_dst = tf.gather(h, self._cols)
h_concat = tf.concat([h_src, h_dst], axis=1)
edge_logits = tf.squeeze(tf.matmul(h_concat, self.attn_kernel), axis=1)
edge_logits = tf.clip_by_value(edge_logits, -10, 10)
edge_attn = tf.nn.leaky_relu(edge_logits, alpha=0.2)
edge_indices = tf.constant(np.column_stack((self._rows, self._cols)), dtype=tf.int64)
sparse_attn = tf.SparseTensor(
indices=edge_indices,
values=edge_attn,
dense_shape=[self.node_count, self.node_count]
)
sparse_attn_weights = tf.sparse_softmax(sparse_attn)
output = tf.sparse_tensor_dense_matmul(sparse_attn_weights, h)
output = tf.clip_by_value(output, -5, 5)
output += self.bias
output = tf.nn.elu(output)
return output
# ==================== 学生知识追踪模型 ====================
class StudentModel:
def __init__(self, is_training, config):
self.batch_size = config.batch_size # 添加这行
self.batch_size_tensor = tf.placeholder(tf.int32, [], name='batch_size_placeholder')
self.num_skills = config.num_skills
self.num_steps = config.num_steps
self.current = tf.placeholder(tf.int32, [None, self.num_steps], name='current')
self.next = tf.placeholder(tf.int32, [None, self.num_steps], name='next')
self.target_id = tf.placeholder(tf.int32, [None], name='target_ids')
self.target_correctness = tf.placeholder(tf.float32, [None], name='target_correctness')
with tf.device('/gpu:0'), tf.variable_scope("KnowledgeGraph", reuse=tf.AUTO_REUSE):
kg_loader = KnowledgeGraphLoader()
kg_loader.load()
kg_node_features = tf.constant(kg_loader.node_features, dtype=tf.float32)
# 增强GAT结构
gat_output = kg_node_features
for i in range(2):
with tf.variable_scope(f"GAT_Layer_{i + 1}"):
dim = 64 if i == 0 else 32
gat_layer = GraphAttentionLayer(
input_dim=gat_output.shape[1] if i > 0 else kg_node_features.shape[1],
output_dim=dim,
kg_loader=kg_loader
)
gat_output = gat_layer(gat_output)
gat_output = tf.nn.leaky_relu(gat_output, alpha=0.1)
self.skill_embeddings = gat_output
with tf.variable_scope("FeatureProcessing"):
# 使用实际batch_size的placeholder
batch_size = tf.shape(self.next)[0]
# 初始化方法1:使用tf.zeros_like和tile
dummy_vector = tf.zeros([1, 1], dtype=tf.float32)
history_init = tf.tile(dummy_vector, [batch_size, 1])
elapsed_init = tf.tile(dummy_vector, [batch_size, 1])
# 或者初始化方法2:直接使用tf.fill
# history_init = tf.fill([batch_size, 1], 0.0)
# elapsed_init = tf.fill([batch_size, 1], 0.0)
current_indices = tf.minimum(self.current, kg_loader.static_node_count - 1)
current_embed = tf.nn.embedding_lookup(self.skill_embeddings, current_indices)
inputs = []
valid_mask = tf.cast(tf.not_equal(self.current, 0), tf.float32)
answers_float = tf.cast(self.next, tf.float32)
# 初始化历史和耗时特征
history = history_init
elapsed_time = elapsed_init
for t in range(self.num_steps):
if t > 0:
past_answers = answers_float[:, :t]
past_valid_mask = valid_mask[:, :t]
correct_count = tf.reduce_sum(past_answers * past_valid_mask, axis=1, keepdims=True)
total_valid = tf.reduce_sum(past_valid_mask, axis=1, keepdims=True)
history = correct_count / (total_valid + 1e-8)
elapsed_time = tf.fill([batch_size, 1], tf.cast(t, tf.float32))
with tf.variable_scope(f"feature_extraction_t{t}"):
# 基础特征
current_feat = current_embed[:, t, :]
# 知识图谱特征
difficulty_feature = tf.gather(
kg_loader.node_features[:, 0],
tf.minimum(self.current[:, t], kg_loader.static_node_count - 1)
)
difficulty_feature = tf.reshape(difficulty_feature, [-1, 1])
# 情感特征
affect_features = []
for i in range(1, 3):
try:
affect_feature = tf.gather(
kg_loader.node_features[:, i],
tf.minimum(self.current[:, t], kg_loader.static_node_count - 1)
)
affect_feature = tf.reshape(affect_feature, [-1, 1])
affect_features.append(affect_feature)
except Exception as e:
tf.logging.warning(f"情感特征{i}提取失败: {str(e)}")
affect_features.append(tf.zeros_like(difficulty_feature))
# 确保所有特征都是2维的
features_to_concat = [current_feat, history, elapsed_time, difficulty_feature] + affect_features
features_to_concat = [
f if len(f.shape) == 2 else tf.reshape(f, [-1, 1])
for f in features_to_concat
]
# 调试信息(可选)
if is_training:
features_to_concat = [
tf.Print(f, [tf.shape(f)],
message=f"Feature {i} shape at step {t}: ")
for i, f in enumerate(features_to_concat)
]
combined = tf.concat(features_to_concat, axis=1)
inputs.append(combined)
# 增强RNN结构
with tf.variable_scope("RNN"):
cells = []
for i in range(2):
cell = rnn.LSTMCell(
FLAGS.hidden_size,
initializer=tf.orthogonal_initializer(),
forget_bias=1.0
)
if is_training and FLAGS.keep_prob < 1.0:
cell = rnn.DropoutWrapper(cell, output_keep_prob=FLAGS.keep_prob)
cells.append(cell)
stacked_cell = rnn.MultiRNNCell(cells)
outputs, _ = tf.nn.dynamic_rnn(
stacked_cell,
tf.stack(inputs, axis=1),
dtype=tf.float32
)
output = tf.reshape(outputs, [-1, FLAGS.hidden_size])
with tf.variable_scope("Output"):
hidden = tf.layers.dense(
output,
units=32,
activation=tf.nn.relu,
kernel_initializer=tf.initializers.glorot_uniform()
)
logits = tf.layers.dense(
hidden,
units=1,
kernel_initializer=tf.initializers.glorot_uniform()
)
self._all_logits = tf.clip_by_value(logits, -20, 20)
selected_logits = tf.gather(tf.reshape(self._all_logits, [-1]), self.target_id)
self.pred = tf.clip_by_value(tf.sigmoid(selected_logits), 1e-8, 1 - 1e-8)
with tf.variable_scope("Loss"):
labels = tf.clip_by_value(self.target_correctness, 0.05, 0.95)
pos_weight = tf.reduce_sum(1.0 - labels) / (tf.reduce_sum(labels) + 1e-8)
bce_loss = tf.nn.weighted_cross_entropy_with_logits(
targets=labels,
logits=selected_logits,
pos_weight=pos_weight
)
confidence_penalty = tf.reduce_mean(
tf.square(tf.sigmoid(selected_logits) - 0.5)
)
loss = tf.reduce_mean(bce_loss) + 0.1 * confidence_penalty
l2_loss = tf.add_n([
tf.nn.l2_loss(v)
for v in tf.trainable_variables()
if 'bias' not in v.name
]) * FLAGS.l2_lambda
self.cost = loss + l2_loss
# ==================== 数据加载 ====================
def read_data_from_csv_file(path, kg_loader, is_training=False):
students = []
student_ids = []
max_skill = 0
missing_problems = set()
if not os.path.exists(path):
print(f"❌ 文件不存在: {path}")
return [], [], [], 0, 0, 0
try:
print(f"[数据] 加载数据文件: {path}")
try:
data_df = pd.read_csv(path)
except Exception as e:
print(f"CSV读取失败: {str(e)}")
encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']
for encoding in encodings:
try:
data_df = pd.read_csv(path, encoding=encoding)
break
except:
continue
if 'data_df' not in locals():
return [], [], [], 0, 0, 0
# 列名标准化
possible_columns = {
'user_id': ['user_id', 'userid', 'student_id', 'studentid'],
'problem_id': ['problem_id', 'problemid', 'skill_id', 'skillid'],
'correct': ['correct', 'correctness', 'answer', 'accuracy'],
'start_time': ['start_time', 'timestamp', 'time', 'date']
}
actual_columns = {}
for col_type, possible_names in possible_columns.items():
found = False
for name in possible_names:
if name in data_df.columns:
actual_columns[col_type] = name
found = True
break
if not found:
print(f"❌ 错误: 找不到 {col_type} 列")
return [], [], [], 0, 0, 0
data_df = data_df.rename(columns={
actual_columns['user_id']: 'user_id',
actual_columns['problem_id']: 'problem_id',
actual_columns['correct']: 'correct',
actual_columns['start_time']: 'start_time'
})
# 时间戳转换
print("[数据] 转换时间戳...")
timestamp_col = data_df['start_time']
if isinstance(timestamp_col.iloc[0], str):
try:
data_df['start_time'] = timestamp_col.astype(float)
except ValueError:
parsed_times = timestamp_col.apply(parse_timestamp)
nan_count = parsed_times.isna().sum()
if nan_count > 0:
print(f"⚠️ 警告: {nan_count}个时间戳无法解析,将设为0")
parsed_times = parsed_times.fillna(0)
data_df['start_time'] = parsed_times
else:
data_df['start_time'] = timestamp_col.astype(float)
# 按学生分组
grouped = data_df.groupby('user_id')
for user_id, group in tqdm(grouped, total=len(grouped), desc="处理学生数据"):
try:
group = enhanced_data_validation(group, kg_loader)
if group is None:
continue
problems = group['problem_id'].values
answers = group['correct'].values.astype(int)
timestamps = group['start_time'].values.astype(float)
valid_data = []
invalid_count = 0
for i, (p, a) in enumerate(zip(problems, answers)):
if p in kg_loader.problem_to_node and a in (0, 1):
valid_data.append((p, a))
else:
invalid_count += 1
if p != 0 and p not in missing_problems:
missing_problems.add(p)
if len(valid_data) < 2:
continue
problems, answers = zip(*valid_data)
n_split = (len(problems) + FLAGS.problem_len - 1) // FLAGS.problem_len
for k in range(n_split):
start = k * FLAGS.problem_len
end = (k + 1) * FLAGS.problem_len
seg_problems = list(problems[start:end])
seg_answers = list(answers[start:end])
if len(seg_problems) < FLAGS.problem_len:
pad_len = FLAGS.problem_len - len(seg_problems)
seg_problems += [0] * pad_len
seg_answers += [0] * pad_len
mapped_problems = [kg_loader.problem_to_node.get(p, 0) for p in seg_problems]
students.append(([user_id, k], mapped_problems, seg_answers))
max_skill = max(max_skill, max(mapped_problems))
student_ids.append(user_id)
except Exception as e:
print(f"处理学生 {user_id} 时出错: {str(e)}")
continue
except Exception as e:
print(f"数据加载失败: {str(e)}")
return [], [], [], 0, 0, 0
return students, [], student_ids, max_skill, 0, 0
def enhanced_data_validation(group, kg_loader):
"""增强数据验证"""
problems = group['problem_id'].values
timestamps = group['start_time'].values.astype(float)
valid_indices = np.where(~np.isnan(timestamps))[0]
if len(valid_indices) > 1:
time_diffs = np.diff(timestamps[valid_indices])
if np.any(time_diffs < 0):
sort_idx = np.argsort(timestamps)
group = group.iloc[sort_idx].reset_index(drop=True)
valid_mask = [p in kg_loader.problem_to_node for p in problems]
if not any(valid_mask):
return None
return group[valid_mask]
# ==================== 训练流程 ====================
def run_epoch(session, model, data, run_type, eval_op, verbose=False):
"""执行一个epoch的训练或评估
Args:
session: TF会话
model: 模型对象
data: 输入数据
run_type: '训练'或'测试'
eval_op: 训练op或tf.no_op()
verbose: 是否显示详细进度
Returns:
dict: 包含loss, auc, rmse, r2的字典
"""
preds = []
labels = []
total_loss = 0.0
processed_count = 0
# 禁用TF调试信息
tf.logging.set_verbosity(tf.logging.ERROR)
index = 0
batch_size = model.batch_size
# 可选:使用tqdm进度条(verbose模式下)
iterator = tqdm(range(0, len(data), batch_size), desc=f"{run_type}处理中") if verbose else range(0, len(data),
batch_size)
for start in iterator:
end = min(start + batch_size, len(data))
batch_data = data[start:end]
# 准备批次数据
current_batch, next_batch, target_ids, target_correctness = [], [], [], []
for idx, (stu_id, problems, answers) in enumerate(batch_data):
valid_length = sum(1 for p in problems if p != 0)
if valid_length < 1:
continue
current_batch.append(problems)
next_batch.append(answers)
last_step = valid_length - 1
target_ids.append(idx * model.num_steps + last_step)
target_correctness.append(answers[last_step])
if not current_batch:
continue
actual_batch_size = len(current_batch)
feed_dict = {
model.current: np.array(current_batch, dtype=np.int32),
model.next: np.array(next_batch, dtype=np.int32),
model.target_id: np.array(target_ids, dtype=np.int32),
model.target_correctness: np.array(target_correctness, dtype=np.float32)
}
try:
if eval_op != tf.no_op():
_, pred, loss = session.run(
[eval_op, model.pred, model.cost],
feed_dict=feed_dict
)
else:
pred, loss = session.run(
[model.pred, model.cost],
feed_dict=feed_dict
)
preds.extend(pred.flatten().tolist())
labels.extend(target_correctness)
total_loss += loss * actual_batch_size
processed_count += actual_batch_size
except Exception as e:
print(f"\n{run_type}错误 (批次 {start}-{end}): {str(e)}", file=sys.stderr)
continue
# 计算指标
if processed_count == 0:
return None
avg_loss = total_loss / processed_count
# 确保标签和预测值在有效范围内
labels = np.clip(np.array(labels), 1e-7, 1 - 1e-7)
preds = np.clip(np.array(preds), 1e-7, 1 - 1e-7)
metrics = {
'loss': avg_loss,
'auc': roc_auc_score(labels, preds) if len(set(labels)) > 1 else 0.5,
'rmse': np.sqrt(mean_squared_error(labels, preds)),
'r2': r2_score(labels, preds)
}
return metrics
def main(_):
"""主训练流程"""
# 1. 加载配置和数据
config = ModelConfig() # 假设已定义
train_data, test_data = load_data() # 假设已定义
# 2. 构建模型
with tf.variable_scope("Model", reuse=False):
train_model = StudentModel(is_training=True, config=config)
with tf.variable_scope("Model", reuse=True):
test_model = StudentModel(is_training=False, config=config)
# 3. 创建会话
sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
with tf.Session(config=sess_config) as sess:
# 4. 初始化变量
sess.run(tf.global_variables_initializer())
# 5. 训练循环
best_auc = 0.0
for epoch in range(1, FLAGS.max_epochs + 1):
# 训练阶段
train_metrics = run_epoch(
sess, train_model, train_data,
'训练', train_op, # train_op应已定义
verbose=(epoch % FLAGS.display_freq == 0)
)
# 测试阶段
test_metrics = run_epoch(
sess, test_model, test_data,
'测试', tf.no_op(),
verbose=False
)
# 6. 输出关键指标
print(f"Epoch {epoch}")
print(
f"训练集 - 损失: {train_metrics['loss']:.4f}, RMSE: {train_metrics['rmse']:.4f}, AUC: {train_metrics['auc']:.4f}, R²: {train_metrics['r2']:.4f}")
print(
f"测试集 - 损失: {test_metrics['loss']:.4f}, RMSE: {test_metrics['rmse']:.4f}, AUC: {test_metrics['auc']:.4f}, R²: {test_metrics['r2']:.4f}")
sys.stdout.flush()
# 7. 保存最佳模型
if test_metrics['auc'] > best_auc:
best_auc = test_metrics['auc']
saver.save(sess, FLAGS.model_path) # saver应已定义
print("训练完成!")
print(f"最佳测试AUC: {best_auc:.4f}")
if __name__ == "__main__":
# 生成模拟数据(仅当真实数据不存在时)
if not os.path.exists(FLAGS.train_data_path) or not os.path.exists(FLAGS.test_data_path):
generate_mock_data()
tf.app.run()
请你输出修改之后的完整代码,不要省略
最新发布