CH4_朴素贝叶斯及其Spark实现

本文介绍了朴素贝叶斯算法的基本原理及其实现过程,并通过一个具体的案例进行演示。包括了算法背后的数学推导,以及如何使用Spark进行大规模数据处理。

CH4_朴素贝叶斯及其Spark实现

关于贝叶斯的介绍在之前的文章中也有说明,网上也有许多资料,在这里就不在做过多赘述。

朴素贝叶斯模型

假设我们有数据样本如下:
( X 1 , X 2 , . . X n , Y ) (X_1,X_2,..X_n,Y) (X1,X2,..Xn,Y)
有m个样本,每个样本有n个特征,特征输出有K个类别
我们可以通过以上样本学习得出先验概率
P ( Y = c k ) , k = 1 , 2... k P(Y=c_k),k=1,2...k P(Y=ck),k=1,2...k
然后学习得出条件概率
P ( X = x ∣ Y = c k ) = P ( X ( 1 ) = x ( 1 ) , X ( 2 ) = x ( 2 ) , X ( n ) = x ( n ) ∣ Y = c k ) ) P(X=x|Y=c_k) =P(X^{(1)}=x^{(1)}, X^{(2)}=x^{(2)},X^{(n)}=x^{(n)}|Y=c_k)) P(X=xY=ck)=P(X(1)=x(1),X(2)=x(2),X(n)=x(n)Y=ck))
朴素贝叶斯之所以朴素是因为对条件概率分布做了条件独立性假设(即用于分类的特征在类确定的条件下都是独立的),具体表现为:
P ( X = x ∣ Y = c k ) = P ( X ( 1 ) = x ( 1 ) , X ( 2 ) = x ( 2 ) , X ( n ) = x ( n ) ∣ Y = c k ) ) = ∏ i = 1 N P ( X ( i ) = x ( i ) ∣ Y = c k ) P(X=x|Y=c_k) =P(X^{(1)}=x^{(1)}, X^{(2)}=x^{(2)},X^{(n)}=x^{(n)}|Y=c_k))= \prod_{i=1}^N P(X^{(i)}=x^{(i)}|Y=c_k) P(X=xY=ck)=P(X(1)=x(1),X(2)=x(2),X(n)=x(n)Y=ck))=i=1NP(X(i)=x(i)Y=ck)
朴素贝叶斯的这一假设大大简化了条件分布的计算,但是也牺牲了分类的准确性,因此,在特征之间非常不独立的情况下,可以优先考虑其他分类算法。
在给定的特征向量X=x的情况下,通过学习得到后验概率分布:
P ( Y = c k ∣ X = x ) P(Y=c_k|X=x) P(Y=ckX=x)
根据贝叶斯定量整理得出:
P ( Y = c k ∣ X = x ) = ∏ i = 1 N P ( X ( i ) = x ( i ) ∣ Y = c k ) P ( Y = c k ) ∑ k = 1 N P ( Y = c k ) ∏ i = 1 N P ( X ( i ) = x ( i ) ∣ Y = c k ) P(Y=c_k|X=x) =\frac{ \prod_{i=1}^N P(X^{(i)}=x^{(i)}|Y=c_k)P(Y=c_k)}{\sum_{k=1}^N P(Y=c_k) \prod_{i=1}^N P(X^{(i)}=x^{(i)}|Y=c_k)} P(Y=ckX=x)=k=1NP(Y=ck)i=1NP(X(i)=x(i)Y=ck)i=1NP(X(i)=x(i)Y=ck)P(Y=ck)
将后验概率最大的类最X=x的输出。

