spark mllib源码分析之L-BFGS(二)

本文深入探讨Spark MLlib中L-BFGS优化器的实现,详细分析训练数据结构、损失函数、State、近似逆海森矩阵以及训练过程,包括下降方向计算、步长确定、权重调整和海森矩阵更新。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

相关文章
spark源码分析之L-BFGS(一)
线搜索
spark正则化
spark mllib源码分析之OWLQN
其他源码分析文章
spark源码分析之DecisionTree与GBDT
spark源码分析之随机森林(Random Forest)

4.4. optimize

我们的optimizer使用的是LBFGS,其optimize函数

  override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
    val (weights, _) = LBFGS.runLBFGS(
      data,
      gradient,   //LogisticGradient
      updater,    //SquaredL2Updater
      numCorrections,  //default 10
      convergenceTol,   //default 1E-6
      maxNumIterations,  //default 100
      regParam,          //0.0
      initialWeights)
    weights
  }

其默认参数都封装在mllib的LBFGS中,实际的训练过程在object LBFGS的runLBFGS函数中

4.4.1. 训练使用的数据结构

4.4.1.1. 损失函数

首先将loss和gradient的计算封装成CostFun类,方便在LBFGS迭代过程中计算

  /**
   * CostFun implements Breeze's DiffFunction[T], which returns the loss and gradient
   * at a particular point (weights). It's used in Breeze's convex optimization routines.
   */
  private class CostFun(
    data: RDD[(Double, Vector)],
    gradient: Gradient,
    updater: Updater,
    regParam: Double,
    numExamples: Long) extends DiffFunction[BDV[Double]] {

    override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
      // Have a local copy to avoid the serialization of CostFun object which is not serializable.
      val w = Vectors.fromBreeze(weights)
      val n = w.size
      val bcW = data.context.broadcast(w)
      val localGradient = gradient

      val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))(
      //executor ops,计算每个partition上的grad和loss,具体参见treeAggregate的用法
          seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
            val l = localGradient.compute(
              features, label, bcW.value, grad)
            (grad, loss + l)
          },
          //driver ops,计算所有分区返回结果
          combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
            axpy(1.0, grad2, grad1)
            (grad1, loss1 + loss2)
          })

      /**
       * regVal is sum of weight squares if it's L2 updater;
       * for other updater, the same logic is followed.
       */
       //计算loss
      val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2

      val loss = lossSum / numExamples + regVal
      /**
       * It will return the gradient part of regularization using updater.
       *
       * Given the input parameters, the updater basically does the following,
       *
       * w' = w - thisIterStepSize * (gradient + regGradient(w))
       * Note that regGradient is function of w
       *
       * If we set gradient = 0, thisIterStepSize = 1, then
       *
       * regGradient(w) = w - w'
       *
       * TODO: We need to clean it up by separating the logic of regularization out
       *       from updater to regularizer.
       */
      // The following gradientTotal is actually the regularization part of gradient.
      // Will add the gradientSum computed from the data with weights in the next step.
      //计算gradient
      val gradientTotal = w.copy
      axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal)

      // gradientTotal = gradientSum / numExamples + gradientTotal
      axpy(1.0 / numExamples, gradientSum, gradientTotal)

      (loss, gradientTotal.asBreeze.asInstanceOf[BDV[Double]])
    }
  }

4.4.1.2. State

对迭代过程中的参数进行简单封装,放在State中

 /**
   * Tracks the information about the optimizer, including the current point, its value, gradient, and then any history.
   * Also includes information for checking convergence.
   * @param x the current point being considered
   * @param value f(x)
   * @param grad f.gradientAt(x)
   * @param adjustedValue  f(x) + r(x), where r is any regularization added to the objective. For LBFGS, this is f(x).
   * @param adjustedGradient f'(x) + r'(x), where r is any regularization added to the objective. For LBFGS, this is f'(x).
   * @param iter what iteration number we are on.
   * @param initialAdjVal f(x_0) + r(x_0), used for checking convergence
   * @param history any information needed by the optimizer to do updates.
   * @param fVals the sequence of the last minImprovementWindow values, used for checking if the "value" isn't improving
   * @param numImprovementFailures the number of times in a row the objective hasn't improved, mostly for SGD
   * @param searchFailed did the line search fail?
   */
  case class State(x: T,
                   value: Double, grad: T,
                   adjustedValue: Double, adjustedGradient: T,
                   iter: Int,
                   initialAdjVal: Double,
                   history: History,
                   fVals: IndexedSeq[Double] = Vector(Double.PositiveInfinity),
                   numImprovementFailures: Int = 0,
                   searchFailed: Boolean = false)

这里的x就是weight,value对应loss,grad对应梯度,history是海森矩阵。

4.4.1.3. ApproximateInverseHessian

默认m=10,使用近10次近似计算,建议3到7;memStep和memGradDelta都是空

