java dnn_Tensorflow DNN多重分类

作者在使用Tensorflow构建DNN进行多元分类时遇到问题,模型预测倾向于0和2类别,而忽略了1类别。探讨了softmax可能的解决方案,并寻求解决为何类别不平衡的建议。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

我正在尝试使用Tensorflow在Python 3.5中创建一个DNN,用于将元组分类为 3 类之一 .

# define initial hyperparameters

batch_size = 100

train_steps = 5000

hidden_units=[10,20,10]

# build model

dnn = tf.contrib.learn.DNNClassifier(hidden_units=hidden_units, feature_columns=feature_cols, n_classes=3)

input_fn = tf.estimator.inputs.pandas_input_fn(x=X_train, y=y_train,

batch_size=batch_size,

num_epochs=None,

shuffle=True)

# fit model to the data

dnn.fit(input_fn = input_fn, steps=train_steps)

# predict new data

predict_input_func = tf.estimator.inputs.pandas_input_fn(x=X_test,

batch_size=len(X_test),

shuffle=False)

preds = dnn.predict_classes(input_fn=predict_input_func)

X_train(和X_test)由7个数字列组成 . y_train(和y_test)由1个数字列组成,作为响应变量,[0或1或2] .

当我用上述模型预测时,我的准确度非常差(准确度为50-70%) .

似乎我已经弄清楚了为什么 - 我的模型预测新输入的类为0或2 ......所以它实际上丢失了所有类1的记录 .

有人能给我一个提示,为什么会这样?我已经读过softmax可能是解决方案......如果是这样的话,我很困惑为什么在Tensorflow文档(章节入门/ Iris分类)中描述了3类类似的DNN .

编辑:我当然尝试过不同的超参数 .

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值