tensorflow 基础: static shape VS Dynamic shape, get_shape VS tf.shape() , reshape VS set_shape

本文深入解析TensorFlow中张量的静态形状(staticshape)与动态形状(dynamicshape),介绍如何使用get_shape()与tf.shape()获取形状,以及如何通过set_shape()与tf.reshape()设置和改变形状。同时探讨了这些方法在实际应用中的注意事项。
部署运行你感兴趣的模型镜像

#########################################################################################

1) 概念:static shape  与 dynamic shape

    在tensorflow中构建的图中的每个节点的tensor有两个shape: static shape 和 dynamic shape.

     static shape: 不用运行图也能获得的shape。是tensor的固有属性,可能是未知的。可以通过人为设定,补全static shape的信息。不论图的输入是什么,static shape 不为所动。

    dynamic shape:当图运行起来以后,随着tensor在图中的流动,可以根据图的结构推断图上每一个节点上的tensor的具体shape,这个shape 称为动态shape。一定是可知的。随着图的输入的shape的不同,图中的tensor的dynamic shape会发生变化。

  如果图中的某个节点上的tensor的static shape 是已知,那么当图运行起来后,该tensor得到的dynamic shape 一定要与static shape 一致,否则会报错。

#############################################################################################

2) 如何获得tensor的shape:x1.get_shape()与tf.shape(x1)

    令x1是图上的一个节点上的张量,那么:

    shape1 = x1.get_shape()  : 得到的shape1是一个tuple。 不能由tensorflow直接利用。通常用于提取 static shape.

    shape2 = tf.shape(x1) : 得到的shape2是一个tf.tensor,需要通过sess.run(shape2)来获取具体数值。通常用于提取dynamic shape.

#################################################################################################

3) 如何设置tensor的shape:x1.set_shape()与tf.reshape(x1)

   令x1是图上的一个节点上的张量,那么:

    x1.set_shape(shape1) : 当x1的shape未知时,用set_shape设置x1的shape,通常针对静态shape,因为只有静态shape才有可能是未知。当x1的static shape已知时,如果shape1与x1的static shape不一致,会报错。==》 set_shape只用于补全tensor的静态shape信息,方便后续使用该shape信息,本身并不能改变shape.  ==》PS:如果后续获得的动态shape与set_shape设置的静态shape不一致会报错。

  x2 = tf.reshape(x1,shape2) x2是生成的一个新的张量,由x1中的元素组成,shape=shape2.能成功运行的条件是x1中元素的个数与shape2容纳的元素个数相等。如果后续用x2代替x1使用,相当于既对x1执行了set_shape操作,又改变了x1的shape.

##################################################################################################

总结:

a) 通常"tf."的操作都会返回一个张量。(个人总结的规律,未做广泛验证)

b)set_shape通常与tf.placeholder结合使用,因为tf.placeholder并不限制输入的tensor的shape,可以通过set_shape限制输入的tensor的shape,可以方便后面使用输入的tensor的shape的同时,还可以检查执行feed_dict时,传入的数据shape是否符合要求。

 

参考:https://blog.youkuaiyun.com/qq_21949357/article/details/77987928

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.9

TensorFlow-v2.9

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