case class ApproximateInverseHessian[T](m: Int,
        private[LBFGS] val memStep: IndexedSeq[T] = IndexedSeq.empty,
        private[LBFGS] val memGradDelta: IndexedSeq[T] = IndexedSeq.empty)
        (implicit space: MutableInnerProductModule[T, Double])

L-BFGS计算迭代方向主要的实现是定义了*算子,之前介绍过,但当时理解错了,这里重新介绍下

  def *(grad: T) = {
    //计算D0
     val diag = if(historyLength > 0) {
       val prevStep = memStep.head
       val prevGradStep = memGradDelta.head
       val sy = prevStep dot prevGradStep
       val yy = prevGradStep dot prevGradStep
       if(sy < 0 || sy.isNaN) throw new NaNHistory
       sy/yy
     } else {
       1.0
     }

     val dir = space.copy(grad)
     val as = new Array[Double](m)
     val rho = new Array[Double](m)

     for(i <- 0 until historyLength) {
       rho(i) = (memStep(i) dot memGradDelta(i))
       as(i) = (memStep(i) dot dir)/rho(i)
       if(as(i).isNaN) {
         throw new NaNHistory
       }
       axpy(-as(i), memGradDelta(i), dir)
     }

     dir *= diag

     for(i <- (historyLength - 1) to 0 by (-1)) {
       val beta = (memGradDelta(i) dot dir)/rho(i)
       axpy(as(i) - beta, memStep(i), dir)
     }

     dir *= -1.0
     dir
    }
  }

这里memStep对应si,memGradDelta对应yi,diag是每轮的初始值,算法介绍中有其计算方式,as是alpha,算法中的rho与这里是倒数关系。dir是要返回结果变量,在后向循环中可以认为是q,在前向循环中是r。注意到这里第一轮for训练是从0到historyLength,第二轮是从historyLength到0,与算法的次序正好相反,这是因为在memStep和memGradDelta中,最新的值是存在最前面的(insert,s(k), s(k-1), …, s(0),在update函数中可以看到),在算法中最新的值是往后放的(append, s(0), s(1), …, s(k))。算法第一轮for的使用次序应该是从sk到s0,对应到这里就应该是s0到sk,因此index的次序是反的。
矩阵的更新

    def updated(step: T, gradDelta: T) = {
      val memStep = (step +: this.memStep) take m
      val memGradDelta = (gradDelta +: this.memGradDelta) take m

      new ApproximateInverseHessian(m, memStep,memGradDelta)
    }

可以看到是插入之后取前m个

4.4.2. 训练

4.4.2.1. adjustFunction

入参是CostFun,这里返回CachedDiffFunction

  /** Calculates both the value and the gradient at a point */
  def calculate(x:T):(Double,T) = {
    var ld = lastData
    if (ld == null || x != ld._1) {
      val newData = obj.calculate(x)
      ld = (copy(x), newData._1, newData._2)
      lastData = ld
    }

    val (_, v, g) = ld
    v -> g
  }

其实就是记住上次的结果,如果本次x与上次相同,就可以直接返回结果

4.4.2.2. initialState

初始化State

protected def initialState(f: DF, init: T) = {
    //x是初始化的weight
    val x = init
    //LBFGS.ApproximateInverseHessian
    val history = initialHistory(f,init)
    //用初始weight调用损失函数计算loss和gradient
    val (value, grad) = calculateObjective(f, x, history)
    //adjust函数这里直接返回了loss和gradient了
    val (adjValue,adjGrad) = adjust(x,grad,value)
    //返回第一次计算的状态
    State(x,value,grad,adjValue,adjGrad,0,adjValue,history)
  }
4.4.2.3. iterations
4.4.2.3.1. chooseDescentDirection

计算下降方向,实际调用的是ApproximateInverseHessian的*算子,前面有介绍

protected def chooseDescentDirection(state: State, fn: DiffFunction[T]):T = {
    state.history * state.grad
  }
4.4.2.3.2. determineStepSize

使用线搜索方法确定最优步长,之前的文章有介绍

protected def determineStepSize(state: State, f: DiffFunction[T], dir: T) = {
    val x = state.x
    val grad = state.grad
    //偏函数,先把x和dir放进去,后面调用的时候只需要传入alpha,就可以计算f(x+d*alpha)
    val ff = LineSearch.functionFromSearchDirection(f, x, dir)
    //使用强Wolfe线搜索,在之前的文章有介绍
    val search = new StrongWolfeLineSearch(maxZoomIter = 10, maxLineSearchIter = 10) // TODO: Need good default values here.
    val alpha = search.minimize(ff, if(state.iter == 0.0) 1.0/norm(dir) else 1.0)

    if(alpha * norm(grad) < 1E-10)
      throw new StepSizeUnderflow
    alpha
  }
4.4.2.3.3. 调整

根据优化方向和步长计算weight

val x = takeStep(state,dir,stepSize)
protected def takeStep(state: State, dir: T, stepSize: Double) = state.x + dir * stepSize

