Tensorlow 中文API:tf.zeros() tf.ones()tf.fill()tf.constant()

本文详细介绍了TensorFlow中生成常量张量的各种方法,包括使用tf.zeros(), tf.ones(), tf.fill() 和 tf.constant() 函数的具体应用及注意事项。
部署运行你感兴趣的模型镜像

在程序中有一处不理解的地方 import basic.util.prints这个basic包找不到也搜不到,有知道的帮忙留言,谢谢

可以在下面使用print(data.eval())来输出结果


Tensors常量值函数
  • tf.zeros(shape, dtype=tf.float32, name=None)
  • tf.zeros_like(tensor, dtype=None, name=None)
  • tf.ones(shape, dtype=tf.float32, name=None)
  • tf.ones_like(tensor, dtype=None, name=None)
  • tf.fill(dims, value, name=None)
  • tf.constant(value, dtype=None, shape=None, name='Const')

在Tensorflow中,任何参数的运算和操作都需要转换成对应的TensorFlow数据类型,例如现实中的字符串类型是不能直接在TensorFlow中进行运算的,必须转换成对应的TensorFlow类型才行

故而Tensorflow提供了一些函数用于生成常量值.如上是一些基本的生成常量的方法.

tf.zeros(shape, dtype=tf.float32, name=None)

创建一个所有的参数为0的tensor对象

This operation returns a tensor of type dtype with shape shape and all elements set to zero.

这个操作会返回一个类型为dtype,并且维度为sharp的tensor,并且所有的参数均为0.

参数:
  • shape: 用于表示维度,通常为一个int32类型数组,或者一个一维(1-D)的tf.int32数字.注意不能直接使用数字
  • dtype: 所要创建的tensor对象的数据类型
  • name: 一个该操作的别名(可选的).
返回:

所有参数都为0的tensor对象

用例以及结果:

#coding=utf8
import tensorflow as tf
import basic.util.prints as p

sess = tf.InteractiveSession()

# 创建一个维度为1, 类型为int的对象
data = tf.zeros([1], dtype=tf.int32)
p.printValue("sess.zeros([1], dtype=tf.int32)", data)
# 创建一个维度为3, 类型为int的对象
data = tf.zeros([1,2,1], dtype=tf.int32)
p.printValue("sess.zeros([3,4,5], dtype=tf.int32)", data)
# double
data = tf.zeros([8], dtype=tf.double)
p.printValue("sess.zeros([1], dtype=tf.double)", data)
# float
data = tf.zeros([8], dtype=tf.float16)
p.printValue("sess.zeros([1], dtype=tf.float16)", data)


sess.close()
# sess.zeros([1], dtype=tf.int32) : Tensor("zeros:0", shape=(1,), dtype=int32) - 
[0]
# sess.zeros([3,4,5], dtype=tf.int32) : Tensor("zeros_1:0", shape=(1, 2, 1), dtype=int32) - 
[[[0]
  [0]]]
