FM(Factorization Machine)
模型描述
在点击率预估等任务中,10维的类别型特征做onehot编码后变成1000维特征,绝大多数特征取值为0,即特征稀疏。然后,某些稀疏特征经过关联得到的关联特征,例如“化妆品”类商品和“女”性,与label之间的相关性会提高。因此对于一个具有n个特征的样本,模型表示
y^=w0+∑i=0nwixi+∑i=1n−1∑j=i+1nvijxixj
但是xi和xj本来就很多为0,xixj就更多为0,而有n(n−1)2个vij要训练,因此训练样本的不足很容易导致vij不准确。在W=VVT,其中W∈Rn×n,V∈Rn×K的启发下,模型变成
y^=w0+∑i=0nwixi+∑i=1n−1∑j=i+1n<vi,vj>xixj
其中,xi∈R,代表样本x第
模型的参数共有1+n+n×K个,而所有第{j|xixj≠0}维特征都可以用来训练vi,很大程度上避免了数据稀疏性的影响。
模型求解
对于回归问题,优化目标是MSE(Mean Square Error)时,对N个训练样本,优化问题描述为
min J(θ)=1N∑i=1N(yi−y^i)2
由于点击率预估的问题中,样本数量很大,因此采用随机梯度下降方法。优化目标变成在第i次迭代中,让当前样本的
min J(θ)=(yi−y^i)2
迭代方程
θi=θi−1−α∂J(θ)∂θ
其中
∂J(θ)∂θ=−2(yi−y^i)∂y^i∂θ
而
∂y^∂θ=⎧⎩⎨⎪⎪1xixi∑nj=1vjkxj−vikx2iθ=w0θ=wiθ=vik
其中i代表第
虽然直观求解
复杂度为O(Kn2),但是
∑i=1n−1∑j=i+1n(vixi)(vjxj)T=12(∑i=1nvixi)(∑i=1nvixi)T−12∑i=1nx2ivivTi
因此
∂y^∂vi=xi(∑j=1nvjxj)T−vTix2i
其中计算∑nj=1vjxj的复杂度为O(Kn);之后计算每个∂y^∂vi,i∈{1,2,…,n}的复杂度是O(1),wi和w0也是,即计算每个参数的梯度的复杂度是O(1);得到梯度后更新每个参数的复杂度是O(1);模型参数一共Kn+n+1个,因此FM在训练时的复杂度为O(Kn)。综上,FM可以在线性时间内训练和预测,是非常高效的模型。
FFM(Field aware factorization machine)
对于一个具有m个特征的样本,经过one-hot编码后,特征维数变为n,对xi,i∈{1,2,…,n},隐向量vi∈Rm×K,最初设想中
∑i=1n−1∑j=i+1nvijxixj⇒∑i=1n−1∑j=i+1n<vi,mj,vj,mi>xixj
其中mj代表编码后的第j维特征在编码前属于的特征维数(所属field)。FFM的二次参数有nmK个,远多于FM模型的nK个。由于隐向量与field相关,FFM二次项不能化简,预测复杂度是