根据新得到的weight,用损失函数计算loss及梯度

val (value,grad) = calculateObjective(adjustedFun, x, state.history)
protected def calculateObjective(f: DF, x: T, history: History): (Double, T) = {
    f.calculate(x)
}

adjust函数直接返回新得到的loss和gradient,adjValue等于上面的loss,adjGrad等于gradient

val (adjValue,adjGrad) = adjust(x,grad,value)
def adjust(newX: T, newGrad: T, newVal: Double):(Double,T) = (newVal,newGrad)

根据上面计算得到的loss和gradient,与上一轮比较,计算相对的improvement

val oneOffImprovement = (state.adjustedValue - adjValue)/
  (state.adjustedValue.abs max adjValue.abs max 1E-6 * state.initialAdjVal.abs)
4.4.2.3.4. 更新

更新海森矩阵

 protected def updateHistory(newX: T, newGrad: T, newVal: Double,  f: DiffFunction[T], oldState: State): History = {
    //(si,yi)
    oldState.history.updated(newX - oldState.x, newGrad :- oldState.grad)
}

在state中记录最近的loss,次数由minImprovementWindow决定

val newAverage = updateFValWindow(state, adjValue)
protected def updateFValWindow(oldState: State, newAdjVal: Double):IndexedSeq[Double] = {
    val interm = oldState.fVals :+ newAdjVal
    if(interm.length > minImprovementWindow) interm.drop(1)
    else interm
}

构造新的state

var s = State(x,value,grad,adjValue,adjGrad,state.iter + 1,
state.initialAdjVal, history, newAverage, 0)

x是新一轮计算的weight,value是新的loss,grad是新梯度,adjValue在这里一直是等于value,adjGrad也是一直等于grad;算法的迭代次数加1,这里相当于算法中的k;state.initialAdjVal一直是0(来自initialState);history是新计算得到海森矩阵;newAverage记录了最近minImprovementWindow(默认为0)次的loss;numImprovementFailures此处设为0
计算本轮迭代是否有改善

val improvementFailure = (state.fVals.length >= minImprovementWindow &&
    state.fVals.nonEmpty && 
    state.fVals.last > state.fVals.head * (1-improvementTol))
if(improvementFailure)
    s = s.copy(fVals = IndexedSeq.empty, 
        numImprovementFailures = state.numImprovementFailures + 1)

改善失败的条件是至少迭代了minImprovementWindow轮,并且本轮相对初轮提升小于improvementTol。如果失败,将之前记录的loss清空,state.numImprovementFailures加1,上面的新state的构造时,numImprovementFailures是设为0的,因此这里应该是连续改善失败的累加,一旦成功improve,会被清0

4.4.2.3.4. 截止

这里是无限代跌的,直到有异常抛出,第一次异常会重置海森矩阵,第二次异常才会退出

} catch {
    case x: FirstOrderException if !failedOnce =>
        failedOnce = true
        logger.error("Failure! Resetting history: " + x)
        state.copy(history = initialHistory(adjustedFun, state.x))
    case x: FirstOrderException =>
        logger.error("Failure again! Giving up and returning. Maybe the objective is just poorly behaved?")
        state.copy(searchFailed = true)
}

根据截止状态,返回截止原因

def convergedReason:Option[ConvergenceReason] = {
    if (iter >= maxIter && maxIter >= 0)
        Some(FirstOrderMinimizer.MaxIterations)
    else if (!fVals.isEmpty && (adjustedValue - fVals.max).abs <= tolerance * initialAdjVal)
        Some(FirstOrderMinimizer.FunctionValuesConverged)
    else if (numImprovementFailures >= numberOfImprovementFailures)
        Some(FirstOrderMinimizer.ObjectiveNotImproving)
    else if (norm(adjustedGradient) <= math.max(tolerance * adjustedValue.abs, 1E-8))
        Some(FirstOrderMinimizer.GradientConverged)
    else if (searchFailed)
        Some(FirstOrderMinimizer.SearchFailed)
    else
        None
}

4.4.2. 结果返回

判断是否真正收敛,返回weight和loss

var state = states.next()
while (states.hasNext) {
    lossHistory += state.value
    state = states.next()
}
lossHistory += state.value
//收敛原因为FunctionValuesConverged或GradientConverged
if (!state.actuallyConverged) {
    logWarning("LBFGS training finished but the result " +
        s"is not converged because: ${state.convergedReason.get.reason}")
}

val weights = Vectors.fromBreeze(state.x)
val lossHistoryArray = lossHistory.result()
logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format(
      lossHistoryArray.takeRight(10).mkString(", ")))

(weights, lossHistoryArray)

4.5. 收尾

根据是否有截距,获取截距和真正的weight;如果输入进行了scaling,训练特征值与weight计算 wifi/stdi ,训练时是对特征scaling,返回时(预测时),特征值不需要变换的情况下,相当于 (wi/stdi)fi

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值