# -*- coding: utf-8 -*- """ DKT-DSC for Assistment2012 (完整可运行版) 最后更新: 2024-07-01 """ import os import sys import numpy as np import tensorflow.compat.v1 as tf os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" config = tf.ConfigProto() config.gpu_options.allow_growth = True tf.disable_v2_behavior() try: import psutil HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False print("警告: psutil模块未安装,内存监控功能受限") from scipy.sparse import coo_matrix from tensorflow.contrib import rnn import pandas as pd from tqdm import tqdm from sklearn.metrics import mean_squared_error, r2_score, roc_curve, auc import math import random from datetime import datetime import warnings # 忽略警告 warnings.filterwarnings('ignore') # ==================== 配置部分 ==================== DATA_BASE_PATH = './data/' data_name = 'Assist_2012' # 模拟知识图谱路径(实际使用时替换为真实路径) KNOWLEDGE_GRAPH_PATHS = { 'graphml': './output_assist2012_gat_improved/knowledge_graph.graphml', 'nodes': './output_assist2012_gat_improved/graph_nodes.csv', 'edges': './output_assist2012_gat_improved/graph_edges.csv' } # 创建模拟数据路径 os.makedirs(DATA_BASE_PATH, exist_ok=True) os.makedirs(os.path.dirname(KNOWLEDGE_GRAPH_PATHS['nodes']), exist_ok=True) # ==================== 模拟数据生成 ==================== def generate_mock_data(): """生成模拟数据用于测试""" # 生成模拟训练数据 (300条记录) train_data = pd.DataFrame({ 'user_id': np.repeat(range(10), 30), 'problem_id': np.random.randint(1, 100, 300), 'correct': np.random.randint(0, 2, 300), 'start_time': np.arange(300) # 使用简单递增数字模拟时间戳 }) train_data.to_csv(os.path.join(DATA_BASE_PATH, f'{data_name}_train.csv'), index=False) # 生成模拟测试数据 (100条记录) test_data = pd.DataFrame({ 'user_id': np.repeat(range(5), 20), 'problem_id': np.random.randint(1, 100, 100), 'correct': np.random.randint(0, 2, 100), 'start_time': np.arange(100) + 300 # 时间戳接续训练数据 }) test_data.to_csv(os.path.join(DATA_BASE_PATH, f'{data_name}_test.csv'), index=False) # 生成模拟知识图谱节点数据 node_ids = [f'problem_{i}' for i in range(1, 101)] + \ [f'concept_{i}' for i in range(1, 21)] node_types = ['problem'] * 100 + ['concept'] * 20 mock_node_data = pd.DataFrame({ 'node_id': node_ids, 'type': node_types, 'difficulty': np.random.rand(120), 'avg_accuracy': np.random.rand(120), 'total_attempts': np.random.randint(100, 1000, 120), 'avg_confidence': np.random.rand(120) }) mock_node_data.to_csv(KNOWLEDGE_GRAPH_PATHS['nodes'], index=False) # 生成模拟边数据 sources = np.random.choice(node_ids, 500) targets = np.random.choice(node_ids, 500) weights = np.random.rand(500) mock_edge_data = pd.DataFrame({ 'source': sources, 'target': targets, 'weight': weights }) mock_edge_data.to_csv(KNOWLEDGE_GRAPH_PATHS['edges'], index=False) # 检查并生成模拟数据 if not os.path.exists(os.path.join(DATA_BASE_PATH, f'{data_name}_train.csv')): print("[系统] 检测到缺少数据文件,正在生成模拟数据...") generate_mock_data() # ==================== Flags配置 ==================== tf.flags.DEFINE_float("epsilon", 1e-8, "Adam优化器的epsilon值") tf.flags.DEFINE_float("l2_lambda", 0.003, "L2正则化系数") tf.flags.DEFINE_float("learning_rate", 2e-4, "学习率") tf.flags.DEFINE_float("max_grad_norm", 5.0, "梯度裁剪阈值") tf.flags.DEFINE_float("keep_prob", 0.7, "Dropout保留概率") tf.flags.DEFINE_integer("hidden_layer_num", 2, "隐藏层数量") tf.flags.DEFINE_integer("hidden_size", 64, "隐藏层大小") tf.flags.DEFINE_integer("evaluation_interval", 1, "评估间隔周期数") tf.flags.DEFINE_integer("batch_size", 32, "批次大小") # 减小批次大小以便在模拟数据上运行 tf.flags.DEFINE_integer("problem_len", 20, "问题序列长度") tf.flags.DEFINE_integer("epochs", 5, "训练周期数") # 减少epoch以便快速测试 tf.flags.DEFINE_boolean("allow_soft_placement", True, "允许软设备放置") tf.flags.DEFINE_boolean("log_device_placement", False, "记录设备放置信息") tf.flags.DEFINE_string("train_data_path", os.path.join(DATA_BASE_PATH, f'{data_name}_train.csv'), "训练数据路径") tf.flags.DEFINE_string("test_data_path", os.path.join(DATA_BASE_PATH, f'{data_name}_test.csv'), "测试数据路径") FLAGS = tf.flags.FLAGS # 焦点损失参数 FOCAL_LOSS_GAMMA = 2.0 FOCAL_LOSS_ALPHA = 0.25 # 学习率衰减参数 DECAY_STEPS = 100 DECAY_RATE = 0.97 # 早停参数 EARLY_STOP_PATIENCE = 3 def memory_usage(): if HAS_PSUTIL: try: process = psutil.Process(os.getpid()) return process.memory_info().rss / (1024 ** 2) except: return 0.0 return 0.0 # ==================== 时间戳处理工具函数 ==================== def parse_timestamp(timestamp_str): """尝试多种格式解析时间戳""" if isinstance(timestamp_str, (int, float, np.number)): return float(timestamp_str) if isinstance(timestamp_str, str): timestamp_str = timestamp_str.strip('"\' ') # 尝试常见时间格式 for fmt in ('%Y-%m-%d %H:%M:%S', '%m/%d/%Y %H:%M', '%Y-%m-%d', '%s'): try: if fmt == '%s': # Unix时间戳 return float(timestamp_str) dt = datetime.strptime(timestamp_str, fmt) return dt.timestamp() except ValueError: continue return np.nan # ==================== 知识图谱加载器 ==================== class KnowledgeGraphLoader: def __init__(self): self.node_features = None self.adj_matrix = None self.problem_to_node = {} self.node_id_map = {} self.static_node_count = 0 self._rows = None self._cols = None def load(self): print("\n[KG] 加载知识图谱...") try: if not os.path.exists(KNOWLEDGE_GRAPH_PATHS['nodes']): raise FileNotFoundError(f"节点文件未找到: {KNOWLEDGE_GRAPH_PATHS['nodes']}") if not os.path.exists(KNOWLEDGE_GRAPH_PATHS['edges']): raise FileNotFoundError(f"边文件未找到: {KNOWLEDGE_GRAPH_PATHS['edges']}") node_df = pd.read_csv(KNOWLEDGE_GRAPH_PATHS['nodes']) self.static_node_count = len(node_df) print(f"[KG] 总节点数: {self.static_node_count}") # 处理空值 print("[KG] 处理特征空值...") feature_cols = [col for col in node_df.columns if col not in ['node_id', 'type']] for col in feature_cols: if node_df[col].isna().any(): if 'accuracy' in col or 'confidence' in col: median_val = node_df[col].median() node_df[col] = node_df[col].fillna(median_val) else: for node_type in ['problem', 'concept']: mask = node_df['type'] == node_type type_median = node_df.loc[mask, col].median() node_df.loc[mask, col] = node_df.loc[mask, col].fillna(type_median) # 特征标准化 raw_features = node_df[feature_cols].values raw_features = np.nan_to_num(raw_features) feature_mean = np.mean(raw_features, axis=0) feature_std = np.std(raw_features, axis=0) + 1e-8 self.node_features = np.array( (raw_features - feature_mean) / feature_std, dtype=np.float32 ) # 创建映射 self.node_id_map = {row['node_id']: idx for idx, row in node_df.iterrows()} # 创建问题映射 self.problem_to_node = {} problem_count = 0 for idx, row in node_df.iterrows(): if row['type'] == 'problem': try: problem_id = int(row['node_id'].split('_')[1]) self.problem_to_node[problem_id] = idx problem_count += 1 except (IndexError, ValueError): continue print(f"[KG] 已加载 {problem_count} 个问题节点映射") # 加载边数据 edge_df = pd.read_csv(KNOWLEDGE_GRAPH_PATHS['edges']) rows, cols, data = [], [], [] grouped = edge_df.groupby('source') for src, group in tqdm(grouped, total=len(grouped), desc="处理边数据"): src_idx = self.node_id_map.get(src, -1) if src_idx == -1: continue neighbors = [] for _, row in group.iterrows(): tgt_idx = self.node_id_map.get(row['target'], -1) if tgt_idx != -1: neighbors.append((tgt_idx, row['weight'])) neighbors.sort(key=lambda x: x[1], reverse=True) top_k = min(100, len(neighbors)) for i in range(top_k): rows.append(src_idx) cols.append(neighbors[i][0]) data.append(neighbors[i][1]) # 添加自环 for i in range(self.static_node_count): rows.append(i) cols.append(i) data.append(1.0) # 创建稀疏矩阵 adj_coo = coo_matrix( (data, (rows, cols)), shape=(self.static_node_count, self.static_node_count), dtype=np.float32 ) self.adj_matrix = adj_coo.tocsc() self._rows = np.array(rows) self._cols = np.array(cols) except Exception as e: print(f"知识图谱加载失败: {str(e)}") raise # ==================== 图注意力层 ==================== class GraphAttentionLayer: def __init__(self, input_dim, output_dim, kg_loader, scope=None): self.kg_loader = kg_loader self.node_count = kg_loader.static_node_count self._rows = kg_loader._rows self._cols = kg_loader._cols with tf.variable_scope(scope or "GAT"): self.W = tf.get_variable( "W", [input_dim, output_dim], initializer=tf.initializers.variance_scaling( scale=0.1, mode='fan_avg', distribution='uniform') ) self.attn_kernel = tf.get_variable( "attn_kernel", [output_dim * 2, 1], initializer=tf.initializers.variance_scaling( scale=0.1, mode='fan_avg', distribution='uniform') ) self.bias = tf.get_variable( "bias", [output_dim], initializer=tf.zeros_initializer() ) def __call__(self, inputs): inputs = tf.clip_by_value(inputs, -5, 5) h = tf.matmul(inputs, self.W) h = tf.clip_by_value(h, -5, 5) h_src = tf.gather(h, self._rows) h_dst = tf.gather(h, self._cols) h_concat = tf.concat([h_src, h_dst], axis=1) edge_logits = tf.squeeze(tf.matmul(h_concat, self.attn_kernel), axis=1) edge_logits = tf.clip_by_value(edge_logits, -10, 10) edge_attn = tf.nn.leaky_relu(edge_logits, alpha=0.2) edge_indices = tf.constant(np.column_stack((self._rows, self._cols)), dtype=tf.int64) sparse_attn = tf.SparseTensor( indices=edge_indices, values=edge_attn, dense_shape=[self.node_count, self.node_count] ) sparse_attn_weights = tf.sparse_softmax(sparse_attn) output = tf.sparse_tensor_dense_matmul(sparse_attn_weights, h) output = tf.clip_by_value(output, -5, 5) output += self.bias output = tf.nn.elu(output) return output # ==================== 学生知识追踪模型 ==================== class StudentModel: def __init__(self, is_training, config): self.batch_size = config.batch_size # 添加这行 self.batch_size_tensor = tf.placeholder(tf.int32, [], name='batch_size_placeholder') self.num_skills = config.num_skills self.num_steps = config.num_steps self.current = tf.placeholder(tf.int32, [None, self.num_steps], name='current') self.next = tf.placeholder(tf.int32, [None, self.num_steps], name='next') self.target_id = tf.placeholder(tf.int32, [None], name='target_ids') self.target_correctness = tf.placeholder(tf.float32, [None], name='target_correctness') with tf.device('/gpu:0'), tf.variable_scope("KnowledgeGraph", reuse=tf.AUTO_REUSE): kg_loader = KnowledgeGraphLoader() kg_loader.load() kg_node_features = tf.constant(kg_loader.node_features, dtype=tf.float32) # 增强GAT结构 gat_output = kg_node_features for i in range(2): with tf.variable_scope(f"GAT_Layer_{i + 1}"): dim = 64 if i == 0 else 32 gat_layer = GraphAttentionLayer( input_dim=gat_output.shape[1] if i > 0 else kg_node_features.shape[1], output_dim=dim, kg_loader=kg_loader ) gat_output = gat_layer(gat_output) gat_output = tf.nn.leaky_relu(gat_output, alpha=0.1) self.skill_embeddings = gat_output with tf.variable_scope("FeatureProcessing"): # 使用实际batch_size的placeholder batch_size = tf.shape(self.next)[0] # 初始化方法1:使用tf.zeros_like和tile dummy_vector = tf.zeros([1, 1], dtype=tf.float32) history_init = tf.tile(dummy_vector, [batch_size, 1]) elapsed_init = tf.tile(dummy_vector, [batch_size, 1]) # 或者初始化方法2:直接使用tf.fill # history_init = tf.fill([batch_size, 1], 0.0) # elapsed_init = tf.fill([batch_size, 1], 0.0) current_indices = tf.minimum(self.current, kg_loader.static_node_count - 1) current_embed = tf.nn.embedding_lookup(self.skill_embeddings, current_indices) inputs = [] valid_mask = tf.cast(tf.not_equal(self.current, 0), tf.float32) answers_float = tf.cast(self.next, tf.float32) # 初始化历史和耗时特征 history = history_init elapsed_time = elapsed_init for t in range(self.num_steps): if t > 0: past_answers = answers_float[:, :t] past_valid_mask = valid_mask[:, :t] correct_count = tf.reduce_sum(past_answers * past_valid_mask, axis=1, keepdims=True) total_valid = tf.reduce_sum(past_valid_mask, axis=1, keepdims=True) history = correct_count / (total_valid + 1e-8) elapsed_time = tf.fill([batch_size, 1], tf.cast(t, tf.float32)) with tf.variable_scope(f"feature_extraction_t{t}"): # 基础特征 current_feat = current_embed[:, t, :] # 知识图谱特征 difficulty_feature = tf.gather( kg_loader.node_features[:, 0], tf.minimum(self.current[:, t], kg_loader.static_node_count - 1) ) difficulty_feature = tf.reshape(difficulty_feature, [-1, 1]) # 情感特征 affect_features = [] for i in range(1, 3): try: affect_feature = tf.gather( kg_loader.node_features[:, i], tf.minimum(self.current[:, t], kg_loader.static_node_count - 1) ) affect_feature = tf.reshape(affect_feature, [-1, 1]) affect_features.append(affect_feature) except Exception as e: tf.logging.warning(f"情感特征{i}提取失败: {str(e)}") affect_features.append(tf.zeros_like(difficulty_feature)) # 确保所有特征都是2维的 features_to_concat = [current_feat, history, elapsed_time, difficulty_feature] + affect_features features_to_concat = [ f if len(f.shape) == 2 else tf.reshape(f, [-1, 1]) for f in features_to_concat ] # 调试信息(可选) if is_training: features_to_concat = [ tf.Print(f, [tf.shape(f)], message=f"Feature {i} shape at step {t}: ") for i, f in enumerate(features_to_concat) ] combined = tf.concat(features_to_concat, axis=1) inputs.append(combined) # 增强RNN结构 with tf.variable_scope("RNN"): cells = [] for i in range(2): cell = rnn.LSTMCell( FLAGS.hidden_size, initializer=tf.orthogonal_initializer(), forget_bias=1.0 ) if is_training and FLAGS.keep_prob < 1.0: cell = rnn.DropoutWrapper(cell, output_keep_prob=FLAGS.keep_prob) cells.append(cell) stacked_cell = rnn.MultiRNNCell(cells) outputs, _ = tf.nn.dynamic_rnn( stacked_cell, tf.stack(inputs, axis=1), dtype=tf.float32 ) output = tf.reshape(outputs, [-1, FLAGS.hidden_size]) with tf.variable_scope("Output"): hidden = tf.layers.dense( output, units=32, activation=tf.nn.relu, kernel_initializer=tf.initializers.glorot_uniform() ) logits = tf.layers.dense( hidden, units=1, kernel_initializer=tf.initializers.glorot_uniform() ) self._all_logits = tf.clip_by_value(logits, -20, 20) selected_logits = tf.gather(tf.reshape(self._all_logits, [-1]), self.target_id) self.pred = tf.clip_by_value(tf.sigmoid(selected_logits), 1e-8, 1 - 1e-8) with tf.variable_scope("Loss"): labels = tf.clip_by_value(self.target_correctness, 0.05, 0.95) pos_weight = tf.reduce_sum(1.0 - labels) / (tf.reduce_sum(labels) + 1e-8) bce_loss = tf.nn.weighted_cross_entropy_with_logits( targets=labels, logits=selected_logits, pos_weight=pos_weight ) confidence_penalty = tf.reduce_mean( tf.square(tf.sigmoid(selected_logits) - 0.5) ) loss = tf.reduce_mean(bce_loss) + 0.1 * confidence_penalty l2_loss = tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name ]) * FLAGS.l2_lambda self.cost = loss + l2_loss # ==================== 数据加载 ==================== def read_data_from_csv_file(path, kg_loader, is_training=False): students = [] student_ids = [] max_skill = 0 missing_problems = set() if not os.path.exists(path): print(f"❌ 文件不存在: {path}") return [], [], [], 0, 0, 0 try: print(f"[数据] 加载数据文件: {path}") try: data_df = pd.read_csv(path) except Exception as e: print(f"CSV读取失败: {str(e)}") encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252'] for encoding in encodings: try: data_df = pd.read_csv(path, encoding=encoding) break except: continue if 'data_df' not in locals(): return [], [], [], 0, 0, 0 # 列名标准化 possible_columns = { 'user_id': ['user_id', 'userid', 'student_id', 'studentid'], 'problem_id': ['problem_id', 'problemid', 'skill_id', 'skillid'], 'correct': ['correct', 'correctness', 'answer', 'accuracy'], 'start_time': ['start_time', 'timestamp', 'time', 'date'] } actual_columns = {} for col_type, possible_names in possible_columns.items(): found = False for name in possible_names: if name in data_df.columns: actual_columns[col_type] = name found = True break if not found: print(f"❌ 错误: 找不到 {col_type} 列") return [], [], [], 0, 0, 0 data_df = data_df.rename(columns={ actual_columns['user_id']: 'user_id', actual_columns['problem_id']: 'problem_id', actual_columns['correct']: 'correct', actual_columns['start_time']: 'start_time' }) # 时间戳转换 print("[数据] 转换时间戳...") timestamp_col = data_df['start_time'] if isinstance(timestamp_col.iloc[0], str): try: data_df['start_time'] = timestamp_col.astype(float) except ValueError: parsed_times = timestamp_col.apply(parse_timestamp) nan_count = parsed_times.isna().sum() if nan_count > 0: print(f"⚠️ 警告: {nan_count}个时间戳无法解析,将设为0") parsed_times = parsed_times.fillna(0) data_df['start_time'] = parsed_times else: data_df['start_time'] = timestamp_col.astype(float) # 按学生分组 grouped = data_df.groupby('user_id') for user_id, group in tqdm(grouped, total=len(grouped), desc="处理学生数据"): try: group = enhanced_data_validation(group, kg_loader) if group is None: continue problems = group['problem_id'].values answers = group['correct'].values.astype(int) timestamps = group['start_time'].values.astype(float) valid_data = [] invalid_count = 0 for i, (p, a) in enumerate(zip(problems, answers)): if p in kg_loader.problem_to_node and a in (0, 1): valid_data.append((p, a)) else: invalid_count += 1 if p != 0 and p not in missing_problems: missing_problems.add(p) if len(valid_data) < 2: continue problems, answers = zip(*valid_data) n_split = (len(problems) + FLAGS.problem_len - 1) // FLAGS.problem_len for k in range(n_split): start = k * FLAGS.problem_len end = (k + 1) * FLAGS.problem_len seg_problems = list(problems[start:end]) seg_answers = list(answers[start:end]) if len(seg_problems) < FLAGS.problem_len: pad_len = FLAGS.problem_len - len(seg_problems) seg_problems += [0] * pad_len seg_answers += [0] * pad_len mapped_problems = [kg_loader.problem_to_node.get(p, 0) for p in seg_problems] students.append(([user_id, k], mapped_problems, seg_answers)) max_skill = max(max_skill, max(mapped_problems)) student_ids.append(user_id) except Exception as e: print(f"处理学生 {user_id} 时出错: {str(e)}") continue except Exception as e: print(f"数据加载失败: {str(e)}") return [], [], [], 0, 0, 0 return students, [], student_ids, max_skill, 0, 0 def enhanced_data_validation(group, kg_loader): """增强数据验证""" problems = group['problem_id'].values timestamps = group['start_time'].values.astype(float) valid_indices = np.where(~np.isnan(timestamps))[0] if len(valid_indices) > 1: time_diffs = np.diff(timestamps[valid_indices]) if np.any(time_diffs < 0): sort_idx = np.argsort(timestamps) group = group.iloc[sort_idx].reset_index(drop=True) valid_mask = [p in kg_loader.problem_to_node for p in problems] if not any(valid_mask): return None return group[valid_mask] # ==================== 训练流程 ==================== def run_epoch(session, model, data, run_type, eval_op, verbose=False): """执行一个epoch的训练或评估 Args: session: TF会话 model: 模型对象 data: 输入数据 run_type: '训练'或'测试' eval_op: 训练op或tf.no_op() verbose: 是否显示详细进度 Returns: dict: 包含loss, auc, rmse, r2的字典 """ preds = [] labels = [] total_loss = 0.0 processed_count = 0 # 禁用TF调试信息 tf.logging.set_verbosity(tf.logging.ERROR) index = 0 batch_size = model.batch_size # 可选:使用tqdm进度条(verbose模式下) iterator = tqdm(range(0, len(data), batch_size), desc=f"{run_type}处理中") if verbose else range(0, len(data), batch_size) for start in iterator: end = min(start + batch_size, len(data)) batch_data = data[start:end] # 准备批次数据 current_batch, next_batch, target_ids, target_correctness = [], [], [], [] for idx, (stu_id, problems, answers) in enumerate(batch_data): valid_length = sum(1 for p in problems if p != 0) if valid_length < 1: continue current_batch.append(problems) next_batch.append(answers) last_step = valid_length - 1 target_ids.append(idx * model.num_steps + last_step) target_correctness.append(answers[last_step]) if not current_batch: continue actual_batch_size = len(current_batch) feed_dict = { model.current: np.array(current_batch, dtype=np.int32), model.next: np.array(next_batch, dtype=np.int32), model.target_id: np.array(target_ids, dtype=np.int32), model.target_correctness: np.array(target_correctness, dtype=np.float32) } try: if eval_op != tf.no_op(): _, pred, loss = session.run( [eval_op, model.pred, model.cost], feed_dict=feed_dict ) else: pred, loss = session.run( [model.pred, model.cost], feed_dict=feed_dict ) preds.extend(pred.flatten().tolist()) labels.extend(target_correctness) total_loss += loss * actual_batch_size processed_count += actual_batch_size except Exception as e: print(f"\n{run_type}错误 (批次 {start}-{end}): {str(e)}", file=sys.stderr) continue # 计算指标 if processed_count == 0: return None avg_loss = total_loss / processed_count # 确保标签和预测值在有效范围内 labels = np.clip(np.array(labels), 1e-7, 1 - 1e-7) preds = np.clip(np.array(preds), 1e-7, 1 - 1e-7) metrics = { 'loss': avg_loss, 'auc': roc_auc_score(labels, preds) if len(set(labels)) > 1 else 0.5, 'rmse': np.sqrt(mean_squared_error(labels, preds)), 'r2': r2_score(labels, preds) } return metrics def main(_): """主训练流程""" # 1. 加载配置和数据 config = ModelConfig() # 假设已定义 train_data, test_data = load_data() # 假设已定义 # 2. 构建模型 with tf.variable_scope("Model", reuse=False): train_model = StudentModel(is_training=True, config=config) with tf.variable_scope("Model", reuse=True): test_model = StudentModel(is_training=False, config=config) # 3. 创建会话 sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True with tf.Session(config=sess_config) as sess: # 4. 初始化变量 sess.run(tf.global_variables_initializer()) # 5. 训练循环 best_auc = 0.0 for epoch in range(1, FLAGS.max_epochs + 1): # 训练阶段 train_metrics = run_epoch( sess, train_model, train_data, '训练', train_op, # train_op应已定义 verbose=(epoch % FLAGS.display_freq == 0) ) # 测试阶段 test_metrics = run_epoch( sess, test_model, test_data, '测试', tf.no_op(), verbose=False ) # 6. 输出关键指标 print(f"Epoch {epoch}") print( f"训练集 - 损失: {train_metrics['loss']:.4f}, RMSE: {train_metrics['rmse']:.4f}, AUC: {train_metrics['auc']:.4f}, R²: {train_metrics['r2']:.4f}") print( f"测试集 - 损失: {test_metrics['loss']:.4f}, RMSE: {test_metrics['rmse']:.4f}, AUC: {test_metrics['auc']:.4f}, R²: {test_metrics['r2']:.4f}") sys.stdout.flush() # 7. 保存最佳模型 if test_metrics['auc'] > best_auc: best_auc = test_metrics['auc'] saver.save(sess, FLAGS.model_path) # saver应已定义 print("训练完成!") print(f"最佳测试AUC: {best_auc:.4f}") if __name__ == "__main__": # 生成模拟数据(仅当真实数据不存在时) if not os.path.exists(FLAGS.train_data_path) or not os.path.exists(FLAGS.test_data_path): generate_mock_data() tf.app.run() 在这个基础上修改得到的完整代码,你给的不完整,不要省略!!!!
07-02
# -*- coding: utf-8 -*- """ DKT-DSC for Assistment2012 (优化版) - 修复数据泄露问题 最后更新: 2024-07-01 """ import os import sys import numpy as np import tensorflow.compat.v1 as tf os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" config = tf.ConfigProto() config.gpu_options.allow_growth = True tf.disable_v2_behavior() # 安全导入psutil模块 try: import psutil HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False print("警告: psutil模块未安装,内存监控功能受限") from scipy.sparse import coo_matrix from tensorflow.contrib import rnn import pandas as pd from tqdm import tqdm from sklearn.metrics import mean_squared_error, r2_score, roc_curve, auc import math import random # ==================== 配置部分 ==================== # 使用实际数据路径 DATA_BASE_PATH = '/home/yhh/students/jianglu/DKT2/DKT/data/' data_name = 'Assist_2012' # 修正数据集名称 KNOWLEDGE_GRAPH_PATHS = { 'graphml': './output_assist2012_gat_improved/knowledge_graph.graphml', 'nodes': './output_assist2012_gat_improved/graph_nodes.csv', 'edges': './output_assist2012_gat_improved/graph_edges.csv' } # ==================== Flags配置 ==================== tf.flags.DEFINE_float("epsilon", 1e-8, "Adam优化器的epsilon值") tf.flags.DEFINE_float("l2_lambda", 0.005, "L2正则化系数") # 减小正则化强度 tf.flags.DEFINE_float("learning_rate", 1e-4, "学习率") tf.flags.DEFINE_float("max_grad_norm", 3.0, "梯度裁剪阈值") # 更严格的梯度裁剪 tf.flags.DEFINE_float("keep_prob", 0.8, "Dropout保留概率") # 减小dropout tf.flags.DEFINE_integer("hidden_layer_num", 1, "隐藏层数量") tf.flags.DEFINE_integer("hidden_size", 48, "隐藏层大小") # 增加隐藏层大小 tf.flags.DEFINE_integer("evaluation_interval", 2, "评估间隔周期数") tf.flags.DEFINE_integer("batch_size", 128, "批次大小") tf.flags.DEFINE_integer("problem_len", 15, "问题序列长度") # 增加序列长度 tf.flags.DEFINE_integer("epochs", 100, "训练周期数") tf.flags.DEFINE_boolean("allow_soft_placement", True, "允许软设备放置") tf.flags.DEFINE_boolean("log_device_placement", False, "记录设备放置信息") tf.flags.DEFINE_string("train_data_path", f'{DATA_BASE_PATH}{data_name}_train.csv', "训练数据路径") tf.flags.DEFINE_string("test_data_path", f'{DATA_BASE_PATH}{data_name}_test.csv', "测试数据路径") FLAGS = tf.flags.FLAGS # 焦点损失参数 FOCAL_LOSS_GAMMA = 1.5 # 调整焦点损失参数 FOCAL_LOSS_ALPHA = 0.3 # 学习率衰减参数 DECAY_STEPS = 2000 DECAY_RATE = 0.95 # 学习率预热步数 WARMUP_STEPS = 2000 # 内存监控函数 def memory_usage(): """增强的内存监控函数,处理psutil缺失情况""" if HAS_PSUTIL: try: process = psutil.Process(os.getpid()) return process.memory_info().rss / (1024 ** 2) except: return 0.0 return 0.0 # ==================== 知识图谱加载器 ==================== class KnowledgeGraphLoader: def __init__(self): self.node_features = None self.adj_matrix = None self.problem_to_node = {} self.node_id_map = {} self.static_node_count = 0 self._rows = None self._cols = None def load(self): """加载知识图谱数据并进行严格的数据验证""" print("\n[KG] 加载知识图谱...") try: if not os.path.exists(KNOWLEDGE_GRAPH_PATHS['nodes']): raise FileNotFoundError(f"节点文件未找到: {KNOWLEDGE_GRAPH_PATHS['nodes']}") if not os.path.exists(KNOWLEDGE_GRAPH_PATHS['edges']): raise FileNotFoundError(f"边文件未找到: {KNOWLEDGE_GRAPH_PATHS['edges']}") node_df = pd.read_csv(KNOWLEDGE_GRAPH_PATHS['nodes']) self.static_node_count = len(node_df) print(f"[KG] 总节点数: {self.static_node_count}") # 处理空值 - 根据验证报告中的发现 print("[KG] 处理特征空值...") feature_cols = [col for col in node_df.columns if col not in ['node_id', 'type']] # 特别处理total_attempts特征 if 'total_attempts' in feature_cols: # 概念节点使用概念节点中位数填充 concept_mask = node_df['type'] == 'concept' concept_median = node_df.loc[concept_mask, 'total_attempts'].median() # 处理NaN值 if pd.isna(concept_median): concept_median = 0.0 node_df.loc[concept_mask, 'total_attempts'] = node_df.loc[concept_mask, 'total_attempts'].fillna(concept_median) # 问题节点使用问题节点中位数填充 problem_mask = node_df['type'] == 'problem' problem_median = node_df.loc[problem_mask, 'total_attempts'].median() # 处理NaN值 if pd.isna(problem_median): problem_median = 0.0 node_df.loc[problem_mask, 'total_attempts'] = node_df.loc[problem_mask, 'total_attempts'].fillna(problem_median) print(f" 填充 total_attempts 缺失值: 概念节点={concept_median}, 问题节点={problem_median}") # 处理其他数值特征 other_cols = [col for col in feature_cols if col != 'total_attempts'] for col in other_cols: # 分类型填充 if 'confidence' in col or 'affect' in col: # 情感特征使用全局平均值填充 global_mean = node_df[col].mean() # 处理NaN值 if pd.isna(global_mean): global_mean = 0.0 node_df[col] = node_df[col].fillna(global_mean) print(f" 填充 {col} 缺失值: 全局均值={global_mean:.4f}") else: # 其他特征按问题类型分组填充 problem_mask = node_df['type'] == 'problem' problem_mean = node_df.loc[problem_mask, col].mean() # 处理NaN值 if pd.isna(problem_mean): problem_mean = 0.0 node_df.loc[problem_mask, col] = node_df.loc[problem_mask, col].fillna(problem_mean) concept_mask = node_df['type'] == 'concept' concept_mean = node_df.loc[concept_mask, col].mean() # 处理NaN值 if pd.isna(concept_mean): concept_mean = 0.0 node_df.loc[concept_mask, col] = node_df.loc[concept_mask, col].fillna(concept_mean) print(f" 填充 {col} 缺失值: 问题节点={problem_mean:.4f}, 概念节点={concept_mean:.4f}") print("\n[KG诊断] 特征分析...") if feature_cols: raw_features = node_df[feature_cols].values nan_count = np.isnan(raw_features).sum() inf_count = np.isinf(raw_features).sum() print(f" 总特征值数: {raw_features.size}") print(f" NaN特征数: {nan_count}") print(f" Inf特征数: {inf_count}") if nan_count > 0 or inf_count > 0: print(f"⚠️ 警告: 节点特征包含 {nan_count} 个NaN和 {inf_count} 个Inf值,将被替换为0") raw_features = np.nan_to_num(raw_features) # 标准化特征并确保为float32类型 feature_mean = np.mean(raw_features, axis=0) feature_std = np.std(raw_features, axis=0) + 1e-8 self.node_features = np.array( (raw_features - feature_mean) / feature_std, dtype=np.float32 # 显式指定为float32 ) self.node_features = np.nan_to_num(self.node_features) # 再次确保无NaN else: print("警告: 节点文件中没有特征列") self.node_features = np.zeros((self.static_node_count, 1), dtype=np.float32) # 创建节点ID映射 self.node_id_map = {} for idx, row in node_df.iterrows(): self.node_id_map[row['node_id']] = idx # 创建问题ID到节点索引的映射 self.problem_to_node = {} problem_count = 0 for idx, row in node_df.iterrows(): if row['type'] == 'problem': try: parts = row['node_id'].split('_') if len(parts) < 2: continue problem_id = int(parts[1]) self.problem_to_node[problem_id] = idx problem_count += 1 except: continue print(f"[KG] 已加载 {problem_count} 个问题节点映射") # 加载边数据并进行优化 edge_df = pd.read_csv(KNOWLEDGE_GRAPH_PATHS['edges']) print("[KG] 优化邻接矩阵(保留每个节点的前100个邻居)...") rows, cols, data = [], [], [] valid_edge_count = 0 invalid_edge_count = 0 # 限制每个节点的邻居数量以提高效率 grouped = edge_df.groupby('source') for src, group in tqdm(grouped, total=len(grouped), desc="处理边数据"): src_idx = self.node_id_map.get(src, -1) if src_idx == -1: invalid_edge_count += len(group) continue neighbors = [] for _, row in group.iterrows(): tgt_idx = self.node_id_map.get(row['target'], -1) if tgt_idx != -1: neighbors.append((tgt_idx, row['weight'])) # 根据权重排序并取Top 100 neighbors.sort(key=lambda x: x[1], reverse=True) top_k = min(100, len(neighbors)) # 限制邻居数量 for i in range(top_k): tgt_idx, weight = neighbors[i] rows.append(src_idx) cols.append(tgt_idx) data.append(weight) valid_edge_count += 1 # 添加自环 for i in range(self.static_node_count): rows.append(i) cols.append(i) data.append(1.0) valid_edge_count += 1 # 创建稀疏邻接矩阵 adj_coo = coo_matrix( (data, (rows, cols)), shape=(self.static_node_count, self.static_node_count), dtype=np.float32 ) self.adj_matrix = adj_coo.tocsc() self._rows = np.array(rows) self._cols = np.array(cols) print(f"[KG] 邻接矩阵构建完成 | 节点: {self.static_node_count} | 边: {len(data)}") print(f"[KG优化] 最大行索引: {np.max(self._rows)} | 最大列索引: {np.max(self._cols)}") except Exception as e: import traceback print(f"知识图谱加载失败: {str(e)}") traceback.print_exc() raise RuntimeError(f"知识图谱加载失败: {str(e)}") from e # ==================== 图注意力层 ==================== class GraphAttentionLayer: def __init__(self, input_dim, output_dim, kg_loader, scope=None): self.kg_loader = kg_loader self.node_count = kg_loader.static_node_count self._rows = kg_loader._rows self._cols = kg_loader._cols with tf.variable_scope(scope or "GAT"): self.W = tf.get_variable( "W", [input_dim, output_dim], initializer=tf.initializers.variance_scaling( scale=0.1, mode='fan_avg', distribution='uniform') ) self.attn_kernel = tf.get_variable( "attn_kernel", [output_dim * 2, 1], initializer=tf.initializers.variance_scaling( scale=0.1, mode='fan_avg', distribution='uniform') ) self.bias = tf.get_variable( "bias", [output_dim], initializer=tf.zeros_initializer() ) def __call__(self, inputs): inputs = tf.clip_by_value(inputs, -5, 5) inputs = tf.check_numerics(inputs, "GAT输入包含NaN或Inf") # 特征变换 h = tf.matmul(inputs, self.W) h = tf.clip_by_value(h, -5, 5) h = tf.check_numerics(h, "特征变换后包含NaN或Inf") # 注意力机制 h_src = tf.gather(h, self._rows) h_dst = tf.gather(h, self._cols) h_concat = tf.concat([h_src, h_dst], axis=1) edge_logits = tf.squeeze(tf.matmul(h_concat, self.attn_kernel), axis=1) edge_logits = tf.clip_by_value(edge_logits, -10, 10) edge_attn = tf.nn.leaky_relu(edge_logits, alpha=0.2) # 创建稀疏注意力矩阵 edge_indices = tf.constant(np.column_stack((self._rows, self._cols)), dtype=tf.int64) sparse_attn = tf.SparseTensor( indices=edge_indices, values=edge_attn, dense_shape=[self.node_count, self.node_count] ) # 稀疏softmax和矩阵乘法 sparse_attn_weights = tf.sparse_softmax(sparse_attn) output = tf.sparse_tensor_dense_matmul(sparse_attn_weights, h) output = tf.clip_by_value(output, -5, 5) output += self.bias output = tf.nn.elu(output) output = tf.check_numerics(output, "最终GAT输出包含NaN或Inf") return output # ==================== 学生知识追踪模型 ==================== class StudentModel: def __init__(self, is_training, config): self.batch_size = config.batch_size self.num_skills = config.num_skills self.num_steps = config.num_steps self.current = tf.placeholder(tf.int32, [None, self.num_steps], name='current') self.next = tf.placeholder(tf.int32, [None, self.num_steps], name='next') self.target_id = tf.placeholder(tf.int32, [None], name='target_ids') self.target_correctness = tf.placeholder(tf.float32, [None], name='target_correctness') with tf.device('/gpu:0'), tf.variable_scope("KnowledgeGraph", reuse=tf.AUTO_REUSE): # 加载知识图谱 kg_loader = KnowledgeGraphLoader() kg_loader.load() kg_node_features = tf.constant(kg_loader.node_features, dtype=tf.float32) kg_node_features = tf.check_numerics(kg_node_features, "知识图谱节点特征包含NaN或Inf") # 精简GAT层 - 减少层数和维度 gat_output = kg_node_features for i in range(2): # 减少GAT层数为2 with tf.variable_scope(f"GAT_Layer_{i + 1}"): gat_layer = GraphAttentionLayer( input_dim=gat_output.shape[1] if i > 0 else kg_node_features.shape[1], output_dim=24 if i == 0 else 16, # 减少输出维度 kg_loader=kg_loader ) gat_output = gat_layer(gat_output) gat_output = tf.nn.elu(gat_output) self.skill_embeddings = gat_output with tf.variable_scope("FeatureProcessing"): batch_size = tf.shape(self.next)[0] # 动态获取批次大小 # 当前问题嵌入 current_indices = tf.minimum(self.current, kg_loader.static_node_count - 1) current_embed = tf.nn.embedding_lookup(self.skill_embeddings, current_indices) # 构建输入序列 - 移除下一问题嵌入(修复数据泄露) inputs = [] # 使用当前问题作为有效掩码(而不是下一个问题) valid_mask = tf.cast(tf.not_equal(self.current, 0), tf.float32) answers_float = tf.cast(self.next, tf.float32) # 历史表现特征 - 修复符号张量问题 zero_vector = tf.zeros([1, 1], dtype=tf.float32) history = tf.tile(zero_vector, [batch_size, 1]) elapsed_time = tf.tile(zero_vector, [batch_size, 1]) # 循环处理每个时间步 for t in range(self.num_steps): # 创建时间相关的特征 if t > 0: # 计算历史表现(只使用t-1及之前的信息) past_answers = answers_float[:, :t] # 只使用当前时间步之前的信息 past_valid_mask = valid_mask[:, :t] correct_count = tf.reduce_sum(past_answers * past_valid_mask, axis=1, keepdims=True) total_valid = tf.reduce_sum(past_valid_mask, axis=1, keepdims=True) history = correct_count / (total_valid + 1e-8) # 计算经过的时间 elapsed_time = tf.fill([batch_size, 1], tf.cast(t, tf.float32)) # 难度特征 - 使用知识图谱中的准确率特征 # 确保只使用当前问题的特征 difficulty_feature = tf.gather( kg_loader.node_features[:, 0], # 假设第一个特征是准确率 tf.minimum(self.current[:, t], kg_loader.static_node_count - 1) ) difficulty_feature = tf.cast(difficulty_feature, tf.float32) # 情感特征 - 使用知识图谱中的情感特征 affect_features = [] for i in range(1, 5): # 使用前4个情感特征 affect_feature = tf.gather( kg_loader.node_features[:, i], tf.minimum(self.current[:, t], kg_loader.static_node_count - 1) ) affect_feature = tf.cast(affect_feature, tf.float32) affect_features.append(tf.reshape(affect_feature, [-1, 1])) # 组合所有特征 - 移除了下一问题嵌入(修复数据泄露) combined = tf.concat([ current_embed[:, t, :], history, elapsed_time, tf.reshape(difficulty_feature, [-1, 1]), *affect_features ], axis=1) inputs.append(combined) # RNN模型 with tf.variable_scope("RNN"): cell = rnn.LSTMCell( FLAGS.hidden_size, initializer=tf.initializers.glorot_uniform(), forget_bias=1.0 ) if is_training and FLAGS.keep_prob < 1.0: cell = rnn.DropoutWrapper(cell, output_keep_prob=FLAGS.keep_prob) outputs, _ = tf.nn.dynamic_rnn( cell, tf.stack(inputs, axis=1), dtype=tf.float32 ) output = tf.reshape(outputs, [-1, FLAGS.hidden_size]) # 输出层 with tf.variable_scope("Output"): hidden = tf.layers.dense( output, units=32, activation=tf.nn.relu, kernel_initializer=tf.initializers.glorot_uniform(), name="hidden_layer" ) logits = tf.layers.dense( hidden, units=1, kernel_initializer=tf.initializers.glorot_uniform(), name="output_layer" ) # 损失计算 self._all_logits = tf.clip_by_value(logits, -20, 20) selected_logits = tf.gather(tf.reshape(self._all_logits, [-1]), self.target_id) self.pred = tf.clip_by_value(tf.sigmoid(selected_logits), 1e-8, 1 - 1e-8) # 焦点损失 labels = tf.clip_by_value(self.target_correctness, 0.05, 0.95) pos_weight = tf.reduce_sum(1.0 - labels) / (tf.reduce_sum(labels) + 1e-8) bce_loss = tf.nn.weighted_cross_entropy_with_logits( targets=labels, logits=selected_logits, pos_weight=pos_weight ) loss = tf.reduce_mean(bce_loss) # L2正则化 l2_loss = tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name ]) * FLAGS.l2_lambda self.cost = loss + l2_loss # ==================== 数据加载 ==================== def read_data_from_csv_file(path, kg_loader, is_training=False): """更鲁棒的数据加载函数""" students = [] student_ids = [] max_skill = 0 missing_problems = set() # 增强文件存在性检查 if not os.path.exists(path): print(f"❌ 严重错误: 数据文件不存在: {path}") print("请检查以下可能原因:") print("1. 文件路径是否正确") print("2. 文件名是否匹配") print("3. 文件权限是否足够") # 尝试列出目录内容以便调试 dir_path = os.path.dirname(path) print(f"目录内容: {os.listdir(dir_path) if os.path.exists(dir_path) else '目录不存在'}") return [], [], [], 0, 0, 0 try: # 打印正在加载的文件路径 print(f"[数据] 加载数据文件: {path}") # 读取数据集 - 增强CSV读取兼容性 try: data_df = pd.read_csv(path) except Exception as e: print(f"CSV读取失败: {str(e)}") print("尝试使用备用方法读取...") # 尝试不同编码 encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252'] for encoding in encodings: try: data_df = pd.read_csv(path, encoding=encoding) print(f"成功使用 {encoding} 编码读取文件") break except Exception as e: print(f"编码 {encoding} 尝试失败: {str(e)}") continue if 'data_df' not in locals(): print("所有编码尝试失败,无法读取文件") return [], [], [], 0, 0, 0 print(f"[数据] 加载完成 | 记录数: {len(data_df)}") # 检查必要的列是否存在 - 支持多种列名变体 # 可能的列名变体 possible_columns = { 'user_id': ['user_id', 'userid', 'student_id', 'studentid'], 'problem_id': ['problem_id', 'problemid', 'skill_id', 'skillid'], 'correct': ['correct', 'correctness', 'answer', 'accuracy'], 'start_time': ['start_time', 'timestamp', 'time', 'date'] } # 查找实际列名 actual_columns = {} for col_type, possible_names in possible_columns.items(): found = False for name in possible_names: if name in data_df.columns: actual_columns[col_type] = name found = True break if not found: print(f"❌ 错误: 找不到 {col_type} 列") print(f"数据列: {list(data_df.columns)}") return [], [], [], 0, 0, 0 # 重命名列为标准名称以便后续处理 data_df = data_df.rename(columns={ actual_columns['user_id']: 'user_id', actual_columns['problem_id']: 'problem_id', actual_columns['correct']: 'correct', actual_columns['start_time']: 'start_time' }) print(f"[数据] 使用列: user_id, problem_id, correct, start_time") # 按学生分组 grouped = data_df.groupby('user_id') print(f"[数据] 分组完成 | 学生数: {len(grouped)}") for user_id, group in tqdm(grouped, total=len(grouped), desc="处理学生数据"): # 按时间排序 group = group.sort_values('start_time') problems = group['problem_id'].values answers = group['correct'].values.astype(int) # 筛选有效数据 - 添加详细日志 valid_data = [] invalid_count = 0 for i, (p, a) in enumerate(zip(problems, answers)): # 检查问题是否在知识图谱中 if p in kg_loader.problem_to_node and a in (0, 1): # 额外检查:确保问题特征不包含学生作答信息 node_idx = kg_loader.problem_to_node[p] if 'accuracy' in kg_loader.node_features[node_idx]: # 如果特征中包含准确率,警告可能的数据泄露 print(f"警告: 问题 {p} 的特征包含准确率信息,可能导致数据泄露") valid_data.append((p, a)) else: invalid_count += 1 if p != 0 and p not in missing_problems: print(f"警告: 问题ID {p} 不在知识图谱中 (学生: {user_id}, 位置: {i})") missing_problems.add(p) if len(valid_data) < 2: print(f"跳过数据不足的学生 {user_id} (有效交互: {len(valid_data)}, 无效: {invalid_count})") continue # 分割序列 problems, answers = zip(*valid_data) n_split = (len(problems) + FLAGS.problem_len - 1) // FLAGS.problem_len for k in range(n_split): start = k * FLAGS.problem_len end = (k + 1) * FLAGS.problem_len seg_problems = list(problems[start:end]) seg_answers = list(answers[start:end]) # 填充短序列 if len(seg_problems) < FLAGS.problem_len: pad_len = FLAGS.problem_len - len(seg_problems) seg_problems += [0] * pad_len seg_answers += [0] * pad_len # 训练数据增强 if is_training: valid_indices = [i for i, p in enumerate(seg_problems) if p != 0] if len(valid_indices) > 1 and random.random() > 0.5: random.shuffle(valid_indices) seg_problems = [seg_problems[i] for i in valid_indices] + seg_problems[len(valid_indices):] seg_answers = [seg_answers[i] for i in valid_indices] + seg_answers[len(valid_indices):] # 映射问题ID到知识图谱节点 mapped_problems = [] for p in seg_problems: if p == 0: mapped_problems.append(0) elif p in kg_loader.problem_to_node: mapped_problems.append(kg_loader.problem_to_node[p]) else: mapped_problems.append(0) students.append(([user_id, k], mapped_problems, seg_answers)) max_skill = max(max_skill, max(mapped_problems)) student_ids.append(user_id) except Exception as e: print(f"数据加载失败: {str(e)}") import traceback traceback.print_exc() return [], [], [], 0, 0, 0 avg_length = sum(len(s[1]) for s in students) / len(students) if students else 0 print(f"[数据统计] 学生数: {len(student_ids)} | 序列数: {len(students)}") print(f" 最大技能ID: {max_skill} | 平均序列长度: {avg_length:.1f}") print(f" 缺失问题数: {len(missing_problems)}") return students, [], student_ids, max_skill, 0, 0 # ==================== 训练流程 ==================== def run_epoch(session, model, data, run_type, eval_op, global_step=None): preds = [] labels = [] total_loss = 0.0 step = 0 processed_count = 0 total_batches = max(1, len(data) // model.batch_size) with tqdm(total=total_batches, desc=f"{run_type} Epoch") as pbar: index = 0 while index < len(data): # 准备批次数据 current_batch = [] next_batch = [] target_ids = [] target_correctness = [] for i in range(model.batch_size): if index >= len(data): break stu_id, problems, answers = data[index] valid_length = sum(1 for p in problems if p != 0) if valid_length < 1: index += 1 continue current_batch.append(problems) next_batch.append(answers) last_step = valid_length - 1 target_ids.append(i * model.num_steps + last_step) target_correctness.append(answers[last_step]) index += 1 if len(current_batch) == 0: pbar.update(1) step += 1 continue # 创建feed_dict feed = { model.current: np.array(current_batch, dtype=np.int32), model.next: np.array(next_batch, dtype=np.int32), model.target_id: np.array(target_ids, dtype=np.int32), model.target_correctness: np.array(target_correctness, dtype=np.float32) } # 运行计算 try: results = session.run( [model.pred, model.cost, eval_op], feed_dict=feed ) pred, loss = results[:2] preds.extend(pred.tolist()) labels.extend(target_correctness) total_loss += loss * len(current_batch) processed_count += len(current_batch) pbar.set_postfix( loss=f"{loss:.4f}", mem=f"{memory_usage():.1f}MB" ) pbar.update(1) step += 1 except Exception as e: print(f"\n训练错误: {str(e)}") import traceback traceback.print_exc() break # 计算指标 if not labels or not preds: print(f"{run_type}周期: 无有效样本!") return float('nan'), 0.5, 0.0, 0.0 labels = np.array(labels, dtype=np.float32) preds = np.array(preds, dtype=np.float32) mask = np.isfinite(labels) & np.isfinite(preds) if not mask.any(): print(f"{run_type}周期: 所有样本包含无效值!") return float('nan'), 0.5, 0.0, 0.0 labels = labels[mask] preds = preds[mask] try: rmse = np.sqrt(mean_squared_error(labels, preds)) fpr, tpr, _ = roc_curve(labels, preds) auc_score = auc(fpr, tpr) r2 = r2_score(labels, preds) avg_loss = total_loss / processed_count if processed_count > 0 else 0.0 print(f"\n{run_type}周期总结:") print(f" 样本数: {len(labels)} | 正样本比例: {np.mean(labels > 0.5):.3f}") print(f" Loss: {avg_loss:.4f} | RMSE: {rmse:.4f} | AUC: {auc_score:.4f} | R²: {r2:.4f}") # 添加预测值分布分析 print("\n预测值分布分析:") print(f" 最小值: {np.min(preds):.4f} | 最大值: {np.max(preds):.4f}") print(f" 均值: {np.mean(preds):.4f} | 中位数: {np.median(preds):.4f}") print(f" 标准差: {np.std(preds):.4f}") # 检查完美预测的情况 perfect_preds = np.sum((preds < 1e-5) | (preds > 1 - 1e-5)) if perfect_preds > 0: perfect_ratio = perfect_preds / len(preds) print(f" 警告: {perfect_preds}个样本({perfect_ratio*100:.2f}%)预测值为0或1") # 检查预测值是否全部相同 if np.all(preds == preds[0]): print(f" 严重警告: 所有预测值相同 ({preds[0]:.4f})") return rmse, auc_score, r2, avg_loss except Exception as e: print(f"指标计算错误: {str(e)}") return float('nan'), 0.5, 0.0, 0.0 # ==================== 主函数 ==================== def main(_): print(f"[系统] 训练数据路径: {FLAGS.train_data_path}") print(f"[系统] 测试数据路径: {FLAGS.test_data_path}") # 检查文件是否存在 if not os.path.exists(FLAGS.train_data_path): print(f"❌ 训练文件不存在: {FLAGS.train_data_path}") if not os.path.exists(FLAGS.test_data_path): print(f"❌ 测试文件不存在: {FLAGS.test_data_path}") print(f"⚠️ 优化设置: batch_size={FLAGS.batch_size}, hidden_size={FLAGS.hidden_size}, lr={FLAGS.learning_rate}") session_conf = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False, operation_timeout_in_ms=60000 ) session_conf.gpu_options.allow_growth = True with tf.Session(config=session_conf) as sess: # 加载知识图谱 kg_loader = KnowledgeGraphLoader() kg_loader.load() # 加载数据 print("\n[系统] 加载训练数据...") train_data = read_data_from_csv_file(FLAGS.train_data_path, kg_loader, is_training=True) print("[系统] 加载测试数据...") test_data = read_data_from_csv_file(FLAGS.test_data_path, kg_loader) if not train_data[0] or not test_data[0]: print("❌ 错误: 训练或测试数据为空!") return # 模型配置 class ModelConfig: def __init__(self): self.batch_size = FLAGS.batch_size self.num_skills = kg_loader.static_node_count + 100 # 添加缓冲区 self.num_steps = FLAGS.problem_len self.keep_prob = FLAGS.keep_prob model_config = ModelConfig() print(f"[配置] 技能数量: {model_config.num_skills}") print(f"[配置] 序列长度: {model_config.num_steps}") # 构建模型 print("\n[系统] 构建模型...") with tf.variable_scope("Model"): train_model = StudentModel(is_training=True, config=model_config) tf.get_variable_scope().reuse_variables() test_model = StudentModel(is_training=False, config=model_config) # 优化器和训练操作 global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, global_step, DECAY_STEPS, DECAY_RATE, staircase=True ) optimizer = tf.train.AdamOptimizer( learning_rate=learning_rate, epsilon=FLAGS.epsilon ) grads_and_vars = optimizer.compute_gradients(train_model.cost) grads, variables = zip(*grads_and_vars) clipped_grads, _ = tf.clip_by_global_norm(grads, FLAGS.max_grad_norm) train_op = optimizer.apply_gradients(zip(clipped_grads, variables), global_step=global_step) # 初始化变量 sess.run(tf.global_variables_initializer()) print(f"[系统] 训练开始 | 批次: {FLAGS.batch_size} | 学习率: {FLAGS.learning_rate}") # 模型保存 checkpoint_dir = "checkpoints_assist2012" os.makedirs(checkpoint_dir, exist_ok=True) saver = tf.train.Saver(max_to_keep=3) best_auc = 0.0 # 训练循环 for epoch in range(FLAGS.epochs): print(f"\n==== Epoch {epoch + 1}/{FLAGS.epochs} ====") current_lr = sess.run(learning_rate) print(f"[学习率] 当前学习率: {current_lr:.7f}") # 训练 train_rmse, train_auc, train_r2, train_loss = run_epoch( sess, train_model, train_data[0], '训练', train_op ) # 评估 if (epoch + 1) % FLAGS.evaluation_interval == 0: test_rmse, test_auc, test_r2, test_loss = run_epoch( sess, test_model, test_data[0], '测试', tf.no_op() ) # 保存最佳模型 if test_auc > best_auc: best_auc = test_auc save_path = saver.save(sess, f"{checkpoint_dir}/best_model.ckpt") print(f"保存最佳模型: {save_path}, AUC={best_auc:.4f}") print("\n训练完成!") if __name__ == "__main__": tf.app.run() 训练代码的测试集的auc 20轮只达到了0.7658;哪里出了问题,如何提高auc
07-02
再帮我看看这个import rospy import numpy as np import torch from map_manager.srv import RayCast from nav_msgs.msg import Odometry, Path from visualization_msgs.msg import Marker, MarkerArray # 原: # from geometry_msgs.msg import Point, PoseStamped, TwistStamped, Quaternion, Vector3 # 改为(保留 TwistStamped 也行,但 /cmd_vel 用 Twist): from geometry_msgs.msg import Point, PoseStamped, Twist, TwistStamped, Quaternion, Vector3 from mavros_msgs.msg import PositionTarget, State from mavros_msgs.srv import CommandBool, CommandBoolRequest, SetMode, SetModeRequest from navigation_runner.srv import GetSafeAction, GetSafeActionMap from onboard_detector.srv import GetDynamicObstacles from map_manager.srv import GetStaticObstacles from ppo import PPO from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec from tensordict.tensordict import TensorDict from torchrl.envs.utils import ExplorationType, set_exploration_type from navigation_runner.srv import GetPolicyInference from utils import vec_to_new_frame import math from std_srvs.srv import Empty import tf.transformations import time import threading import os class Navigation: def __init__(self, cfg): self.cfg = cfg # === 传感器角分辨率(保留原写法;如果你用相机,这两个值可能不会被用到)=== self.lidar_hbeams = int(360 / self.cfg.sensor.lidar_hres) self.raypoints = [] self.dynamic_obstacles = [] self.robot_size = 0.3 # 半径 self.raycast_vres = ((self.cfg.sensor.lidar_vfov[1] - self.cfg.sensor.lidar_vfov[0])) / \ (self.cfg.sensor.lidar_vbeams - 1) * np.pi / 180.0 self.raycast_hres = self.cfg.sensor.lidar_hres * np.pi / 180.0 # === 任务/可视化状态 === self.goal = None self.goal_received = False self.target_dir = None self.stable_times = 0 self.has_action = False self.laser_points_msg = None # 地面车不用高度控制 self.height_control = False # ★ 地面车默认不用 PX4 self.px4_control = rospy.get_param('rl/use_px4', False) self.use_policy_server = False # === 里程计与控制接口:小车用 /odom + /cmd_vel === self.odom_received = False self.odom_sub = rospy.Subscriber("/odom", Odometry, self.odom_callback) # 对小车发布 Twist(不是 TwistStamped) self.cmd_pub = rospy.Publisher("/cmd_vel", Twist, queue_size=10) # 无人机专用的 MAVROS 状态、arm、OFFBOARD 全部去掉 # self.state_sub / self.set_mode_client / self.arming_client 等不再需要 # self.pose_pub 也不需要(小车不用 setpoint_position) # === 目标/可视化 === self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.goal_callback) self.raycast_vis_pub = rospy.Publisher("/rl_navigation/raycast", MarkerArray, queue_size=10) self.cmd_vis_pub = rospy.Publisher("/rl_navigation/cmd", MarkerArray, queue_size=10) self.goal_vis_pub = rospy.Publisher("/rl_navigation/goal", MarkerArray, queue_size=10) self.rollout_traj_pub = rospy.Publisher("/rollout_traj", Path, queue_size=10) self.dynamic_obstacle_vis_pub = rospy.Publisher("/rl_navigation/in_range_dynamic_obstacles", MarkerArray, queue_size=10) # safety thread self.safety_stop = False safety_thread = threading.Thread(target = self.safety_check) safety_thread.start() # 在 __init__ 末尾(safety_thread.start() 之后几行)加: self.policy = self.init_model() # 如果你走本地推理 # 或者如果你只想用服务端: # self.use_policy_server = True def init_model(self): observation_dim = 8 num_dim_each_dyn_obs_state = 10 observation_spec = CompositeSpec({ "agents": CompositeSpec({ "observation": CompositeSpec({ "state": UnboundedContinuousTensorSpec((observation_dim,), device=self.cfg.device), "lidar": UnboundedContinuousTensorSpec((1, self.lidar_hbeams, self.cfg.sensor.lidar_vbeams), device=self.cfg.device), "direction": UnboundedContinuousTensorSpec((1, 3), device=self.cfg.device), "dynamic_obstacle": UnboundedContinuousTensorSpec((1, self.cfg.algo.feature_extractor.dyn_obs_num, num_dim_each_dyn_obs_state), device=self.cfg.device), }), }).expand(1) }, shape=[1], device=self.cfg.device) action_dim = 3 action_spec = CompositeSpec({ "agents": CompositeSpec({ "action": UnboundedContinuousTensorSpec((action_dim,), device=self.cfg.device), }) }).expand(1, action_dim).to(self.cfg.device) policy = PPO(self.cfg.algo, observation_spec, action_spec, self.cfg.device) file_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpts") checkpoint = "navrl_checkpoint.pt" policy.load_state_dict(torch.load(os.path.join(file_dir, checkpoint), map_location=self.cfg.device)) return policy def safety_check(self): while not rospy.is_shutdown(): if (self.safety_stop == False): input("[nav-ros]: Press Enter to STOP motion!\n") self.safety_stop = True self.stop_pose = PoseStamped() self.stop_pose.pose = self.odom.pose.pose else: input("[nav-ros]: Press Enter to CONTINUE motion!\n") self.safety_stop = False def get_raycast(self, pos: np.array , start_angle: float): raypoints = [] try: raycast = rospy.ServiceProxy("occupancy_map/raycast", RayCast) pos_msg = Point() pos_msg.x = pos[0] pos_msg.y = pos[1] pos_msg.z = pos[2] response = raycast(pos_msg, start_angle, self.cfg.sensor.lidar_range, self.cfg.sensor.lidar_vfov[0], self.cfg.sensor.lidar_vfov[1], self.cfg.sensor.lidar_vbeams, self.cfg.sensor.lidar_hres ) num_points = int(len(response.points)/3) self.laser_points_msg = response.points for i in range(num_points): p = [response.points[3*i+0], response.points[3*i+1], response.points[3*i+2]] raypoints.append(p) except rospy.service.ServiceException as e: print("[nav-ros]: raycast func err!") return raypoints def get_dynamic_obstacles(self, pos: np.array): dynamic_obstacle_pos = torch.zeros(self.cfg.algo.feature_extractor.dyn_obs_num, 3, dtype=torch.float, device=self.cfg.device) dynamic_obstacle_vel = torch.zeros(self.cfg.algo.feature_extractor.dyn_obs_num, 3, dtype=torch.float, device=self.cfg.device) dynamic_obstacle_size = torch.zeros(self.cfg.algo.feature_extractor.dyn_obs_num, 3, dtype=torch.float, device=self.cfg.device) try: distance_range = 4.0 pos_msg = Point() pos_msg.x = pos[0] pos_msg.y = pos[1] pos_msg.z = pos[2] get_obstacle = rospy.ServiceProxy("onboard_detector/get_dynamic_obstacles", GetDynamicObstacles) response = get_obstacle(pos_msg, distance_range) total_obs_num = len(response.position) for i in range(self.cfg.algo.feature_extractor.dyn_obs_num): if (i < total_obs_num): pos_vec = response.position[i] vel_vec = response.velocity[i] size_vec = response.size[i] dynamic_obstacle_pos[i] = torch.tensor([pos_vec.x, pos_vec.y, pos_vec.z], dtype=torch.float, device=self.cfg.device) dynamic_obstacle_vel[i] = torch.tensor([vel_vec.x, vel_vec.y, vel_vec.z], dtype=torch.float, device=self.cfg.device) dynamic_obstacle_size[i] = torch.tensor([size_vec.x, size_vec.y, size_vec.z], dtype=torch.float, device=self.cfg.device) except rospy.service.ServiceException as e: print("[nav-ros]: dynamic obstacle func err!") return dynamic_obstacle_pos, dynamic_obstacle_vel, dynamic_obstacle_size def get_static_obstacles(self): static_obstacle_pos = [] static_obstacle_size = [] static_obstacle_angle = [] try: get_static_obstacles_server = rospy.ServiceProxy("occupancy_map/get_static_obstacles", GetStaticObstacles) static_obstacle_response = get_static_obstacles_server() static_obstacle_pos = static_obstacle_response.position static_obstacle_size = static_obstacle_response.size static_obstacle_angle = static_obstacle_response.angle except rospy.service.ServiceException as e: print("[nav-ros]: static obstacle func err!") return static_obstacle_pos, static_obstacle_size, static_obstacle_angle def raycast_callback(self, event): if not self.odom_received or not self.goal_received: return pos = np.array([self.odom.pose.pose.position.x, self.odom.pose.pose.position.y, self.odom.pose.pose.position.z]) start_angle = np.arctan2(self.target_dir[1].cpu().numpy(), self.target_dir[0].cpu().numpy()) self.raypoints = self.get_raycast(pos, start_angle) def dynamic_obstacle_callback(self, event): if not self.odom_received: return pos = np.array([self.odom.pose.pose.position.x, self.odom.pose.pose.position.y, self.odom.pose.pose.position.z]) dynamic_obstacle_pos, dynamic_obstacle_vel, dynamic_obstacle_size = self.get_dynamic_obstacles(pos) self.dynamic_obstacles = (dynamic_obstacle_pos, dynamic_obstacle_vel, dynamic_obstacle_size) def odom_callback(self, odom): self.odom = odom self.odom_received = True def state_callback(self, state): self.mavros_state = state def goal_callback(self, goal): if not self.odom_received: return self.goal = goal #self.goal.pose.position.z = self.takeoff_pose.pose.position.z dir_x = self.goal.pose.position.x - self.odom.pose.pose.position.x dir_y = self.goal.pose.position.y - self.odom.pose.pose.position.y dir_z = self.goal.pose.position.z - self.odom.pose.pose.position.z self.target_dir = torch.tensor([dir_x, dir_y, dir_z], device=self.cfg.device) self.goal_received = True self.stable_times = 0 def quaternion_to_rotation_matrix(self, quaternion): # w, x, y, z = quaternion w = quaternion.w x = quaternion.x y = quaternion.y z = quaternion.z xx, xy, xz = x**2, x*y, x*z yy, yz = y**2, y*z zz = z**2 wx, wy, wz = w*x, w*y, w*z return np.array([ [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)] ]) def check_obstacle(self, lidar_scan, dyn_obs_states): # return true if there is obstacles in the range # has_static = not torch.all(lidar_scan == 0.) # has_static = not torch.all(lidar_scan[..., 1:] < 0.2) # hardcode to tune quarter_size = lidar_scan.shape[2] // 4 first_quarter_check, last_quarter_check = torch.all(lidar_scan[:, :, :quarter_size, 1:] < 0.2), torch.all(lidar_scan[:, :, -quarter_size:, 1:] < 0.2) has_static = (not first_quarter_check) or (not last_quarter_check) has_dynamic = not torch.all(dyn_obs_states == 0.) return has_static or has_dynamic def get_safe_action(self, vel_world, action_vel_world): safe_action = np.zeros(3) try: pos_msg = Point(x=self.odom.pose.pose.position.x, y=self.odom.pose.pose.position.y, z=self.odom.pose.pose.position.z) get_safe_action = rospy.ServiceProxy("rl_navigation/get_safe_action", GetSafeAction) vel_msg = Vector3(x=vel_world[0].item(), y=vel_world[1].item(), z=vel_world[2].item()) action_vel_msg = Vector3(x=action_vel_world[0], y=action_vel_world[1], z=action_vel_world[2]) max_vel = np.sqrt(3. * self.cfg.algo.actor.action_limit**2) obstacle_pos_list = [] obstacle_vel_list = [] obstacle_size_list = [] for i in range(len(self.dynamic_obstacles[0])): if (self.dynamic_obstacles[2][i][0] != 0): obs_pos = Vector3(x=self.dynamic_obstacles[0][i][0].item(), y=self.dynamic_obstacles[0][i][1].item(), z=self.dynamic_obstacles[0][i][2].item()) obs_vel = Vector3(x=self.dynamic_obstacles[1][i][0].item(), y=self.dynamic_obstacles[1][i][1].item(), z=self.dynamic_obstacles[1][i][2].item()) obs_size = Vector3(x=self.dynamic_obstacles[2][i][0].item(), y=self.dynamic_obstacles[2][i][1].item(), z=self.dynamic_obstacles[2][i][2].item()) obstacle_pos_list.append(obs_pos) obstacle_vel_list.append(obs_vel) obstacle_size_list.append(obs_size) response = get_safe_action(pos_msg, vel_msg, self.robot_size, obstacle_pos_list, obstacle_vel_list,\ obstacle_size_list, self.laser_points_msg, self.cfg.sensor.lidar_range,\ max(self.raycast_vres, self.raycast_hres), max_vel, action_vel_msg) safe_action = np.array([response.safe_action.x, response.safe_action.y, response.safe_action.z]) return safe_action except rospy.service.ServiceException as e: # print("[nav-ros]: no safety running!") return action_vel_world def get_safe_action_map(self, vel_world, action_vel_world): safe_action = np.zeros(3) try: pos_msg = Point(x=self.odom.pose.pose.position.x, y=self.odom.pose.pose.position.y, z=self.odom.pose.pose.position.z) get_safe_action = rospy.ServiceProxy("rl_navigation/get_safe_action_map", GetSafeActionMap) vel_msg = Vector3(x=vel_world[0].item(), y=vel_world[1].item(), z=vel_world[2].item()) action_vel_msg = Vector3(x=action_vel_world[0], y=action_vel_world[1], z=action_vel_world[2]) max_vel = np.sqrt(3. * self.cfg.algo.actor.action_limit**2) # Dynamic Obstacles obstacle_pos_list = [] obstacle_vel_list = [] obstacle_size_list = [] for i in range(len(self.dynamic_obstacles[0])): if (self.dynamic_obstacles[2][i][0] != 0): obs_pos = Vector3(x=self.dynamic_obstacles[0][i][0].item(), y=self.dynamic_obstacles[0][i][1].item(), z=self.dynamic_obstacles[0][i][2].item()) obs_vel = Vector3(x=self.dynamic_obstacles[1][i][0].item(), y=self.dynamic_obstacles[1][i][1].item(), z=self.dynamic_obstacles[1][i][2].item()) obs_size = Vector3(x=self.dynamic_obstacles[2][i][0].item(), y=self.dynamic_obstacles[2][i][1].item(), z=self.dynamic_obstacles[2][i][2].item()) obstacle_pos_list.append(obs_pos) obstacle_vel_list.append(obs_vel) obstacle_size_list.append(obs_size) # Static Obstacles static_obstacle_pos, static_obstacle_size, static_obstacle_angle = self.get_static_obstacles() response = get_safe_action(pos_msg, vel_msg, self.robot_size, obstacle_pos_list, obstacle_vel_list,\ obstacle_size_list, static_obstacle_pos, static_obstacle_size,\ static_obstacle_angle, max_vel, action_vel_msg) safe_action = np.array([response.safe_action.x, response.safe_action.y, response.safe_action.z]) return safe_action except rospy.service.ServiceException as e: # print("[nav-ros]: no safety running!") return action_vel_world def get_action(self, pos: torch.Tensor, vel: torch.Tensor, goal: torch.Tensor): # use world velocity rpos = goal - pos distance = rpos.norm(dim=-1, keepdim=True) distance_2d = rpos[..., :2].norm(dim=-1, keepdim=True) distance_z = rpos[..., 2].unsqueeze(-1) target_dir_2d = self.target_dir.clone() target_dir_2d[2] = 0. rpos_clipped = rpos / distance.clamp(1e-6) # start to goal direction rpos_clipped_g = vec_to_new_frame(rpos_clipped, target_dir_2d).squeeze(0).squeeze(0) # "relative" velocity vel_g = vec_to_new_frame(vel, target_dir_2d).squeeze(0).squeeze(0) # goal velocity # drone_state = torch.cat([rpos_clipped, orientation, vel_g], dim=-1).squeeze(1) drone_state = torch.cat([rpos_clipped_g, distance_2d, distance_z, vel_g], dim=-1).unsqueeze(0) # Lidar States lidar_scan = torch.tensor(self.raypoints, device=self.cfg.device) lidar_scan = (lidar_scan - pos).norm(dim=-1).clamp_max(self.cfg.sensor.lidar_range).reshape(1, 1, self.lidar_hbeams, self.cfg.sensor.lidar_vbeams) lidar_scan = self.cfg.sensor.lidar_range - lidar_scan # dynamic obstacle states dynamic_obstacle_pos = self.dynamic_obstacles[0].clone() dynamic_obstacle_vel = self.dynamic_obstacles[1].clone() dynamic_obstacle_size = self.dynamic_obstacles[2].clone() closest_dyn_obs_rpos = dynamic_obstacle_pos - pos closest_dyn_obs_rpos[dynamic_obstacle_size[:, 2] == 0] = 0. closest_dyn_obs_rpos[:, 2][dynamic_obstacle_size[:, 2] > 1] = 0. closest_dyn_obs_rpos_g = vec_to_new_frame(closest_dyn_obs_rpos.unsqueeze(0), target_dir_2d).squeeze(0) closest_dyn_obs_distance = closest_dyn_obs_rpos.norm(dim=-1, keepdim=True) closest_dyn_obs_distance_2d = closest_dyn_obs_rpos_g[..., :2].norm(dim=-1, keepdim=True) closest_dyn_obs_distance_z = closest_dyn_obs_rpos_g[..., 2].unsqueeze(-1) closest_dyn_obs_rpos_gn = closest_dyn_obs_rpos_g / closest_dyn_obs_distance.clamp(1e-6) closest_dyn_obs_vel_g = vec_to_new_frame(dynamic_obstacle_vel.unsqueeze(0), target_dir_2d).squeeze(0) obs_res = 0.25 closest_dyn_obs_width = torch.max(dynamic_obstacle_size[:, 0], dynamic_obstacle_size[:, 1]) closest_dyn_obs_width += self.robot_size * 2. closest_dyn_obs_width = torch.clamp(torch.ceil(closest_dyn_obs_width / 0.25) - 1, min=0, max=1./obs_res - 1) closest_dyn_obs_width[dynamic_obstacle_size[:, 2] == 0] = 0. closest_dyn_obs_height = dynamic_obstacle_size[:, 2] closest_dyn_obs_height[(closest_dyn_obs_height <= 1) & (closest_dyn_obs_height != 0)] = 1. closest_dyn_obs_height[closest_dyn_obs_height > 1] = 0. # dyn_obs_states = torch.cat([closest_dyn_obs_rpos_g, closest_dyn_obs_vel_g, \ # closest_dyn_obs_width.unsqueeze(1), closest_dyn_obs_height.unsqueeze(1)], dim=-1).unsqueeze(0).unsqueeze(0) dyn_obs_states = torch.cat([closest_dyn_obs_rpos_gn, closest_dyn_obs_distance_2d, closest_dyn_obs_distance_z, closest_dyn_obs_vel_g, \ closest_dyn_obs_width.unsqueeze(1), closest_dyn_obs_height.unsqueeze(1)], dim=-1).unsqueeze(0).unsqueeze(0) # states obs = TensorDict({ "agents": TensorDict({ "observation": TensorDict({ "state": drone_state, "lidar": lidar_scan, "direction": target_dir_2d, "dynamic_obstacle": dyn_obs_states }) }) }) has_obstacle_in_range = self.check_obstacle(lidar_scan, dyn_obs_states) # if (False): if (has_obstacle_in_range): if (not self.use_policy_server): with set_exploration_type(ExplorationType.MEAN): output = self.policy(obs) vel_world = output["agents", "action"] else: try: get_policy_inference = rospy.ServiceProxy("rl_navigation/GetPolicyInference", GetPolicyInference) response = get_policy_inference(obs["agents"]["observation"]["state"].cpu().numpy().flatten().tolist(), obs["agents"]["observation"]["state"].size(), obs["agents"]["observation"]["lidar"].cpu().numpy().flatten().tolist(), obs["agents"]["observation"]["lidar"].size(), obs["agents"]["observation"]["direction"].cpu().numpy().flatten().tolist(), obs["agents"]["observation"]["direction"].size(), obs["agents"]["observation"]["dynamic_obstacle"].cpu().numpy().flatten().tolist(), obs["agents"]["observation"]["dynamic_obstacle"].size()) vel_world = torch.tensor(response.action, device=self.cfg.device, dtype=torch.float).unsqueeze(0).unsqueeze(0) except rospy.service.ServiceException as e: print("[nav-ros]: Policy server err!") vel_world = torch.tensor([0., 0., 0.], device=self.cfg.device).unsqueeze(0).unsqueeze(0) else: vel_world = (goal - pos)/torch.norm(goal - pos) * self.cfg.algo.actor.action_limit return vel_world def get_rollout_traj(self, pos: torch.Tensor, vel: torch.Tensor, goal: torch.Tensor, dt=0.1, horizon=3.0): traj = [pos.cpu().detach().numpy()] t = 0. while (t < horizon): vel_curr_world = self.get_action(pos, vel, goal) t += dt pos = (pos + dt * vel_curr_world).squeeze(0).squeeze(0) vel = vel_curr_world.squeeze(0).squeeze(0) traj.append(pos.cpu().detach().numpy()) return np.array(traj) def control_callback(self, event): if (not self.odom_received): return # 先刷新一次感知数据(用你已有的两个回调) self.raycast_callback(None) self.dynamic_obstacle_callback(None) if (not self.goal_received or len(self.raypoints) == 0 or len(self.dynamic_obstacles) == 0): stop = Twist() self.cmd_pub.publish(stop) return if (self.safety_stop): stop = Twist() self.cmd_pub.publish(stop) return start_time = time.time() # check for angle goal_angle = np.arctan2(self.target_dir[1].cpu().numpy(), self.target_dir[0].cpu().numpy()) _, _, curr_angle = tf.transformations.euler_from_quaternion([self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y, self.odom.pose.pose.orientation.z, self.odom.pose.pose.orientation.w]) angle_diff = np.abs(goal_angle - curr_angle) if (angle_diff > math.pi): angle_diff = np.abs(angle_diff - math.pi * 2) if (angle_diff >= 0.1): twist_align = Twist() ang_err = goal_angle - curr_angle ang_err = (ang_err + math.pi) % (2 * math.pi) - math.pi k_yaw = 1.0 twist_align.angular.z = float(np.clip(k_yaw * ang_err, -0.8, 0.8)) self.cmd_pub.publish(twist_align) return else: self.stable_times += 1 if (self.stable_times <= 10): # 让姿态稳定几拍:发零速 self.cmd_pub.publish(Twist()) return pos = torch.tensor([self.odom.pose.pose.position.x, self.odom.pose.pose.position.y, self.odom.pose.pose.position.z], device=self.cfg.device) goal = torch.tensor([self.goal.pose.position.x, self.goal.pose.position.y, self.goal.pose.position.z], device=self.cfg.device) orientation = torch.tensor([self.odom.pose.pose.orientation.w, self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y, self.odom.pose.pose.orientation.z], device=self.cfg.device) rot = self.quaternion_to_rotation_matrix(self.odom.pose.pose.orientation) vel_body = np.array([self.odom.twist.twist.linear.x, self.odom.twist.twist.linear.y, self.odom.twist.twist.linear.z]) vel_world = torch.tensor(rot @ vel_body, device=self.cfg.device, dtype=torch.float) # world vel # get RL action from model cmd_vel_world = self.get_action(pos, vel_world, goal).squeeze(0).squeeze(0).detach().cpu().numpy() self.cmd_vel_world = cmd_vel_world.copy() # 平面约束 self.cmd_vel_world[2] = 0.0 # get safe action safe_cmd_vel_world = self.get_safe_action(vel_world, self.cmd_vel_world) self.safe_cmd_vel_world = safe_cmd_vel_world.copy() self.safe_cmd_vel_world[2] = 0.0 quat_no_tilt = tf.transformations.quaternion_from_euler(0, 0, curr_angle) quat_msg = Quaternion() quat_msg.w = quat_no_tilt[3] quat_msg.x = quat_no_tilt[0] quat_msg.y = quat_no_tilt[1] quat_msg.z = quat_no_tilt[2] rot_no_tilt = self.quaternion_to_rotation_matrix(quat_msg) safe_cmd_vel_local = np.linalg.inv(rot_no_tilt) @ safe_cmd_vel_world # Goal condition distance = (pos - goal).norm() if (distance <= 3. and distance > 0.3): if (np.linalg.norm(safe_cmd_vel_local) != 0): safe_cmd_vel_local = 0.5 * safe_cmd_vel_local/np.linalg.norm(safe_cmd_vel_local) safe_cmd_vel_world = 0.5 * safe_cmd_vel_world/np.linalg.norm(safe_cmd_vel_world) elif (distance <= 1.0): safe_cmd_vel_local *= 0. safe_cmd_vel_world *= 0. # final action final_cmd_vel = Twist() # 线速度(平面) final_cmd_vel.linear.x = float(safe_cmd_vel_local[0]) final_cmd_vel.linear.y = 0.0 # 如果是差速底盘,这行改成 0.0 final_cmd_vel.linear.z = 0.0 # 简单航向控制,给角速度 ang_err = goal_angle - curr_angle ang_err = (ang_err + math.pi) % (2 * math.pi) - math.pi k_yaw = 1.0 final_cmd_vel.angular.z = float(np.clip(k_yaw * ang_err, -0.8, 0.8)) # 先赋值,后发布 self.action_pub.publish(final_cmd_vel) # rollout_traj = self.get_rollout_traj(pos, vel_world, goal, dt=0.1, horizon=3.0) # traj_msg = Path() # traj_msg.header.frame_id = "map" # for i in range(len(rollout_traj)): # p = PoseStamped() # p.pose.position.x = rollout_traj[i][0] # p.pose.position.y = rollout_traj[i][1] # p.pose.position.z = rollout_traj[i][2] # traj_msg.poses.append(p) # self.rollout_traj_pub.publish(traj_msg) end_time = time.time() # print("[nav-ros]: control time ", end_time - start_time) def pause_sim(): rospy.wait_for_service('/gazebo/pause_physics') pause = rospy.ServiceProxy('/gazebo/pause_physics', Empty) pause() def unpause_sim(): rospy.wait_for_service('/gazebo/unpause_physics') unpause = rospy.ServiceProxy('/gazebo/unpause_physics', Empty) unpause() def run(self): raycast_timer = rospy.Timer(rospy.Duration(0.05), self.raycast_callback) raycast_vis_timer = rospy.Timer(rospy.Duration(0.05), self.raycast_vis_callback) control_timer = rospy.Timer(rospy.Duration(0.05), self.control_callback) goal_vis_timer = rospy.Timer(rospy.Duration(0.05), self.goal_vis_callback) dynamic_obstacle_timer = rospy.Timer(rospy.Duration(0.05), self.dynamic_obstacle_callback) dynamic_obstacle_vis_timer = rospy.Timer(rospy.Duration(0.05), self.dynamic_obstacle_vis_callback) cmd_vis_timer = rospy.Timer(rospy.Duration(0.05), self.cmd_vis_callback) def raycast_vis_callback(self, event): if not self.odom_received and not self.goal_received: return msg = MarkerArray() pos = self.odom.pose.pose.position direction_init = None for i in range(len(self.raypoints)): point = Marker() point.header.frame_id = "map" point.header.stamp = rospy.get_rostime() point.ns = "raycast_points" point.id = i point.type = point.SPHERE point.action = point.ADD point.pose.position.x = self.raypoints[i][0] point.pose.position.y = self.raypoints[i][1] point.pose.position.z = self.raypoints[i][2] point.lifetime = rospy.Time(0.5) point.scale.x = 0.1 point.scale.y = 0.1 point.scale.z = 0.1 point.color.a = 1.0 point.color.r = 1.0 msg.markers.append(point) line = Marker() line.header.frame_id = "map" line.header.stamp = rospy.get_rostime() line.ns = "raycast_lines" line.id = i line.type = line.LINE_LIST p = Point() p.x = self.raypoints[i][0] p.y = self.raypoints[i][1] p.z = self.raypoints[i][2] line.points.append(p) line.points.append(pos) line.scale.x = 0.03 line.scale.y = 0.03 line.scale.z = 0.03 x_diff = (p.x - self.odom.pose.pose.position.x) y_diff = (p.y - self.odom.pose.pose.position.y) direction = np.array([x_diff, y_diff]) direction = direction/np.linalg.norm(direction) if (i == 0 or (np.linalg.norm(direction - direction_init) <= 0.1)): line.color.b = 1.0 line.color.a = 1.0 if (i == 0): direction_init = direction else: line.color.g = 1.0 line.color.a = 0.5 line.lifetime = rospy.Time(0.5) msg.markers.append(line) self.raycast_vis_pub.publish(msg) def goal_vis_callback(self, event): if not self.goal_received: return msg = MarkerArray() goal_point = Marker() goal_point.header.frame_id = "map" goal_point.header.stamp = rospy.get_rostime() goal_point.ns = "goal_point" goal_point.id = 1 goal_point.type = goal_point.SPHERE goal_point.action = goal_point.ADD goal_point.pose.position.x = self.goal.pose.position.x goal_point.pose.position.y = self.goal.pose.position.y goal_point.pose.position.z = self.goal.pose.position.z goal_point.lifetime = rospy.Time(0.1) goal_point.scale.x = 0.3 goal_point.scale.y = 0.3 goal_point.scale.z = 0.3 goal_point.color.r = 1.0 goal_point.color.b = 1.0 goal_point.color.a = 1.0 msg.markers.append(goal_point) self.goal_vis_pub.publish(msg) def dynamic_obstacle_vis_callback(self, event): if (len(self.dynamic_obstacles) == 0): return dynamic_obstacle_pos = self.dynamic_obstacles[0] dynamic_obstacle_size = self.dynamic_obstacles[2] msg = MarkerArray() for i in range(dynamic_obstacle_pos.size(0)): pos = dynamic_obstacle_pos[i] size = dynamic_obstacle_size[i] # Increase the width width = torch.max(size[0], size[1]) height = size[2] # Create the marker marker = Marker() marker.header.frame_id = "map" marker.header.stamp = rospy.Time.now() marker.ns = "dynamic_obstacles" marker.id = i marker.type = Marker.CUBE marker.action = Marker.ADD marker.pose.position.x = pos[0] marker.pose.position.y = pos[1] marker.pose.position.z = pos[2] marker.pose.orientation.x = 0.0 marker.pose.orientation.y = 0.0 marker.pose.orientation.z = 0.0 marker.pose.orientation.w = 1.0 marker.scale.x = width marker.scale.y = width marker.scale.z = height marker.color.a = 0.5 # Alpha value marker.color.r = 1.0 # Red color marker.color.g = 0.0 marker.color.b = 0.0 msg.markers.append(marker) # Publish the marker array self.dynamic_obstacle_vis_pub.publish(msg) def cmd_vis_callback(self, event): if (not self.has_action): return msg = MarkerArray() # rl action vis rl_action_arrow = Marker() rl_action_arrow.header.frame_id = "map" rl_action_arrow.header.stamp = rospy.get_rostime() rl_action_arrow.ns = "rl_action" rl_action_arrow.id = 0 rl_action_arrow.type = rl_action_arrow.ARROW rl_action_arrow.action = rl_action_arrow.ADD # start agent_pos = Point() agent_pos.x = self.odom.pose.pose.position.x agent_pos.y = self.odom.pose.pose.position.y agent_pos.z = self.odom.pose.pose.position.z # end vel_end = Point() vel_end.x = self.cmd_vel_world[0] + agent_pos.x vel_end.y = self.cmd_vel_world[1] + agent_pos.y vel_end.z = self.cmd_vel_world[2] + agent_pos.z rl_action_arrow.points.append(agent_pos) rl_action_arrow.points.append(vel_end) rl_action_arrow.lifetime = rospy.Duration(0.1) rl_action_arrow.scale.x = 0.06 rl_action_arrow.scale.y = 0.06 rl_action_arrow.scale.z = 0.06 rl_action_arrow.color.a = 1.0 rl_action_arrow.color.r = 1.0 rl_action_arrow.color.g = 0.0 rl_action_arrow.color.b = 0.0 msg.markers.append(rl_action_arrow) # safe action vis safe_action_arrow = Marker() safe_action_arrow.header.frame_id = "map" safe_action_arrow.header.stamp = rospy.get_rostime() safe_action_arrow.ns = "safe_action" safe_action_arrow.id = 1 safe_action_arrow.type = safe_action_arrow.ARROW safe_action_arrow.action = safe_action_arrow.ADD # start agent_pos = Point() agent_pos.x = self.odom.pose.pose.position.x agent_pos.y = self.odom.pose.pose.position.y agent_pos.z = self.odom.pose.pose.position.z # end vel_end = Point() vel_end.x = self.safe_cmd_vel_world[0] + agent_pos.x vel_end.y = self.safe_cmd_vel_world[1] + agent_pos.y vel_end.z = self.safe_cmd_vel_world[2] + agent_pos.z safe_action_arrow.points.append(agent_pos) safe_action_arrow.points.append(vel_end) safe_action_arrow.lifetime = rospy.Duration(0.1) safe_action_arrow.scale.x = 0.06 safe_action_arrow.scale.y = 0.06 safe_action_arrow.scale.z = 0.06 safe_action_arrow.color.a = 1.0 safe_action_arrow.color.r = 0.0 safe_action_arrow.color.g = 1.0 safe_action_arrow.color.b = 0.0 msg.markers.append(safe_action_arrow) self.cmd_vis_pub.publish(msg)
11-09
在 PyTorch 或 TensorFlow 中,`prompt_embed.reshape()` 和 `prompt_embed.shape` 是用于张量形状操作的常见方法。以下是它们的详细说明和示例: --- ### 1. **`prompt_embed.reshape()`** 用于改变张量的形状(维度布局),**不改变数据本身**。 **语法**: ```python reshaped_tensor = prompt_embed.reshape(new_shape) ``` - **`new_shape`**:可以是整数或元组(如 `(a, b, c)`),支持 `-1` 自动推断维度。 **示例**: ```python import torch # 假设 prompt_embed 是一个 2D 张量 prompt_embed = torch.tensor([[1, 2, 3], [4, 5, 6]]) print("Original shape:", prompt_embed.shape) # 输出: torch.Size([2, 3]) # 调整为 3x2 reshaped = prompt_embed.reshape(3, 2) print("Reshaped:\n", reshaped) # 输出: # tensor([[1, 2], # [3, 4], # [5, 6]]) # 使用 -1 自动推断 reshaped_auto = prompt_embed.reshape(-1, 6) # 调整为 1x6 print("Auto-reshaped:\n", reshaped_auto) ``` --- ### 2. **`prompt_embed.shape`** 返回张量的当前形状(属性,不是方法,无需括号)。 **语法**: ```python current_shape = prompt_embed.shape # 注意:无括号! ``` - 返回一个元组(如 `(2, 3)`),表示各维度大小。 **示例**: ```python print("Shape:", prompt_embed.shape) # 输出: torch.Size([2, 3]) ``` --- ### 常见问题 1. **错误用法**: ```python # 错误!shape 是属性,不是方法 wrong = prompt_embed.shape() # 会报错:TypeError ``` 2. **与 `view()` 的区别**: - `reshape()` 更灵活,可能返回拷贝(非连续内存时)。 - `view()` 要求张量内存连续,否则需先调用 `.contiguous()`。 3. **维度兼容性**: 新形状的元素总数必须与原形状一致(除非用 `-1` 自动推断)。 --- ### 完整示例 ```python import torch # 创建张量 prompt_embed = torch.arange(6).reshape(2, 3) # 2x3 print("Original:\n", prompt_embed) print("Shape:", prompt_embed.shape) # torch.Size([2, 3]) # 调整形状 reshaped = prompt_embed.reshape(3, 2) print("Reshaped:\n", reshaped) # 错误示例(元素总数不匹配) try: invalid = prompt_embed.reshape(4, 4) # 报错 except RuntimeError as e: print("Error:", e) ``` ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值