以上过程是朴素贝叶斯的极大似然估计方法,在实际计算中,为了防止计算出现概率值为0的情况,我们在计算先验概率和条件概率时加入一个拉普拉斯平滑指数 λ(λ>=0)。(贝叶斯估计)
先验概率的贝叶斯估计:
P λ ( Y = c k ) = ∑ k = 1 N I ( y i = c k ) + λ N + k λ P_λ(Y=c_k) = \frac{\sum_{k=1}^N I(y_i=c_k)+λ}{N+kλ} Pλ(Y=ck)=N+kλk=1NI(yi=ck)+λ
条件概率的贝叶斯估计:
P λ ( X ( j ) = a j l ∣ Y = c k ) = ∑ k = 1 N I ( ( X ( j ) = a j l , y i = c k ) + λ ∑ k = 1 N I ( y i = c k ) + S j λ P_λ(X^{(j)}=a_{jl} |Y=c_k) =\frac{\sum_{k=1}^N I((X^{(j)}=a_{jl} ,y_i=c_k)+λ}{\sum_{k=1}^N I(y_i=c_k)+S_jλ} Pλ(X(j)=ajlY=ck)=k=1NI(yi=ck)+Sjλk=1NI((X(j)=ajl,yi=ck)+λ

跟多解释和证明请参考李航老师的《统计学习方法》

数据说明

本案例数据来自于MNIST数据库,是一份手写数字数据库。

实现代码

1. 朴素贝叶斯算法模型

import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import scala.beans.BeanProperty
import scala.collection.mutable

/**
  * Created by WZZC on 2019/12/10
  **/
case class MultinomilNaiveBayesModel(data: DataFrame, labelColName: String) {

  private val spark: SparkSession = data.sparkSession
  @BeanProperty var fts: Array[String] =
    data.columns.filterNot(_ == labelColName)
  private val ftsName: String = Identifiable.randomUID("NaiveBayesModel")

  // 拉普拉斯平滑指数
  @BeanProperty var lamada: Double = 1.0

  /**
    * 数据特征转换
    *
    * @param dataFrame
    * @return
    */
  def dataTransForm(dataFrame: DataFrame) = {
    new VectorAssembler()
      .setInputCols(fts)
      .setOutputCol(ftsName)
      .transform(dataFrame)
  }

  var contingentProb: Array[(Double, (Int, Seq[mutable.Map[Double, Double]]))] =
    _

  var priorProb: Map[Double, Double] = _

  def fit = {

    def seqOp(c: (Int, Seq[mutable.Map[Double, Int]]),
              v: (Int, Seq[Double])) = {

      val w: Int = c._1 + v._1

      val tuples = c._2
        .zip(v._2)
        .map(tp => {
          val ftsvalueNum: Int = tp._1.getOrElse(tp._2, 0)

          tp._1 += tp._2 -> (ftsvalueNum + 1)
          tp._1
        })

      (w, tuples)

    }

    def combOp(c1: (Int, Seq[mutable.Map[Double, Int]]),
               c2: (Int, Seq[mutable.Map[Double, Int]])) = {

      val w = c2._1 + c1._1

      val resMap = c1._2
        .zip(c2._2)
        .map(tp => {
          val m1: mutable.Map[Double, Int] = tp._1
          val m2: mutable.Map[Double, Int] = tp._2

          m1.foreach(kv => {
            val i = m2.getOrElse(kv._1, 0)
            m2 += kv._1 -> (kv._2 + i)
          })

          m2
        })

      (w, resMap)
    }

    val nilSeq = new Array[Any](fts.length)
      .map(x => mutable.Map[Double, Int]())
      .toSeq

    val agged = dataTransForm(data).rdd
      .map(row => {

        val lable: Double = row.getAs[Double](labelColName)
        val fts: Seq[Double] = row.getAs[Vector](ftsName).toArray.toSeq

        (lable, (1, fts))
      })
      .aggregateByKey[(Int, Seq[mutable.Map[Double, Int]])]((0, nilSeq))(
        seqOp,
        combOp
      )

    val numLabels: Long = agged.count()
    val numDocuments: Double = agged.map(_._2._1).sum

    // 拉普拉斯变换
    val lpsamples: Double = numDocuments + lamada * numLabels

    //  条件概率
    contingentProb = agged
      .mapValues(tp => {
        val freq = tp._1

        val lprob = tp._2.map(
          m =>
            m.map {
              case (k, v) => {
                val logprob = (v / freq.toDouble).formatted("%.4f")
//                  val logprob = math.log(v / freq.toDouble).formatted("%.4f")
                (k, logprob.toDouble)
              }
          }
        )
        (freq, lprob)
      })
      .collect()

    //  先验概率
    priorProb = agged
      .map(tp => (tp._1, tp._2._1))
      .mapValues(v => ((v + lamada) / lpsamples))
//      .mapValues(v => math.log((v + lamada) / lpsamples))
      .collect()
      .toMap

  }

  def predict(predictData: Seq[Double]): Double = {

    val posteriorProb: Array[(Double, Double)] = contingentProb
      .map(tp3 => {
        val label: Double = tp3._1

        val tp: (Int, Seq[mutable.Map[Double, Double]]) = tp3._2

        val missProb: Double = lamada / (tp._2.length * lamada)

        val sum: Double = tp._2
          .zip(predictData)
          .map {
            case (pmap, ftValue) => {

              val d: Double = pmap.getOrElse(ftValue, missProb)

              math.log(d)

            }
          }
          .sum

        (label, sum)

      })
      .map(tp => (tp._1, priorProb.getOrElse(tp._1, 0.0) + math.log(tp._2)))

    posteriorProb.maxBy(_._2)._1

  }

  def predict(predictData: DataFrame): DataFrame = {

    contingentProb.foreach(println)

    priorProb.foreach(println)

    val predictudf = udf((vec: Vector) => predict(vec.toArray.toSeq))

    dataTransForm(predictData)
      .withColumn(labelColName, predictudf(col(ftsName)))
      .drop(ftsName)

  }

}

2. 算法测试
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.DoubleType

/**
  * Created by WZZC on 2019/4/27
  **/
object NaiveBayesRunner {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName(s"${this.getClass.getSimpleName}")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._
    // 数据加载
    val data = spark.read
      .option("header", true)
      .option("inferSchema", true)
      .csv("data/naviebayes.csv")

    val model = new StringIndexer()
      .setInputCol("x2")
      .setOutputCol("indexX2")
      .fit(data)

    val dataFrame = model
      .transform(data)
      .withColumn("x1", $"x1".cast(DoubleType))
      .withColumn("y", $"y".cast(DoubleType))

    val bayes = MultinomilNaiveBayesModel(dataFrame, "y")

    bayes.setFts(Array("x1", "indexX2"))

    bayes.fts.foreach(println)
    bayes.fit

    bayes.predict(dataFrame).show()

    spark.stop()
  }

}

参考资料:

《统计学习方法》-- 李航

https://www.cnblogs.com/pinard/p/6069267.html

<think>我们遇到的是一个典型的字符编码转换问题。在Windows编程中,当项目设置为Unicode字符集时,字符串需要使用宽字符(wchar_t)表示,而直接使用字符串字面量(如"...")会被解释为窄字符串(const char[]),导致类型不匹配。 错误信息: 错误 C2440 “=”: 无法从“const char []”转换为“LPCWSTR” 原因: - LPCWSTR 是 const wchar_t* 类型,用于Unicode(宽字符)字符串。 - 而代码中使用了类似 "字符串" 的字面量,其类型是 const char[],无法自动转换为 const wchar_t*。 解决方案: 1. 使用宽字符字符串字面量:在字符串前加 L 前缀,例如 L"宽字符串"。 2. 使用TEXT宏(或_T宏):它会根据项目字符集设置自动转换为合适的字符串类型。 3. 更改项目字符集设置:将项目属性中的“字符集”设置为“使用多字节字符集”(但这不是推荐做法,因为Unicode是更现代和推荐的方式)。 示例代码: 假设原代码是: LPCWSTR str = "Hello"; 这会导致错误,因为"Hello"是窄字符串。 修正方法1(使用L前缀): LPCWSTR str = L"Hello"; 修正方法2(使用TEXT宏): LPCWSTR str = TEXT("Hello"); // 需要包含<windows.h>,TEXT宏在windows.h中定义 另外,如果是在MFC中,还可以使用 _T 宏(与TEXT相同): LPCWSTR str = _T("Hello"); 但是注意:如果项目是Unicode字符集,则TEXT("Hello")会被解释为L"Hello",如果是多字节字符集,则解释为"Hello"。 另外,如果函数参数要求是LPCWSTR,而传递的是字符串字面量,也需要同样处理。 例如,在创建窗口时: HWND hWnd = CreateWindowW(className, L"窗口标题", ...); // 使用宽字符版本 或者 HWND hWnd = CreateWindow(className, TEXT("窗口标题"), ...); // 使用通用版本 注意:CreateWindow是一个宏,根据字符集设置会映射到CreateWindowW(宽字符)或CreateWindowA(窄字符)。 因此,在代码中,我们应该使用TEXT宏来保证可移植性。 示例修正: 原错误代码位置:ch4_1_7.cpp 第11行 假设第11行代码是: SetWindowText("我的窗口"); // 错误 修正为: SetWindowText(_T("我的窗口")); 或者如果明确知道使用宽字符函数: SetWindowTextW(L"我的窗口"); 但是,为了通用性,建议使用第一种。 另外,如果项目是Unicode字符集,那么SetWindowText实际上就是SetWindowTextW,所以要求宽字符串。 总结:在Windows编程中,处理字符串时要注意字符集设置,使用TEXT宏或显式宽字符串。 §§相关问题§§ 1. 在Windows编程中,TEXT宏和_T宏有什么区别? 2. 如何将现有的窄字符串(const char*)动态转换为宽字符串(wchar_t*)以兼容Unicode API? 3. 为什么在Windows编程中推荐使用Unicode字符集而不是多字节字符集? 4. 在跨平台项目中如何统一处理字符串编码以避免类似转换问题? 5. 除了字符串类型不匹配,还有哪些常见的Windows编程中因字符集设置导致的错误?</think>这个编译错误 `C2440: “=”: 无法从“const char []”转换为“LPCWSTR”` 是 Windows 编程中常见的字符编码问题,通常在使用 MFC 或 Win32 API 时出现。以下是详细分析和解决方案: ### 错误原因分析 1. **字符集不匹配**: - `const char[]`:ANSI 窄字符(8位) - `LPCWSTR`:Unicode 宽字符(16位,`wchar_t*`) - 项目设置为 Unicode 字符集(默认),但代码使用了 ANSI 字符串 2. **项目配置**: ```mermaid graph TD A[项目属性] --> B[配置属性] B --> C[高级] C --> D[字符集] D --> E[使用Unicode字符集] --> 错误 D --> F[使用多字节字符集] --> 可能正常 ``` 3. **API 版本差异**: - `CreateWindow` 实际是宏: ```cpp #ifdef UNICODE #define CreateWindow CreateWindowW #else #define CreateWindow CreateWindowA #endif ``` ### 解决方案 #### 方案1:使用宽字符字面量(推荐) 在字符串前加 `L` 前缀: ```cpp // 错误:LPCWSTR title = "我的窗口"; LPCWSTR title = L"我的窗口"; // 正确 ``` #### 方案2:使用通用文本宏 ```cpp #include <tchar.h> // 使用 _T() 或 TEXT() 宏 LPCTSTR title = _T("我的窗口"); SetWindowText(_T("新标题")); ``` #### 方案3:显式转换函数 ```cpp // 窄转宽 const char* narrowStr = "ANSI文本"; int len = MultiByteToWideChar(CP_ACP, 0, narrowStr, -1, NULL, 0); wchar_t* wideStr = new wchar_t[len]; MultiByteToWideChar(CP_ACP, 0, narrowStr, -1, wideStr, len); // 使用后释放 delete[] wideStr; ``` #### 方案4:修改项目字符集设置 1. 项目 → 属性 → 配置属性 → 高级 2. 字符集 → 改为"使用多字节字符集" 3. **注意**:此方案不推荐,会降低国际化支持 ### 完整修复示例 ```cpp #include <Windows.h> #include <tchar.h> int WINAPI WinMain(HINSTANCE hInst, HINSTANCE, LPSTR, int) { // 正确使用宽字符 LPCWSTR className = L"MyWindowClass"; LPCWSTR title = L"Unicode窗口"; // 或使用通用宏 LPCTSTR className2 = _T("MyWindowClass"); LPCTSTR title2 = _T("通用文本窗口"); // 注册窗口类 WNDCLASS wc = {}; wc.lpfnWndProc = DefWindowProc; wc.hInstance = hInst; wc.lpszClassName = className; // 正确使用 // 创建窗口 HWND hWnd = CreateWindow(className, title, ...); return 0; } ``` ### 预防措施 1. **统一编码规范**: ```cpp // 始终使用 #include <tchar.h> LPCTSTR str = _T("文本"); TCHAR buffer[256]; ``` 2. **API 调用规范**: ```cpp // 错误:MessageBox(NULL, "内容", "标题", MB_OK); MessageBox(NULL, _T("内容"), _T("标题"), MB_OK); // 正确 ``` 3. **字符串处理函数**: | ANSI 函数 | Unicode 函数 | 通用函数 | |-----------|-------------|----------| | `strlen` | `wcslen` | `_tcslen` | | `strcpy` | `wcscpy` | `_tcscpy` | ### 调试技巧 在 Visual Studio 中: 1. 查看预处理结果:项目 → 属性 → C/C++ → 预处理器 → 预处理到文件 → 是 2. 编译后查看 .i 文件,检查宏展开结果
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值