使用深度学习进行生存分析

本文转自:使用深度学习进行生存分析

相关资源

原论文地址:here

论文中使用的深度生存分析库:DeepSurv,是基于Theano 和 Lasagne库实现的,支持训练网络模型,预测等功能。

考虑到DeepSurv库中存在着一些错误以及未实现的功能,博主使用目前主流的深度学习框架Tensorflow实现了深度生存分析库:TFDeepSurv。欢迎有兴趣的同学Star和Fork,指出错误,相互交流!

TFDeepSurv简介:基于tensorflow的深度生存分析框架,经过模拟数据和真实数据的测试。支持生存分析数据事件时间出现ties的建模,自定义神经网络结构及参数,可视化训练过程,输入训练数据特征重要性分析,病人生存函数的估计。还有支持使用科学的贝叶斯超参数优化方法来调整网络参数。

博主有空会给出TFDeepSurv各个功能实现参考的源论文!

前言

本文主要的目的为了介绍深度学习是如何运用到生存分析中的,包括其基本原理。然后介绍目前实现了利用深度学习进行生存分析的开源软件包 DeepSurv,它实现了生存分析模型,使用Deep Neural Networks来训练学习参数,并且还实现了风险人群的划分。

还是一样的,强烈建议你去读一下原论文DeepSurv: Personalized Treatment Recommender System Using A Cox Proportional Hazards Deep Neural Network.,相信你会收获很大,至少比看我的好一万倍。写这篇文章的动力有几点,一是不想自己学过的知识什么很快就忘了,感觉记录一下比较重要(博主比较蠢),当作是看论文的笔记吧;二是看完文章之后,觉得我们平时还是要多思考,论文里的思想其实也不是完全原创的,神经网络不是,生存分析cox比例风险模型1972年就有了,但是别人就能洞察到使用深度学习的思想去学习COX模型中需要估计的参数,个人觉得这是一个有科学素养的人才能做到的吧;三是博主在为了PR收集生存分析资料的时候,深感不易,这方面的中文的介绍很少,所以为了方便大家的交流讨论,还是写一下吧。

博主知识水平有限,不吝赐教!欢迎提出错误!

问题来源

假设你已经知道了生存分析主要是在做哪些工作。我们都知道在进行生存分析的时候,有这么几种方法:

  • 参数法:当生存时间符合某一个已知分布时,知道了分布函数,那么剩下的就是求解该分布的参数了。
  • 非参数法:用KM估计去求生存函数,作生存曲线,这里面不涉及任何参数,主要思想就是频率代替概率。
  • 半参数法:也就是使用COX比例风险模型来求生存函数,这个也是本文的重点。

关于COX比例风险模型是怎么提出的,这个是1972年前辈的智慧,本文不打算介绍,这里给出一个链接:hazard-curve,可以帮助你快速了解生存分析和生存函数以及风险曲线的数学定义,然后你就可以去看COX比例风险模型是怎么提出来的了。确保自己懂了COX比例风险模型的原理,可以问自己几个问题:比例两个字是体现在那个地方?为什么风险函数会是 h ( t ) = h 0 ( t ) ⋅ e θ ⋅ x h(t) = h_0(t)\cdot e^{\theta \cdot x} h(t)=h0(t)eθx 这种形式?

COX比例风险模型,直接给出了风险函数的数学表达式(假设你已经学会了懂了其背后的数学原理):
h ( t ) = h 0 ( t ) ⋅ e θ ⋅ x h(t) = h_0(t)\cdot e^{\theta \cdot x} h(t)=h0(t)eθx
其中, θ = ( θ 1 , … , θ m ) θ=(θ_1,…,θ_m) θ=(θ1,,θm)是线性模型的系数或未知参数, h 0 ( t ) h_0 (t) h0(t)是基准风险函数。 e θ x e^{θx} eθx描述了患者观察到回归变量 x x x时的死亡风险比例。对∀i∈N,θ_i>0,表示该协变量是危险因素,越大使生存时间越短。 ∀i∈N,θ_i < 0表示该协变量是保护因素,越大使得生存时间越长。

现在需要去求取参数 θ \theta θ,其思想就是偏似然估计法。假定在某死亡时间没有重复事件发生,设 t 1 < t 2 < ⋯ < t k t_1<t_2<⋯<t_k t1t2tk 表示在观察数据中有 k k k个不同的死亡事件。设 x i x_i xi的观察协变量。设 R ( t i ) R(t_i) R(ti)时间仍然处于观察研究的个体集合。则风险函数 h ( t ) h(t) h(t)的参数估计可以用以下偏似然概率估计方法:
p l ( θ ) = ∏ i = 1 k e θ x i ∑ j ∈ R ( t i ) e θ x j pl(\theta) = \prod_{i=1}^{k}\frac{e^{\theta x_i}}{\sum_{j \in R(t_i)}e^{\theta x_j}} pl(θ)=i=1kjR(ti)eθxjeθxi

其中 q i = e θ x i ∑ j ∈ R ( t i ) e θ x j q_i = \frac{e^{\theta x_i}}{\sum_{j \in R(t_i)}e^{\theta x_j}} qi=jR(ti)eθxjeθxi个死亡个体,其死亡条件概率。其实通俗一点的解释就是:我已经观察到时间 t i t_i ti了,现在有一群人,我可以利用风险公式 h ( t i ) h(t_i) h(ti) 求出这群人每一个个体的死亡风险,其中有一个人恰好在 t i t_i ti时刻发生了死亡事件,那么这个人的死亡条件概率就写为:
q i = h i ( t i ) ∑ j ∈ R ( t i ) h j ( t i ) = e θ x i ∑ j ∈ R ( t i ) e θ x j q_i = \frac{h_i(t_i)}{\sum_{j \in R(t_i)}h_j(t_i)} = \frac{e^{\theta x_i}}{\sum_{j \in R(t_i)}e^{\theta x_j}} qi=jR(ti)hj(ti)hi(ti)=jR(ti)eθxjeθxi

