由于spark中默认支持原生矩阵格式的输入,但实际中我们经常碰到的是稀疏的数据集,因此这里我实现了一个与libsvm输入格式相同的logistic回归,刚接触scala和spark,代码写的还不够简洁,还请各位指点。
代码如下:
package spark.ml.classification import java.util.Random import scala.collection.mutable.HashMap import scala.io.Source import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD; import org.apache.spark.util.Vector import java.lang.Math import org.apache.spark.broadcast.Broadcast import spark.ml.utils.SparserVector object SparseLR { val labelNum = 2; // 类别数 val dimNum = 124; // 维度 val iteration = 10; // 迭代次数 val alpha = 0.1 // 迭代步长 val lambda = 0.1 val rand = new Random(42) var w = Vector(dimNum, _ => rand.nextDouble) //用随机数初始化参数 /** * 定义一个数据点 */ case class DataPoint(x: SparserVector, y: Int) /** * 解析一个训练样本,构造DataPoint结构 * @param 训练样本 */ def parsePoint(line: String): DataPoint = { var features = new SparserVector(dimNum) val field