# sess.zeros([1], dtype=tf.double) : Tensor("zeros_2:0", shape=(8,), dtype=float64) - 
[ 0.  0.  0.  0.  0.  0.  0.  0.]
# sess.zeros([1], dtype=tf.float16) : Tensor("zeros_3:0", shape=(8,), dtype=float16) - 
[ 0.  0.  0.  0.  0.  0.  0.  0.]
另外,需要注意的是不能直接使用数字作为shape的参数,例如下面用例就会报错:
...
# sharp不能为int类型,需要制定为tensor sharp类型
data = tf.zeros(1, dtype=tf.int32)
p.printValue("tf.zeros(1, dtype=tf.int32)", data)
...
Traceback (most recent call last):
  File "/services/git/GIthub/ml-example/tensorflow/basic/casting/string_to_number.py", line 20, in <module>
    data = tf.zeros(1, dtype=tf.int32)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 623, in zeros
    output = fill(shape, constant(0, dtype=dtype), name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 531, in fill
    return _op_def_lib.apply_op("Fill", dims=dims, value=value, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 655, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2156, in create_op
    set_shapes_for_outputs(ret)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1612, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 1165, in _FillShape
    dimensions_shape = op.inputs[0].get_shape().with_rank(1)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/tensor_shape.py", line 625, in with_rank
    raise ValueError("Shape %s must have rank %d" % (self, rank))
ValueError: Shape () must have rank 1
这个时候可以通过如下两个方式解决
# 可是使用如下方法替换
data = tf.zeros([1], dtype=tf.int32)
p.printValue("tf.zeros([1], dtype=tf.int32)", data)

tf.zeros_like(tensor, dtype=None, name=None)

该方法用于创建一个所有参数均为0的tensor对象

给定一个tensor(tensor对象),该方法会返回一个类似当前参数类型以及维度的对象,但是所有参数的值均为0.当参数dtype选定了后,所有返回参数的类型也会变成选定的类型

该方法实际上为一个拷贝函数:默认情况下,它会拷贝参数tensor的类型,维度等数据,并将其中的值设置为0.当参数dtype设置后,那么拷贝后的tensor对象

参数:

  • tensor: tensor对象
  • dtype: 返回的tensor对象类型,不设置(为空时)时返回类型同参数tensor一致.该参数必须为如下tensorflow类型: float32, float64, int8, int16, int32, int64, uint8以及complex64.

  • name: 该操作别名 (可选).

返回:

所有参数为0的tensor对象

测试用例与结果如下:

#coding=utf8
import tensorflow as tf
import basic.util.prints as p

sess = tf.InteractiveSession()

# create a tensor object
original = [[1,2,3],[4,5,6]]
p.printRowValue("Original value", original)

# 调用zeros_like,默认类型为int
data = tf.zeros_like(original)
p.printValue("tf.zeros_like(original)", data)

# 调用zeros_like,默认类型为double
data = tf.zeros_like(original, dtype=tf.double)
p.printValue("tf.zeros_like(original, dtype=tf.double)", data)

# dtype类型为: float, int, double, uint以及complex(复数), 如下类型支持
data = tf.zeros_like(original, dtype=tf.float16)
data = tf.zeros_like(original, dtype=tf.float16_ref)
data = tf.zeros_like(original, dtype=tf.float32)
data = tf.zeros_like(original, dtype=tf.float32_ref)
data = tf.zeros_like(original, dtype=tf.float64)
data = tf.zeros_like(original, dtype=tf.float64_ref)
data = tf.zeros_like(original, dtype=tf.int8)
data = tf.zeros_like(original, dtype=tf.int8_ref)
data = tf.zeros_like(original, dtype=tf.int16)
data = tf.zeros_like(original, dtype=tf.int16_ref)
data = tf.zeros_like(original, dtype=tf.int32)
data = tf.zeros_like(original, dtype=tf.int32_ref)
data = tf.zeros_like(original, dtype=tf.int64)
data = tf.zeros_like(original, dtype=tf.int64_ref)
data = tf.zeros_like(original, dtype=tf.uint8)
data = tf.zeros_like(original, dtype=tf.uint8_ref)
data = tf.zeros_like(original, dtype=tf.uint16)
data = tf.zeros_like(original, dtype=tf.uint16_ref)
data = tf.zeros_like(original, dtype=tf.double)
data = tf.zeros_like(original, dtype=tf.double_ref)
data = tf.zeros_like(original, dtype=tf.complex64)
data = tf.zeros_like(original, dtype=tf.complex64_ref)
data = tf.zeros_like(original, dtype=tf.complex128)
data = tf.zeros_like(original, dtype=tf.complex128_ref)

# [ERROR] 不支持类型包括: bfloat, qint, quint
# data = tf.zeros_like(original, dtype=tf.bfloat16)
# data = tf.zeros_like(original, dtype=tf.quint8)
# data = tf.zeros_like(original, dtype=tf.qint16)
# data = tf.zeros_like(original, dtype=tf.qint32)


sess.close()
# Original value : [[1, 2, 3], [4, 5, 6]]
# tf.zeros_like(original) : Tensor("zeros_like:0", shape=(2, 3), dtype=int32) - 
[[0 0 0]
 [0 0 0]]
# tf.zeros_like(original, dtype=tf.double) : Tensor("zeros_like_1:0", shape=(2, 3), dtype=float64) - 
[[ 0.  0.  0.]
 [ 0.  0.  0.]]
该方法在参数复制并置0的时候非常有用

tf.ones(shape, dtype=tf.float32, name=None)

创建一个所有的参数为1的tensor对象

这个操作会返回一个类型为dtype,并且维度为sharp的tensor,并且所有的参数均为0.

参数:

  • shape: 用于表示维度,通常为一个int32类型数组,或者一个一维(1-D)的tf.int32数字.注意不能直接使用数字
  • dtype: 所要创建的tensor对象的数据类型
  • name: 一个该操作的别名(可选的).

返回:

所有参数都为1的tensor对象

用例以及结果:

#coding=utf8
import tensorflow as tf
import basic.util.prints as p

sess = tf.InteractiveSession()

# 创建一个维度为1, 类型为int的对象
data = tf.ones([1], dtype=tf.int32)
p.printValue("sess.ones([1], dtype=tf.int32)", data)
# 创建一个维度为3, 类型为int的对象
data = tf.ones([1,2,1], dtype=tf.int32)
p.printValue("sess.ones([3,4,5], dtype=tf.int32)", data)
# double
data = tf.ones([8], dtype=tf.double)
p.printValue("sess.ones([1], dtype=tf.double)", data)
# float
data = tf.ones([8], dtype=tf.float16)
p.printValue("sess.ones([1], dtype=tf.float16)", data)

# [ERROR] sharp不能为int类型,需要制定为tensor sharp类型
# data = tf.ones(1, dtype=tf.int32)
# p.printValue("tf.ones(1, dtype=tf.int32)", data)

# 可是使用如下方法替换
data = tf.ones([1], dtype=tf.int32)
p.printValue("tf.ones([1], dtype=tf.int32)", data)

sess.close()
执行返回结果
# sess.ones([1], dtype=tf.int32) : Tensor("ones:0", shape=(1,), dtype=int32) - 
[1]
# sess.ones([3,4,5], dtype=tf.int32) : Tensor("ones_1:0", shape=(1, 2, 1), dtype=int32) - 
[[[1]
  [1]]]
# sess.ones([1], dtype=tf.double) : Tensor("ones_2:0", shape=(8,), dtype=float64) - 
[ 1.  1.  1.  1.  1.  1.  1.  1.]
# sess.ones([1], dtype=tf.float16) : Tensor("ones_3:0", shape=(8,), dtype=float16) - 
[ 1.  1.  1.  1.  1.  1.  1.  1.]
# tf.ones([1], dtype=tf.int32) : Tensor("ones_4:0", shape=(1,), dtype=int32) - [1]

tf.ones_like(tensor, dtype=None, name=None)

该方法用于创建一个所有参数均为1的tensor对象

给定一个tensor(tensor对象),该方法会返回一个类似当前参数类型以及维度的对象,但是所有参数的值均为1.当参数dtype选定了后,所有返回参数的类型也会变成选定的类型

该方法实际上为一个拷贝函数:默认情况下,它会拷贝参数tensor的类型,维度等数据,并将其中的值设置为1.当参数dtype设置后,那么拷贝后的tensor对象

参数:

  • tensor: tensor对象
  • dtype: 返回的tensor对象类型,不设置(为空时)时返回类型同参数tensor一致.该参数必须为如下tensorflow类型: float32, float64, int8, int16, int32, int64, uint8以及complex64.

  • name: 该操作别名 (可选).

返回:

所有参数为1的tensor对象

测试用例与结果如下:

#coding=utf8
import tensorflow as tf
import basic.util.prints as p

sess = tf.InteractiveSession()

# create a tensor object
original = [1,2,3,4,5]
# original = tf.zeros([5], dtype=tf.float16)
p.printRowValue("Original value", original)

# 调用ones_like,默认类型为int
data = tf.ones_like(original)
p.printValue("tf.ones_like(original)", data)

# 调用ones_like,默认类型为double
data = tf.ones_like(original, dtype=tf.double)
p.printValue("tf.ones_like(original, dtype=tf.double)", data)

# dtype类型为: float, int, double, uint以及complex(复数), 如下类型支持

data = tf.ones_like(original, dtype=tf.float32)
p.printValue("tf.ones_like(original, dtype=tf.float32)", data)
data = tf.ones_like(original, dtype=tf.float32_ref)
p.printValue("tf.ones_like(original, dtype=tf.float32_ref)", data)
data = tf.ones_like(original, dtype=tf.float64)
p.printValue("tf.ones_like(original, dtype=tf.float64)", data)
data = tf.ones_like(original, dtype=tf.float64_ref)
p.printValue("tf.ones_like(original, dtype=tf.float64_ref)", data)
data = tf.ones_like(original, dtype=tf.int8)
p.printValue("tf.ones_like(original, dtype=tf.int8)", data)
data = tf.ones_like(original, dtype=tf.int8_ref)
p.printValue("tf.ones_like(original, dtype=tf.int8_ref)", data)
data = tf.ones_like(original, dtype=tf.int16)
p.printValue("tf.ones_like(original, dtype=tf.int16)", data)
data = tf.ones_like(original, dtype=tf.int16_ref)
p.printValue("tf.ones_like(original, dtype=tf.int16_ref)", data)
data = tf.ones_like(original, dtype=tf.int32)
p.printValue("tf.ones_like(original, dtype=tf.int32)", data)
data = tf.ones_like(original, dtype=tf.int32_ref)
p.printValue("tf.ones_like(original, dtype=tf.int32_ref)", data)
data = tf.ones_like(original, dtype=tf.int64)
p.printValue("tf.ones_like(original, dtype=tf.int64)", data)
data = tf.ones_like(original, dtype=tf.int64_ref)
p.printValue("tf.ones_like(original, dtype=tf.int64_ref)", data)
data = tf.ones_like(original, dtype=tf.uint8)
p.printValue("tf.ones_like(original, dtype=tf.uint8)", data)
data = tf.ones_like(original, dtype=tf.uint8_ref)
p.printValue("tf.ones_like(original, dtype=tf.uint8_ref)", data)

data = tf.ones_like(original, dtype=tf.double)
p.printValue("tf.ones_like(original, dtype=tf.double)", data)
data = tf.ones_like(original, dtype=tf.double_ref)
p.printValue("tf.ones_like(original, dtype=tf.double_ref)", data)
data = tf.ones_like(original, dtype=tf.complex64)
p.printValue("tf.ones_like(original, dtype=tf.complex64)", data)
data = tf.ones_like(original, dtype=tf.complex64_ref)
p.printValue("tf.ones_like(original, dtype=tf.complex64_ref)", data)
data = tf.ones_like(original, dtype=tf.complex128)
p.printValue("tf.ones_like(original, dtype=tf.complex128)", data)
data = tf.ones_like(original, dtype=tf.complex128_ref)
p.printValue("tf.ones_like(original, dtype=tf.complex128_ref)", data)

# 特殊情况
data = tf.ones_like(original, dtype=tf.float16)
# p.printValue("tf.ones_like(original, dtype=tf.float16)", data)
data = tf.ones_like(original, dtype=tf.float16_ref)
# p.printValue("tf.ones_like(original, dtype=tf.float16_ref)", data)
data = tf.ones_like(original, dtype=tf.uint16)
# p.printValue("tf.ones_like(original, dtype=tf.uint16)", data)
data = tf.ones_like(original, dtype=tf.uint16_ref)
# p.printValue("tf.ones_like(original, dtype=tf.uint16_ref)", data)

# [ERROR] 不支持类型包括: bfloat, qint, quint
# data = tf.ones_like(original, dtype=tf.bfloat16)
# data = tf.ones_like(original, dtype=tf.quint8)
# data = tf.ones_like(original, dtype=tf.qint16)
# data = tf.ones_like(original, dtype=tf.qint32)


sess.close()

运行结果:

# Original value : [1, 2, 3, 4, 5]
# tf.ones_like(original) : Tensor("ones_like:0", shape=(5,), dtype=int32) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.double) : Tensor("ones_like_1:0", shape=(5,), dtype=float64) - 
[ 1.  1.  1.  1.  1.]
# tf.ones_like(original, dtype=tf.float32) : Tensor("ones_like_2:0", shape=(5,), dtype=float32) - 
[ 1.  1.  1.  1.  1.]
# tf.ones_like(original, dtype=tf.float32_ref) : Tensor("ones_like_3:0", shape=(5,), dtype=float32) - 
[ 1.  1.  1.  1.  1.]
# tf.ones_like(original, dtype=tf.float64) : Tensor("ones_like_4:0", shape=(5,), dtype=float64) - 
[ 1.  1.  1.  1.  1.]
# tf.ones_like(original, dtype=tf.float64_ref) : Tensor("ones_like_5:0", shape=(5,), dtype=float64) - 
[ 1.  1.  1.  1.  1.]
# tf.ones_like(original, dtype=tf.int8) : Tensor("ones_like_6:0", shape=(5,), dtype=int8) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.int8_ref) : Tensor("ones_like_7:0", shape=(5,), dtype=int8) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.int16) : Tensor("ones_like_8:0", shape=(5,), dtype=int16) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.int16_ref) : Tensor("ones_like_9:0", shape=(5,), dtype=int16) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.int32) : Tensor("ones_like_10:0", shape=(5,), dtype=int32) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.int32_ref) : Tensor("ones_like_11:0", shape=(5,), dtype=int32) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.int64) : Tensor("ones_like_12:0", shape=(5,), dtype=int64) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.int64_ref) : Tensor("ones_like_13:0", shape=(5,), dtype=int64) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.uint8) : Tensor("ones_like_14:0", shape=(5,), dtype=uint8) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.uint8_ref) : Tensor("ones_like_15:0", shape=(5,), dtype=uint8) - 
[1 1 1 1 1]
# tf.ones_like(original, dtype=tf.double) : Tensor("ones_like_16:0", shape=(5,), dtype=float64) - 
[ 1.  1.  1.  1.  1.]
# tf.ones_like(original, dtype=tf.double_ref) : Tensor("ones_like_17:0", shape=(5,), dtype=float64) - 
[ 1.  1.  1.  1.  1.]
# tf.ones_like(original, dtype=tf.complex64) : Tensor("ones_like_18:0", shape=(5,), dtype=complex64) - 
[ 1.+0.j  1.+0.j  1.+0.j  1.+0.j  1.+0.j]
# tf.ones_like(original, dtype=tf.complex64_ref) : Tensor("ones_like_19:0", shape=(5,), dtype=complex64) - 
[ 1.+0.j  1.+0.j  1.+0.j  1.+0.j  1.+0.j]
# tf.ones_like(original, dtype=tf.complex128) : Tensor("ones_like_20:0", shape=(5,), dtype=complex128) - 
[ 1.+0.j  1.+0.j  1.+0.j  1.+0.j  1.+0.j]
# tf.ones_like(original, dtype=tf.complex128_ref) : Tensor("ones_like_21:0", shape=(5,), dtype=complex128) - 
[ 1.+0.j  1.+0.j  1.+0.j  1.+0.j  1.+0.j]


tf.fill(dims, value, name=None)

创建一个维度为dims,值为value的tensor对象.该操作会创建一个维度为dims的tensor对象,并将其值设置为value,该tensor对象中的值类型和value一致

    • 当value为0时,该方法等同于tf.zeros()
    • 当value为1时,该方法等同于tf.ones()
参数:
  • dims: 类型为int32的tensor对象,用于表示输出的维度(1-D, n-D),通常为一个int32数组,如:[1], [2,3]等
  • value: 常量值(字符串,数字等),该参数用于设置到最终返回的tensor对象值中
  • name: 当前操作别名(可选)
返回:

tensor对象,类型和value一致

测试用例如下:

#coding=utf8
import tensorflow as tf
import basic.util.prints as p

sess = tf.InteractiveSession()

dim = [2,3]
data = tf.fill(dim, 5)
p.printValue("tf.fill(dim, value)", data)
data = tf.fill(dim, 5.0)
p.printValue("tf.fill(dim, value)", data)
data = tf.fill(dim, "5.0")
p.printValue("tf.fill(dim, value)", data)

sess.close()

运行返回如下:

# tf.fill(dim, value) : Tensor("Fill:0", shape=(2, 3), dtype=int32) - 
[[5 5 5]
 [5 5 5]]
# tf.fill(dim, value) : Tensor("Fill_1:0", shape=(2, 3), dtype=float32) - 
[[ 5.  5.  5.]
 [ 5.  5.  5.]]
# tf.fill(dim, value) : Tensor("Fill_2:0", shape=(2, 3), dtype=string) - 
[['5.0' '5.0' '5.0']
 ['5.0' '5.0' '5.0']]

tf.constant(value,dtype=None,shape=None,name=’Const’) 

创建一个常量tensor,按照给出value来赋值,可以用shape来指定其形状。value可以是一个数,也可以是一个list。  如果是一个数,那么这个常亮中所有值的按该数来赋值。  如果是list,那么len(value)一定要小于等于shape展开后的长度。赋值时,先将value中的值逐个存入。不够的部分,则全部存入value的最后一个值。

a = tf.constant(2,shape=[2])
b = tf.constant(2,shape=[2,2])
c = tf.constant([1,2,3],shape=[6])
d = tf.constant([1,2,3],shape=[3,2])

sess = tf.InteractiveSession()
print(sess.run(a))
#[2 2]
print(sess.run(b))
#[[2 2]
# [2 2]]
print(sess.run(c))
#[1 2 3 3 3 3]
print(sess.run(d))
#[[1 2]
# [3 3]
# [3 3]]




您可能感兴趣的与本文相关的镜像

TensorFlow-v2.9

TensorFlow-v2.9

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

# -*- coding: utf-8 -*- """ DKT-DSC for Assistment2012 (优化版) - 修复数据泄露问题 最后更新: 2024-07-01 """ import os import sys import numpy as np import tensorflow.compat.v1 as tf os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" config = tf.ConfigProto() config.gpu_options.allow_growth = True tf.disable_v2_behavior() # 安全导入psutil模块 try: import psutil HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False print("警告: psutil模块未安装,内存监控功能受限") from scipy.sparse import coo_matrix from tensorflow.contrib import rnn import pandas as pd from tqdm import tqdm from sklearn.metrics import mean_squared_error, r2_score, roc_curve, auc import math import random # ==================== 配置部分 ==================== # 使用实际数据路径 DATA_BASE_PATH = '/home/yhh/students/jianglu/DKT2/DKT/data/' data_name = 'Assist_2012' # 修正数据集名称 KNOWLEDGE_GRAPH_PATHS = { 'graphml': './output_assist2012_gat_improved/knowledge_graph.graphml', 'nodes': './output_assist2012_gat_improved/graph_nodes.csv', 'edges': './output_assist2012_gat_improved/graph_edges.csv' } # ==================== Flags配置 ==================== tf.flags.DEFINE_float("epsilon", 1e-8, "Adam优化器的epsilon值") tf.flags.DEFINE_float("l2_lambda", 0.005, "L2正则化系数") # 减小正则化强度 tf.flags.DEFINE_float("learning_rate", 1e-4, "学习率") tf.flags.DEFINE_float("max_grad_norm", 3.0, "梯度裁剪阈值") # 更严格的梯度裁剪 tf.flags.DEFINE_float("keep_prob", 0.8, "Dropout保留概率") # 减小dropout tf.flags.DEFINE_integer("hidden_layer_num", 1, "隐藏层数量") tf.flags.DEFINE_integer("hidden_size", 48, "隐藏层大小") # 增加隐藏层大小 tf.flags.DEFINE_integer("evaluation_interval", 2, "评估间隔周期数") tf.flags.DEFINE_integer("batch_size", 128, "批次大小") tf.flags.DEFINE_integer("problem_len", 15, "问题序列长度") # 增加序列长度 tf.flags.DEFINE_integer("epochs", 100, "训练周期数") tf.flags.DEFINE_boolean("allow_soft_placement", True, "允许软设备放置") tf.flags.DEFINE_boolean("log_device_placement", False, "记录设备放置信息") tf.flags.DEFINE_string("train_data_path", f'{DATA_BASE_PATH}{data_name}_train.csv', "训练数据路径") tf.flags.DEFINE_string("test_data_path", f'{DATA_BASE_PATH}{data_name}_test.csv', "测试数据路径") FLAGS = tf.flags.FLAGS # 焦点损失参数 FOCAL_LOSS_GAMMA = 1.5 # 调整焦点损失参数 FOCAL_LOSS_ALPHA = 0.3 # 学习率衰减参数 DECAY_STEPS = 2000 DECAY_RATE = 0.95 # 学习率预热步数 WARMUP_STEPS = 2000 # 内存监控函数 def memory_usage(): """增强的内存监控函数,处理psutil缺失情况""" if HAS_PSUTIL: try: process = psutil.Process(os.getpid()) return process.memory_info().rss / (1024 ** 2) except: return 0.0 return 0.0 # ==================== 知识图谱加载器 ==================== class KnowledgeGraphLoader: def __init__(self): self.node_features = None self.adj_matrix = None self.problem_to_node = {} self.node_id_map = {} self.static_node_count = 0 self._rows = None self._cols = None def load(self): """加载知识图谱数据并进行严格的数据验证""" print("\n[KG] 加载知识图谱...") try: if not os.path.exists(KNOWLEDGE_GRAPH_PATHS['nodes']): raise FileNotFoundError(f"节点文件未找到: {KNOWLEDGE_GRAPH_PATHS['nodes']}") if not os.path.exists(KNOWLEDGE_GRAPH_PATHS['edges']): raise FileNotFoundError(f"边文件未找到: {KNOWLEDGE_GRAPH_PATHS['edges']}") node_df = pd.read_csv(KNOWLEDGE_GRAPH_PATHS['nodes']) self.static_node_count = len(node_df) print(f"[KG] 总节点数: {self.static_node_count}") # 处理空值 - 根据验证报告中的发现 print("[KG] 处理特征空值...") feature_cols = [col for col in node_df.columns if col not in ['node_id', 'type']] # 特别处理total_attempts特征 if 'total_attempts' in feature_cols: # 概念节点使用概念节点中位数填充 concept_mask = node_df['type'] == 'concept' concept_median = node_df.loc[concept_mask, 'total_attempts'].median() # 处理NaN值 if pd.isna(concept_median): concept_median = 0.0 node_df.loc[concept_mask, 'total_attempts'] = node_df.loc[concept_mask, 'total_attempts'].fillna(concept_median) # 问题节点使用问题节点中位数填充 problem_mask = node_df['type'] == 'problem' problem_median = node_df.loc[problem_mask, 'total_attempts'].median() # 处理NaN值 if pd.isna(problem_median): problem_median = 0.0 node_df.loc[problem_mask, 'total_attempts'] = node_df.loc[problem_mask, 'total_attempts'].fillna(problem_median) print(f" 填充 total_attempts 缺失值: 概念节点={concept_median}, 问题节点={problem_median}") # 处理其他数值特征 other_cols = [col for col in feature_cols if col != 'total_attempts'] for col in other_cols: # 分类型填充 if 'confidence' in col or 'affect' in col: # 情感特征使用全局平均值填充 global_mean = node_df[col].mean() # 处理NaN值 if pd.isna(global_mean): global_mean = 0.0 node_df[col] = node_df[col].fillna(global_mean) print(f" 填充 {col} 缺失值: 全局均值={global_mean:.4f}") else: # 其他特征按问题类型分组填充 problem_mask = node_df['type'] == 'problem' problem_mean = node_df.loc[problem_mask, col].mean() # 处理NaN值 if pd.isna(problem_mean): problem_mean = 0.0 node_df.loc[problem_mask, col] = node_df.loc[problem_mask, col].fillna(problem_mean) concept_mask = node_df['type'] == 'concept' concept_mean = node_df.loc[concept_mask, col].mean() # 处理NaN值 if pd.isna(concept_mean): concept_mean = 0.0 node_df.loc[concept_mask, col] = node_df.loc[concept_mask, col].fillna(concept_mean) print(f" 填充 {col} 缺失值: 问题节点={problem_mean:.4f}, 概念节点={concept_mean:.4f}") print("\n[KG诊断] 特征分析...") if feature_cols: raw_features = node_df[feature_cols].values nan_count = np.isnan(raw_features).sum() inf_count = np.isinf(raw_features).sum() print(f" 总特征值数: {raw_features.size}") print(f" NaN特征数: {nan_count}") print(f" Inf特征数: {inf_count}") if nan_count > 0 or inf_count > 0: print(f"⚠️ 警告: 节点特征包含 {nan_count} 个NaN和 {inf_count} 个Inf值,将被替换为0") raw_features = np.nan_to_num(raw_features) # 标准化特征并确保为float32类型 feature_mean = np.mean(raw_features, axis=0) feature_std = np.std(raw_features, axis=0) + 1e-8 self.node_features = np.array( (raw_features - feature_mean) / feature_std, dtype=np.float32 # 显式指定为float32 ) self.node_features = np.nan_to_num(self.node_features) # 再次确保无NaN else: print("警告: 节点文件中没有特征列") self.node_features = np.zeros((self.static_node_count, 1), dtype=np.float32) # 创建节点ID映射 self.node_id_map = {} for idx, row in node_df.iterrows(): self.node_id_map[row['node_id']] = idx # 创建问题ID到节点索引的映射 self.problem_to_node = {} problem_count = 0 for idx, row in node_df.iterrows(): if row['type'] == 'problem': try: parts = row['node_id'].split('_') if len(parts) < 2: continue problem_id = int(parts[1]) self.problem_to_node[problem_id] = idx problem_count += 1 except: continue print(f"[KG] 已加载 {problem_count} 个问题节点映射") # 加载边数据并进行优化 edge_df = pd.read_csv(KNOWLEDGE_GRAPH_PATHS['edges']) print("[KG] 优化邻接矩阵(保留每个节点的前100个邻居)...") rows, cols, data = [], [], [] valid_edge_count = 0 invalid_edge_count = 0 # 限制每个节点的邻居数量以提高效率 grouped = edge_df.groupby('source') for src, group in tqdm(grouped, total=len(grouped), desc="处理边数据"): src_idx = self.node_id_map.get(src, -1) if src_idx == -1: invalid_edge_count += len(group) continue neighbors = [] for _, row in group.iterrows(): tgt_idx = self.node_id_map.get(row['target'], -1) if tgt_idx != -1: neighbors.append((tgt_idx, row['weight'])) # 根据权重排序并取Top 100 neighbors.sort(key=lambda x: x[1], reverse=True) top_k = min(100, len(neighbors)) # 限制邻居数量 for i in range(top_k): tgt_idx, weight = neighbors[i] rows.append(src_idx) cols.append(tgt_idx) data.append(weight) valid_edge_count += 1 # 添加自环 for i in range(self.static_node_count): rows.append(i) cols.append(i) data.append(1.0) valid_edge_count += 1 # 创建稀疏邻接矩阵 adj_coo = coo_matrix( (data, (rows, cols)), shape=(self.static_node_count, self.static_node_count), dtype=np.float32 ) self.adj_matrix = adj_coo.tocsc() self._rows = np.array(rows) self._cols = np.array(cols) print(f"[KG] 邻接矩阵构建完成 | 节点: {self.static_node_count} | 边: {len(data)}") print(f"[KG优化] 最大行索引: {np.max(self._rows)} | 最大列索引: {np.max(self._cols)}") except Exception as e: import traceback print(f"知识图谱加载失败: {str(e)}") traceback.print_exc() raise RuntimeError(f"知识图谱加载失败: {str(e)}") from e # ==================== 图注意力层 ==================== class GraphAttentionLayer: def __init__(self, input_dim, output_dim, kg_loader, scope=None): self.kg_loader = kg_loader self.node_count = kg_loader.static_node_count self._rows = kg_loader._rows self._cols = kg_loader._cols with tf.variable_scope(scope or "GAT"): self.W = tf.get_variable( "W", [input_dim, output_dim], initializer=tf.initializers.variance_scaling( scale=0.1, mode='fan_avg', distribution='uniform') ) self.attn_kernel = tf.get_variable( "attn_kernel", [output_dim * 2, 1], initializer=tf.initializers.variance_scaling( scale=0.1, mode='fan_avg', distribution='uniform') ) self.bias = tf.get_variable( "bias", [output_dim], initializer=tf.zeros_initializer() ) def __call__(self, inputs): inputs = tf.clip_by_value(inputs, -5, 5) inputs = tf.check_numerics(inputs, "GAT输入包含NaN或Inf") # 特征变换 h = tf.matmul(inputs, self.W) h = tf.clip_by_value(h, -5, 5) h = tf.check_numerics(h, "特征变换后包含NaN或Inf") # 注意力机制 h_src = tf.gather(h, self._rows) h_dst = tf.gather(h, self._cols) h_concat = tf.concat([h_src, h_dst], axis=1) edge_logits = tf.squeeze(tf.matmul(h_concat, self.attn_kernel), axis=1) edge_logits = tf.clip_by_value(edge_logits, -10, 10) edge_attn = tf.nn.leaky_relu(edge_logits, alpha=0.2) # 创建稀疏注意力矩阵 edge_indices = tf.constant(np.column_stack((self._rows, self._cols)), dtype=tf.int64) sparse_attn = tf.SparseTensor( indices=edge_indices, values=edge_attn, dense_shape=[self.node_count, self.node_count] ) # 稀疏softmax和矩阵乘法 sparse_attn_weights = tf.sparse_softmax(sparse_attn) output = tf.sparse_tensor_dense_matmul(sparse_attn_weights, h) output = tf.clip_by_value(output, -5, 5) output += self.bias output = tf.nn.elu(output) output = tf.check_numerics(output, "最终GAT输出包含NaN或Inf") return output # ==================== 学生知识追踪模型 ==================== class StudentModel: def __init__(self, is_training, config): self.batch_size = config.batch_size self.num_skills = config.num_skills self.num_steps = config.num_steps self.current = tf.placeholder(tf.int32, [None, self.num_steps], name='current') self.next = tf.placeholder(tf.int32, [None, self.num_steps], name='next') self.target_id = tf.placeholder(tf.int32, [None], name='target_ids') self.target_correctness = tf.placeholder(tf.float32, [None], name='target_correctness') with tf.device('/gpu:0'), tf.variable_scope("KnowledgeGraph", reuse=tf.AUTO_REUSE): # 加载知识图谱 kg_loader = KnowledgeGraphLoader() kg_loader.load() kg_node_features = tf.constant(kg_loader.node_features, dtype=tf.float32) kg_node_features = tf.check_numerics(kg_node_features, "知识图谱节点特征包含NaN或Inf") # 精简GAT层 - 减少层数和维度 gat_output = kg_node_features for i in range(2): # 减少GAT层数为2 with tf.variable_scope(f"GAT_Layer_{i + 1}"): gat_layer = GraphAttentionLayer( input_dim=gat_output.shape[1] if i > 0 else kg_node_features.shape[1], output_dim=24 if i == 0 else 16, # 减少输出维度 kg_loader=kg_loader ) gat_output = gat_layer(gat_output) gat_output = tf.nn.elu(gat_output) self.skill_embeddings = gat_output with tf.variable_scope("FeatureProcessing"): batch_size = tf.shape(self.next)[0] # 动态获取批次大小 # 当前问题嵌入 current_indices = tf.minimum(self.current, kg_loader.static_node_count - 1) current_embed = tf.nn.embedding_lookup(self.skill_embeddings, current_indices) # 构建输入序列 - 移除下一问题嵌入(修复数据泄露) inputs = [] # 使用当前问题作为有效掩码(而不是下一个问题) valid_mask = tf.cast(tf.not_equal(self.current, 0), tf.float32) answers_float = tf.cast(self.next, tf.float32) # 历史表现特征 - 修复符号张量问题 zero_vector = tf.zeros([1, 1], dtype=tf.float32) history = tf.tile(zero_vector, [batch_size, 1]) elapsed_time = tf.tile(zero_vector, [batch_size, 1]) # 循环处理每个时间步 for t in range(self.num_steps): # 创建时间相关的特征 if t > 0: # 计算历史表现(只使用t-1及之前的信息) past_answers = answers_float[:, :t] # 只使用当前时间步之前的信息 past_valid_mask = valid_mask[:, :t] correct_count = tf.reduce_sum(past_answers * past_valid_mask, axis=1, keepdims=True) total_valid = tf.reduce_sum(past_valid_mask, axis=1, keepdims=True) history = correct_count / (total_valid + 1e-8) # 计算经过的时间 elapsed_time = tf.fill([batch_size, 1], tf.cast(t, tf.float32)) # 难度特征 - 使用知识图谱中的准确率特征 # 确保只使用当前问题的特征 difficulty_feature = tf.gather( kg_loader.node_features[:, 0], # 假设第一个特征是准确率 tf.minimum(self.current[:, t], kg_loader.static_node_count - 1) ) difficulty_feature = tf.cast(difficulty_feature, tf.float32) # 情感特征 - 使用知识图谱中的情感特征 affect_features = [] for i in range(1, 5): # 使用前4个情感特征 affect_feature = tf.gather( kg_loader.node_features[:, i], tf.minimum(self.current[:, t], kg_loader.static_node_count - 1) ) affect_feature = tf.cast(affect_feature, tf.float32) affect_features.append(tf.reshape(affect_feature, [-1, 1])) # 组合所有特征 - 移除了下一问题嵌入(修复数据泄露) combined = tf.concat([ current_embed[:, t, :], history, elapsed_time, tf.reshape(difficulty_feature, [-1, 1]), *affect_features ], axis=1) inputs.append(combined) # RNN模型 with tf.variable_scope("RNN"): cell = rnn.LSTMCell( FLAGS.hidden_size, initializer=tf.initializers.glorot_uniform(), forget_bias=1.0 ) if is_training and FLAGS.keep_prob < 1.0: cell = rnn.DropoutWrapper(cell, output_keep_prob=FLAGS.keep_prob) outputs, _ = tf.nn.dynamic_rnn( cell, tf.stack(inputs, axis=1), dtype=tf.float32 ) output = tf.reshape(outputs, [-1, FLAGS.hidden_size]) # 输出层 with tf.variable_scope("Output"): hidden = tf.layers.dense( output, units=32, activation=tf.nn.relu, kernel_initializer=tf.initializers.glorot_uniform(), name="hidden_layer" ) logits = tf.layers.dense( hidden, units=1, kernel_initializer=tf.initializers.glorot_uniform(), name="output_layer" ) # 损失计算 self._all_logits = tf.clip_by_value(logits, -20, 20) selected_logits = tf.gather(tf.reshape(self._all_logits, [-1]), self.target_id) self.pred = tf.clip_by_value(tf.sigmoid(selected_logits), 1e-8, 1 - 1e-8) # 焦点损失 labels = tf.clip_by_value(self.target_correctness, 0.05, 0.95) pos_weight = tf.reduce_sum(1.0 - labels) / (tf.reduce_sum(labels) + 1e-8) bce_loss = tf.nn.weighted_cross_entropy_with_logits( targets=labels, logits=selected_logits, pos_weight=pos_weight ) loss = tf.reduce_mean(bce_loss) # L2正则化 l2_loss = tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name ]) * FLAGS.l2_lambda self.cost = loss + l2_loss # ==================== 数据加载 ==================== def read_data_from_csv_file(path, kg_loader, is_training=False): """更鲁棒的数据加载函数""" students = [] student_ids = [] max_skill = 0 missing_problems = set() # 增强文件存在性检查 if not os.path.exists(path): print(f"❌ 严重错误: 数据文件不存在: {path}") print("请检查以下可能原因:") print("1. 文件路径是否正确") print("2. 文件名是否匹配") print("3. 文件权限是否足够") # 尝试列出目录内容以便调试 dir_path = os.path.dirname(path) print(f"目录内容: {os.listdir(dir_path) if os.path.exists(dir_path) else '目录不存在'}") return [], [], [], 0, 0, 0 try: # 打印正在加载的文件路径 print(f"[数据] 加载数据文件: {path}") # 读取数据集 - 增强CSV读取兼容性 try: data_df = pd.read_csv(path) except Exception as e: print(f"CSV读取失败: {str(e)}") print("尝试使用备用方法读取...") # 尝试不同编码 encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252'] for encoding in encodings: try: data_df = pd.read_csv(path, encoding=encoding) print(f"成功使用 {encoding} 编码读取文件") break except Exception as e: print(f"编码 {encoding} 尝试失败: {str(e)}") continue if 'data_df' not in locals(): print("所有编码尝试失败,无法读取文件") return [], [], [], 0, 0, 0 print(f"[数据] 加载完成 | 记录数: {len(data_df)}") # 检查必要的列是否存在 - 支持多种列名变体 # 可能的列名变体 possible_columns = { 'user_id': ['user_id', 'userid', 'student_id', 'studentid'], 'problem_id': ['problem_id', 'problemid', 'skill_id', 'skillid'], 'correct': ['correct', 'correctness', 'answer', 'accuracy'], 'start_time': ['start_time', 'timestamp', 'time', 'date'] } # 查找实际列名 actual_columns = {} for col_type, possible_names in possible_columns.items(): found = False for name in possible_names: if name in data_df.columns: actual_columns[col_type] = name found = True break if not found: print(f"❌ 错误: 找不到 {col_type} 列") print(f"数据列: {list(data_df.columns)}") return [], [], [], 0, 0, 0 # 重命名列为标准名称以便后续处理 data_df = data_df.rename(columns={ actual_columns['user_id']: 'user_id', actual_columns['problem_id']: 'problem_id', actual_columns['correct']: 'correct', actual_columns['start_time']: 'start_time' }) print(f"[数据] 使用列: user_id, problem_id, correct, start_time") # 按学生分组 grouped = data_df.groupby('user_id') print(f"[数据] 分组完成 | 学生数: {len(grouped)}") for user_id, group in tqdm(grouped, total=len(grouped), desc="处理学生数据"): # 按时间排序 group = group.sort_values('start_time') problems = group['problem_id'].values answers = group['correct'].values.astype(int) # 筛选有效数据 - 添加详细日志 valid_data = [] invalid_count = 0 for i, (p, a) in enumerate(zip(problems, answers)): # 检查问题是否在知识图谱中 if p in kg_loader.problem_to_node and a in (0, 1): # 额外检查:确保问题特征不包含学生作答信息 node_idx = kg_loader.problem_to_node[p] if 'accuracy' in kg_loader.node_features[node_idx]: # 如果特征中包含准确率,警告可能的数据泄露 print(f"警告: 问题 {p} 的特征包含准确率信息,可能导致数据泄露") valid_data.append((p, a)) else: invalid_count += 1 if p != 0 and p not in missing_problems: print(f"警告: 问题ID {p} 不在知识图谱中 (学生: {user_id}, 位置: {i})") missing_problems.add(p) if len(valid_data) < 2: print(f"跳过数据不足的学生 {user_id} (有效交互: {len(valid_data)}, 无效: {invalid_count})") continue # 分割序列 problems, answers = zip(*valid_data) n_split = (len(problems) + FLAGS.problem_len - 1) // FLAGS.problem_len for k in range(n_split): start = k * FLAGS.problem_len end = (k + 1) * FLAGS.problem_len seg_problems = list(problems[start:end]) seg_answers = list(answers[start:end]) # 填充短序列 if len(seg_problems) < FLAGS.problem_len: pad_len = FLAGS.problem_len - len(seg_problems) seg_problems += [0] * pad_len seg_answers += [0] * pad_len # 训练数据增强 if is_training: valid_indices = [i for i, p in enumerate(seg_problems) if p != 0] if len(valid_indices) > 1 and random.random() > 0.5: random.shuffle(valid_indices) seg_problems = [seg_problems[i] for i in valid_indices] + seg_problems[len(valid_indices):] seg_answers = [seg_answers[i] for i in valid_indices] + seg_answers[len(valid_indices):] # 映射问题ID到知识图谱节点 mapped_problems = [] for p in seg_problems: if p == 0: mapped_problems.append(0) elif p in kg_loader.problem_to_node: mapped_problems.append(kg_loader.problem_to_node[p]) else: mapped_problems.append(0) students.append(([user_id, k], mapped_problems, seg_answers)) max_skill = max(max_skill, max(mapped_problems)) student_ids.append(user_id) except Exception as e: print(f"数据加载失败: {str(e)}") import traceback traceback.print_exc() return [], [], [], 0, 0, 0 avg_length = sum(len(s[1]) for s in students) / len(students) if students else 0 print(f"[数据统计] 学生数: {len(student_ids)} | 序列数: {len(students)}") print(f" 最大技能ID: {max_skill} | 平均序列长度: {avg_length:.1f}") print(f" 缺失问题数: {len(missing_problems)}") return students, [], student_ids, max_skill, 0, 0 # ==================== 训练流程 ==================== def run_epoch(session, model, data, run_type, eval_op, global_step=None): preds = [] labels = [] total_loss = 0.0 step = 0 processed_count = 0 total_batches = max(1, len(data) // model.batch_size) with tqdm(total=total_batches, desc=f"{run_type} Epoch") as pbar: index = 0 while index < len(data): # 准备批次数据 current_batch = [] next_batch = [] target_ids = [] target_correctness = [] for i in range(model.batch_size): if index >= len(data): break stu_id, problems, answers = data[index] valid_length = sum(1 for p in problems if p != 0) if valid_length < 1: index += 1 continue current_batch.append(problems) next_batch.append(answers) last_step = valid_length - 1 target_ids.append(i * model.num_steps + last_step) target_correctness.append(answers[last_step]) index += 1 if len(current_batch) == 0: pbar.update(1) step += 1 continue # 创建feed_dict feed = { model.current: np.array(current_batch, dtype=np.int32), model.next: np.array(next_batch, dtype=np.int32), model.target_id: np.array(target_ids, dtype=np.int32), model.target_correctness: np.array(target_correctness, dtype=np.float32) } # 运行计算 try: results = session.run( [model.pred, model.cost, eval_op], feed_dict=feed ) pred, loss = results[:2] preds.extend(pred.tolist()) labels.extend(target_correctness) total_loss += loss * len(current_batch) processed_count += len(current_batch) pbar.set_postfix( loss=f"{loss:.4f}", mem=f"{memory_usage():.1f}MB" ) pbar.update(1) step += 1 except Exception as e: print(f"\n训练错误: {str(e)}") import traceback traceback.print_exc() break # 计算指标 if not labels or not preds: print(f"{run_type}周期: 无有效样本!") return float('nan'), 0.5, 0.0, 0.0 labels = np.array(labels, dtype=np.float32) preds = np.array(preds, dtype=np.float32) mask = np.isfinite(labels) & np.isfinite(preds) if not mask.any(): print(f"{run_type}周期: 所有样本包含无效值!") return float('nan'), 0.5, 0.0, 0.0 labels = labels[mask] preds = preds[mask] try: rmse = np.sqrt(mean_squared_error(labels, preds)) fpr, tpr, _ = roc_curve(labels, preds) auc_score = auc(fpr, tpr) r2 = r2_score(labels, preds) avg_loss = total_loss / processed_count if processed_count > 0 else 0.0 print(f"\n{run_type}周期总结:") print(f" 样本数: {len(labels)} | 正样本比例: {np.mean(labels > 0.5):.3f}") print(f" Loss: {avg_loss:.4f} | RMSE: {rmse:.4f} | AUC: {auc_score:.4f} | R²: {r2:.4f}") # 添加预测值分布分析 print("\n预测值分布分析:") print(f" 最小值: {np.min(preds):.4f} | 最大值: {np.max(preds):.4f}") print(f" 均值: {np.mean(preds):.4f} | 中位数: {np.median(preds):.4f}") print(f" 标准差: {np.std(preds):.4f}") # 检查完美预测的情况 perfect_preds = np.sum((preds < 1e-5) | (preds > 1 - 1e-5)) if perfect_preds > 0: perfect_ratio = perfect_preds / len(preds) print(f" 警告: {perfect_preds}个样本({perfect_ratio*100:.2f}%)预测值为0或1") # 检查预测值是否全部相同 if np.all(preds == preds[0]): print(f" 严重警告: 所有预测值相同 ({preds[0]:.4f})") return rmse, auc_score, r2, avg_loss except Exception as e: print(f"指标计算错误: {str(e)}") return float('nan'), 0.5, 0.0, 0.0 # ==================== 主函数 ==================== def main(_): print(f"[系统] 训练数据路径: {FLAGS.train_data_path}") print(f"[系统] 测试数据路径: {FLAGS.test_data_path}") # 检查文件是否存在 if not os.path.exists(FLAGS.train_data_path): print(f"❌ 训练文件不存在: {FLAGS.train_data_path}") if not os.path.exists(FLAGS.test_data_path): print(f"❌ 测试文件不存在: {FLAGS.test_data_path}") print(f"⚠️ 优化设置: batch_size={FLAGS.batch_size}, hidden_size={FLAGS.hidden_size}, lr={FLAGS.learning_rate}") session_conf = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False, operation_timeout_in_ms=60000 ) session_conf.gpu_options.allow_growth = True with tf.Session(config=session_conf) as sess: # 加载知识图谱 kg_loader = KnowledgeGraphLoader() kg_loader.load() # 加载数据 print("\n[系统] 加载训练数据...") train_data = read_data_from_csv_file(FLAGS.train_data_path, kg_loader, is_training=True) print("[系统] 加载测试数据...") test_data = read_data_from_csv_file(FLAGS.test_data_path, kg_loader) if not train_data[0] or not test_data[0]: print("❌ 错误: 训练或测试数据为空!") return # 模型配置 class ModelConfig: def __init__(self): self.batch_size = FLAGS.batch_size self.num_skills = kg_loader.static_node_count + 100 # 添加缓冲区 self.num_steps = FLAGS.problem_len self.keep_prob = FLAGS.keep_prob model_config = ModelConfig() print(f"[配置] 技能数量: {model_config.num_skills}") print(f"[配置] 序列长度: {model_config.num_steps}") # 构建模型 print("\n[系统] 构建模型...") with tf.variable_scope("Model"): train_model = StudentModel(is_training=True, config=model_config) tf.get_variable_scope().reuse_variables() test_model = StudentModel(is_training=False, config=model_config) # 优化器和训练操作 global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, global_step, DECAY_STEPS, DECAY_RATE, staircase=True ) optimizer = tf.train.AdamOptimizer( learning_rate=learning_rate, epsilon=FLAGS.epsilon ) grads_and_vars = optimizer.compute_gradients(train_model.cost) grads, variables = zip(*grads_and_vars) clipped_grads, _ = tf.clip_by_global_norm(grads, FLAGS.max_grad_norm) train_op = optimizer.apply_gradients(zip(clipped_grads, variables), global_step=global_step) # 初始化变量 sess.run(tf.global_variables_initializer()) print(f"[系统] 训练开始 | 批次: {FLAGS.batch_size} | 学习率: {FLAGS.learning_rate}") # 模型保存 checkpoint_dir = "checkpoints_assist2012" os.makedirs(checkpoint_dir, exist_ok=True) saver = tf.train.Saver(max_to_keep=3) best_auc = 0.0 # 训练循环 for epoch in range(FLAGS.epochs): print(f"\n==== Epoch {epoch + 1}/{FLAGS.epochs} ====") current_lr = sess.run(learning_rate) print(f"[学习率] 当前学习率: {current_lr:.7f}") # 训练 train_rmse, train_auc, train_r2, train_loss = run_epoch( sess, train_model, train_data[0], '训练', train_op ) # 评估 if (epoch + 1) % FLAGS.evaluation_interval == 0: test_rmse, test_auc, test_r2, test_loss = run_epoch( sess, test_model, test_data[0], '测试', tf.no_op() ) # 保存最佳模型 if test_auc > best_auc: best_auc = test_auc save_path = saver.save(sess, f"{checkpoint_dir}/best_model.ckpt") print(f"保存最佳模型: {save_path}, AUC={best_auc:.4f}") print("\n训练完成!") if __name__ == "__main__": tf.app.run() 训练代码的测试集的auc 20轮只达到了0.7658;哪里出了问题,如何提高auc
最新发布
07-02
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值