现在就是利用偏似然估计的思想,将所有死亡时刻 t 1 , t 2 , . . . , t k t_1,t_2,...,t_k t1,t2,...,tk的死亡条件概率相乘,求取是这个乘积最大的参数值 θ \theta θ,把它作为估计量。

注意COX模型给出的前提:假设协变量的总影响可以表示为它们的线性组合。例如,我评价一个人的颜值 v v v,你告诉我颜值可以这么计算 v = 2 x 1 + 9 x 2 + 1.3 x 3 v=2x_1 + 9x_2+1.3x_3 v=2x1+9x2+1.3x3 x 1 , x 2 , x 3 x_1,x_2,x_3 x1,x2,x3表示眼睛大小,脸型,鼻子高度(当然,这里是打个比方QAQ)。事实上,很多情况下,协变量的线性组合不能准确衡量它们对某个目标值的影响! 关于这点例子很多(例XOR问题),就不一一介绍了。

问题根源就是在 θ ⋅ x \theta \cdot x θx,我们把它记为 r r r。那么我们可不可以把它表示为非线性组合呢?但是好像它的数学表达式公式不太好给出,无论我们怎么表示 r r r,其目标都是使 p l ( θ ) pl(\theta) pl(θ)最小。这个时候,神经网络的作用就显现出来了,它对于表示一组协变量的非线性组合简直太擅长了!假设网络的输入为一组协变量 x = ( x 1 , x 2 , . . . , x n ) x=(x_1,x_2,...,x_n) x=(x1,x2,...,xn),那么网络的输出表示为 r ^ w , b \hat r_{w,b} r^w,b为神经网络的参数。然后,损失函数就很显而易见了:
L = − l o g ( p l ( r ^ w , b ) ) = − l o g ( ∏ i = 1 k e r ^ w , b i ∑ j ∈ R ( t i ) e r ^ w , b j ) L = -log(pl(\hat r_{w,b})) = -log(\prod_{i=1}^{k}\frac{e^{\hat r_{w,b}^i}}{\sum_{j \in R(t_i)}e^{\hat r_{w,b}^j}}) L=log(pl(r^w,b))=log(i=1kjR(ti)er^w,bjer^w,bi)
h ( x ) h(x) h(x)表示为网络的输出,那么上式可以化简为:
L = − ∑ i = 1 k [ h i ( x ) − l o g ( ∑ j ∈ R ( t i ) e h j ( x ) ) ] L = -\sum_{i=1}^{k}[h_i(x)-log(\sum_{j\in R(t_i)}e^{h_j(x)})] L=i=1k[hi(x)log(jR(ti)ehj(x))]
剩下的工作就交给神经网络去训练样本学习到这样的非线性组合 &lt; r ^ w , b = f ( x , θ ) &lt;\hat r_{w,b}=f(x,\theta) <r^w,b=f(x,θ)。其实思路还是很好懂的嘛。

DeepSurv网络框架实现

DeepSurv 的工作就是实现了上面介绍的所有内容(最重要的是损失函数),还实现了一些其他的功能(比如划分风险人群)。下面介绍一下这个框架的实现。

这里是DeepSurv类下面定义的方法:

class DeepSurv:
    def __init__()
    # 计算Loss function值
    def _negative_log_likelihood()
    # 得到当前网络的loss值同时更新网络参数
    def _get_loss_updates()
    # 得到可调用的函数:训练集上,网络进行一次正向和反向传播
    #                验证集上,一遍正向传播,计算Loss function值
    def _get_train_valid_fn()
    # 计算评估指标:C Index
    def get_concordance_index()
    def _standardize_x()
    def prepare_data()
    def train()
    def to_json()
    def save_model()
    def save_weights()
    def load_weights()
    # 得到网络的输出值
    def risk()
    def predict_risk()
    # 划分风险人群
    def recommend_treatment()
    def plot_risk_surface()

初始化函数:初始化网络结构,并且记录一些参数

def __init__(self, n_in,
    learning_rate, hidden_layers_sizes = None,
    lr_decay = 0.0, momentum = 0.9,
    L2_reg = 0.0, L1_reg = 0.0,
    activation = "rectify",
    dropout = None,
    batch_norm = False,
    standardize = False,
    ):

按照给定hidden_layers_sizes的搭建指定的网络结构:
输入层:network = lasagne.layers.InputLayer(shape=(None,n_in),input_var = self.X)
隐藏层:network = lasagne.layers.DenseLayer(network, num_units = n_layer, nonlinearity activation_fn, W = W_init)(参数决定该层是否dropout或者BatchNorm)
输出层:network = lasagne.layers.DenseLayer(network, num_units = 1, nonlinearity = lasagne.nonlinearities.linear, W = lasagne.init.GlorotUniform())

训练函数:在给定的训练数据上进行训练,并且在验证集上进行评估

def train(self,
    train_data, valid_data= None,
    n_epochs = 500,
    validation_frequency = 250,
    patience = 2000, improvement_threshold = 0.99999, patience_increase = 2,
    logger = None,
    update_fn = lasagne.updates.nesterov_momentum,
    verbose = True,
    **kwargs):

训练函数里的内容就是通用的一套了:

  • 准备好训练数据
  • 每个epoch迭代训练网络
  • 计算Loss,反向传播更新网络参数

具体地,源代码里还有很多细节的地方,自己亲身学习一下还不错啊!

原博客作者另外还写了一篇实战的总结:【论文笔记】Deep Survival: A Deep Cox Proportional Hazards Network ,值得借鉴。

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值