决策树 算法 实例 scala

young   myope   no  reduced no lenses
young   myope   no  normal  soft
young   myope   yes reduced no lenses
young   myope   yes normal  hard
young   hyper   no  reduced no lenses
young   hyper   no  normal  soft
young   hyper   yes reduced no lenses
young   hyper   yes normal  hard
pre myope   no  reduced no lenses
pre myope   no  normal  soft
pre myope   yes reduced no lenses
pre myope   yes normal  hard
pre hyper   no  reduced no lenses
pre hyper   no  normal  soft
pre hyper   yes reduced no lenses
pre hyper   yes normal  no lenses
presbyopic  myope   no  reduced no lenses
presbyopic  myope   no  normal  no lenses
presbyopic  myope   yes reduced no lenses
presbyopic  myope   yes normal  hard
presbyopic  hyper   no  reduced no lenses
presbyopic  hyper   no  normal  soft
presbyopic  hyper   yes reduced no lenses
presbyopic  hyper   yes normal  no lenses
package mlia.trees

import breeze.numerics._
import scala.annotation.tailrec

case class Tree(nodes: Array[Node] = Array.empty) {

  override def toString = s"Tree[${nodes.map(_.toString).mkString(",")}]"

  def <<-(node: Node): Tree = Tree(nodes :+ node)

  def classify(testVec: Vector[Int], featLabels: Array[String], cur: Array[Node] = nodes): String = search(testVec, featLabels, nodes)

  @tailrec
  private def search(testVec: Vector[Int], featLabels: Array[String], cur: Array[Node]): String = {
    cur.find { node =>
      node.isLeaf || testVec(featLabels.indexOf(node.key)).toString == node.value.toString
    } match {
      case None => "Fail to classify."
      case Some(node) if node.isLeaf => node.value.toString
      case Some(node) => search(testVec, featLabels, node.children)
    }
  }
}

case class Node(key: String, value: Any, children: Array[Node] = Array.empty) {

  val isLeaf = children.isEmpty

  override def toString =
    if (children.isEmpty) s" -> $value[Leaf]" else s"{$key : $value ${children.map(_.toString).mkString(",")}}"
}

object Tree {

  case class Row(data: Array[Int], label: String)

  case class InformationGain(featureIdx: Int, infoGain: Double)

  def calcShannonEnt(dataSet: Array[Row]) = {

    val labelCounts = dataSet.foldLeft(Map.empty[String, Int]) { (map, row) =>
      map + (row.label -> (map.getOrElse(row.label, 0) + 1))
    }
    val numEntries = dataSet.size
    labelCounts.foldLeft(0.0) { (state, count) =>
      val prob = labelCounts(count._1).toDouble / numEntries
      state - prob * (log(prob) / log(2))
    }
  }

  def splitDataSet(dataSet: Array[Row], axis: Int, value: Int) = dataSet.filter(_.data(axis) == value)

  def chooseBestFeatureToSplit(dataSet: Array[Row]) = {

    val numEntries = dataSet.size
    val numFeatures = dataSet.head.data.size
    val baseEntropy = calcShannonEnt(dataSet)

    (0 until numFeatures).foldLeft(InformationGain(-1, 0.0)) { (curBest, cur) =>
      val uniqueVals = dataSet.map(_.data(cur)).distinct
      val newEntropy = uniqueVals.foldLeft(0.0) { (ent, value) =>
        val subDataSet = splitDataSet(dataSet, cur, value)
        val prob = subDataSet.size / numEntries.toDouble
        ent + prob * calcShannonEnt(subDataSet)
      }
      val infoGain = baseEntropy - newEntropy
      if (infoGain > curBest.infoGain) InformationGain(cur, infoGain) else curBest
    }
  }

  def majorityCnt(classList: Array[String]): String =
    classList.foldLeft(Map.empty[String, Int]) { (state, x) =>
      state + (x -> (state.getOrElse(x, 0) + 1))
    }.toArray.sortBy(_._2).reverse.head._1

  private def remove(num: Int, list: Array[String]) = list diff Array(num)

  def apply(dataSet: Array[Row], labels: Array[String]): Tree = createTree(dataSet, labels)

  private def createTree(dataSet: Array[Row], labels: Array[String], cur: Tree = Tree(), value: Int = -1): Tree = {
    val classList = dataSet.map(_.label)
    if (classList.distinct.size == 1) cur <<- Node(value.toString, classList(0)) // all label is equal
    else if (dataSet.head.data.isEmpty) cur <<- Node(value.toString, majorityCnt(classList)) // no more feature
    else {
      val bestFeat = chooseBestFeatureToSplit(dataSet).featureIdx
      val subLabels = remove(bestFeat, labels)
      val uniqueFeatValues = dataSet.map(_.data(bestFeat)).distinct
      uniqueFeatValues.foldLeft(cur) { (state, value) =>
        val subTree = createTree(splitDataSet(dataSet, bestFeat, value), subLabels, cur, value)
        state <<- Node(labels(bestFeat), value.toString, subTree.nodes)
      }
    }
  }
}


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值