Java实现的朴素贝叶斯分类器

本文介绍了一种简单的朴素贝叶斯分类器(Naive Bayes Classifier, NBC)的实现方法,该分类器适用于结果只有两种情况的问题,并提供了具体的Java代码实现。文章详细解释了如何通过分离训练数据集来提高分类准确率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目前的算法只能处理结果只有两种的情况,即true or false. 多分枝或者是数字类型的还无法处理。

用到的一些基础数据结构可以参考上一篇关于ID3的代码。 

 

这里只贴出来实现贝叶斯分类预测的部分:

package classifier;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import util.ArffUtil;


/**
 * NBC means Naive Bayes Classifier
 * @author wenjun_yang
 *
 */
public class NBCUtil {
	
	ArffUtil util = new ArffUtil();
	private List<String> attributeList = null;
	private List<String[]> dataList = null;
	private String decAttributeName = null;
	private int decAttributeIndex = -1;
	
	private Map<String, List<String[]>> seperatedDataTable = null;
	public NBCUtil(String decAttributeName, List<String> attributeList, List<String[]> dataList) {
		this.attributeList = attributeList;
		this.dataList = dataList;
		this.decAttributeName = decAttributeName;
		
		this.decAttributeIndex = util.getValueIndex(decAttributeName, this.attributeList);
		this.seperatedDataTable = seperateDataList(dataList);
	}
	
	private Map<String, List<String[]>> seperateDataList(List<String[]> dataList) {
		Map<String, List<String[]>> map = new HashMap<String, List<String[]>>();
		
		for(String[] arr : dataList) {
			if(decAttributeIndex >= 0 && decAttributeIndex < arr.length) {
				String currentKey = arr[decAttributeIndex]; 
				if(map.containsKey(currentKey)) {
					List<String[]> tempList = map.get(currentKey);
					tempList.add(arr);
					map.put(currentKey, tempList);
				} else {
					List<String[]> tempList = new ArrayList<String[]>();
					tempList.add(arr);
					map.put(currentKey , tempList);
				}
			}
		}
		
		return map;
	}
	
	public Boolean predict(Map<String, String> predictData, String targetDecAttributeValue) {
		if(predictData.containsKey(decAttributeName)) predictData.remove(decAttributeName);
		
		List<String[]> positiveDataTable = new ArrayList<String[]>();
		if(seperatedDataTable.containsKey(targetDecAttributeValue)) {
			positiveDataTable = seperatedDataTable.get(targetDecAttributeValue);
		}
		
		double resultP = 1.;
		
		// Step 1: 逐个属性的比率进行计算
		// 即: 计算 P(Attr=Value|Y=true) / P(Attr=Value|Y=false) 的值
		for(String attrName : predictData.keySet()) {
			String attrValue = predictData.get(attrName);
			int attrIndex = util.getValueIndex(attrName, attributeList);
			int attrPositiveCount = 0;
			int attrNegativeCount = 0;
			
			for(String[] arr : dataList) {
				if(arr[attrIndex].equals(attrValue)) {
					if(arr[decAttributeIndex].equals(targetDecAttributeValue)) {
						attrPositiveCount++;
					} else {
						attrNegativeCount++;
					}
				}
			}
			double temp =  (attrPositiveCount / (double)positiveDataTable.size() ) /
							(attrNegativeCount / (double)(dataList.size() - positiveDataTable.size()));
			resultP *= temp;
		}
		// 最后计算 P(Y=true) / P(Y=false)
		resultP *= positiveDataTable.size() / (double)(dataList.size() - positiveDataTable.size());
		System.out.println(resultP);
		if(resultP > 1) {
			return true;
		} else {
			return false;
		}
	}
}

 

 

完整的项目也上传了,可以直接使用。

数据源来自weka

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值