超分之RVRT

本文介绍了一种名为RVRT的新型视频超分辨率(VSR)模型,它结合了循环网络和Transformer的优点。RVRT通过将序列划分为clips,并在每个clips内部使用Transformer进行并行处理,同时使用Guided Deformable Attention (GDA)进行视频对齐,以解决长距离特征捕获和计算效率的问题。实验表明,RVRT在视频超分辨率任务上实现了最优的性能与计算资源之间的平衡。

在这里插入图片描述

这篇文章是今年6月份刚出的VSR文章,其推出了一种将循环网络结构和Transformer结构相结合的一种模型——Recurrent Video Restoration Transformer(RVRT)。RVRT完成了Recurrent-based模型和Vision-Transformer模型之间的平衡,从而实现了模型大小、计算效率、表现力之间的trade-off。

Note:

  1. 本文只探究VSR部分。

参考文档:
源码

Abstract

  1. VSR方法大致分为Sliding-windows、Recurrent、Transformer三大类,其中后面两个是目前已被证明更好的方法。现有的视频超分陷入了2个极端,即要么就是Recurrent-based模型,要么就是Transformer-based模型。
  2. Recurrent-based模型例如RLSPBasicVSRBasicVSR++等通过一帧接一帧的方式对LR图像进行超分。虽然基于RNN会使得模型共享从而产生较小的模型量;此外每次处理一帧也使得运算效率很高;但是循环网络模型天然会遭受梯度爆炸/消失、信息衰减、噪声放大等问题,因此其无法捕捉长距离上的特征信息。
  3. Vision Transformer-based模型例如VSRTVRT等对多帧同时并行提取特征信息,但是会造成较大的模型参数和显存消耗。

基于两种模型的优缺点,本文作者提出了一种将Recurrent-based和Vision Transformer=based结合的模型——RVRT。

  1. RVRT将整个序列 L R S LRS LRS分成一段段,每一段作为一个clips,即整个序列由 T N \frac{T}{N} NT个clips组成,每个clips里面有 N N N帧。由于 N < T N < T N<T,故RVRT相当于将整个序列缩短了,这样才能弥补Recurrent-based的缺陷。
  2. 每个clips内部使用Transformer并行化处理;clips之间使用Guided Deformable Attention(GDA)来做对齐;而整个模型框架还是基于循环网络的,因此可以说RVRT在局部上使用并行结构来提取特征,全局上使用循环结构,两者相互结合,实现了模型大小参数量、计算效率、表现力的trade-off
  3. RVRT是VRT的升级版,都是同一批作者,二者都适合于所有的视频恢复任务,如超分、去噪、去模糊、去块等。
  4. RVRT在通用的数据集上实现了SOTA的表现力,证明了Recurrent和Transformer结合起来实现视频恢复的方案是可行的。

1 Introduction

首先我们先对目前主流的2种VSR模型的优缺点进行简要阐述:
Recurrent model \colorbox{orange}{Recurrent model} Recurrent model

  1. 优点:因为RNN结构是共享模型的,同一套参数在不同时间段被重复使用,因此其模型参数量较小,计算效率较高。
  2. 缺点:缺乏长距离建模的能力,这是由于当序列长度较长时候,循环结构会发生梯度消失/爆炸现象,以及天然会存在信息丢失和噪声放大问题;此外无法并行化提取所有帧信息也是比较低效的。

Vision Transformer model \colorbox{deepskyblue}{Vision Transformer model} Vision Transformer model

  1. 优点:可以直接并行化对所有帧同时提取特征信息,简单来说就是 T T T帧进去, T T T帧出来;可以直接融合所有帧的特征信息,因为Transformer可以根据注意力机制来挑选一些更有用的特征。
  2. 缺点:模型大、计算复杂度高、训练难度大、显存占用较大。

因此作者打算实现一个同时具备上述2个模型优点的混合模型——RVRT。
RVRT在以Vision Transformer为基础之上引入循环结构来减小模型参数量;此外RVRT将一定数量的相邻帧合并成1个clip,从而降低了输入帧的有效个数,缓解了Recurrent模型无法捕捉长距离特征信息、信息衰减问题。

具体而言,RVRT的核心为3部分:①循环结构;②Recurrent Feature Refinement(Vision Transformer部分);③Guided Deformable Attention。下面我们简要介绍一下这三部分。
Recurrent Structure \colorbox{yellow}{Recurrent Structure} Recurrent Structure
循环结构可以构建长序列在时间维度上的相关性。但为了避免序列太长导致RNN结构天然的问题,RVRT将整个序列分成 T / N T/N T/N个clips,每个clips里面包含了 N N N帧。其循环结构基本和BasicVSR++类似:所有的clips共用同一个网络——RFR,特征传播沿着两个方向进行,只不过RVRT的隐藏状态更大,因为它是一个clip;此外RFR还有 L L L层的传播迭代用来让对齐更加准确。相比传统的Recurrent-based结构,RVRT不容易产生梯度消失/爆炸、过多的信息衰减以及噪声放大问题。

Recurrent Feature Refinement \colorbox{springgreen}{Recurrent Feature Refinement} Recurrent Feature Refinement
RFR可以说是一整个循环结构,也可以说是单独一个特征校正模块,既然前者已经叙说过了,这里就单指代后者。RFR结构主要由Swin-T结构组成,主要用于对clips内的 N < T N < T N<T帧进行特征提取,在空间中提取到所有的信息,由于只在clip内部进行,因此模型的复杂度被有效降低了。

Guided Deformable Attention \colorbox{mediumorchid}{Guided Deformable Attention} Guided Deformable Attention
一般Transformer用于特征提取比较多,自从VRT出现MMA之后,RVRT也出了Transformer用于对齐的功能。GDA主要利用DA的机制来对齐相邻的clips。一来其可以有效降低Transformer的计算量;二来其比传统的CNN-based对齐方式,比如flow-based方法每次对齐只依赖于1个采样点,而DAT可以基于多个采样点来生产对齐结果;三来相比MMA只能基于局部空间建模,DAT可以在全局空间中建模。


小结一下

  1. RVRT可以利用Transformer在帧数较少的clips内部中的局部空间上进行同步并行提取特征信息;此外利用Recurrent结构进行时许上的相关性建模可以减轻模型复杂度;将帧数分成clip-by-clip的形式不但可以减轻梯度消失/爆炸的问题,也可以减轻Transformer计算量。Recurrent+Transformer的结构可以同时占据两者的优势。
  2. RVRT提出了基于Transformer的对齐方式——GDA,其用于clips-to-clips的对齐。
  3. RVRT在一系列benchmark上展现了SOTA的水平,实现了模型大小、显存占用量、runtime和表现力之间的trade-off。

2 Related Work

3 Methodology

3.1 Overall Architecture

在这里插入图片描述
如上图所示就是RVRT的pipeline,其主要由3部分组成:①浅层特征提取模块;②RFR模块;③超分重建模块。
浅层特征提取:对于超分任务由1个卷积层和RSTB(Residual Swin Transformer Blocks)组成,来提取浅层特征。
RFR:RVRT的核心部分,其利用Recurrent机制来做全局范围内的或者说时间维度上的相关性建模以及利用Transformer来做局部空间范围内的相关性建模,从而实现了模型复杂度和表现力的trade-off;此外RFR中使用了GDA来做clips之间的对齐。
超分重建:这部分也是使用了1个卷积层和RSTB组成,外加Pixelshuffle来做上采样生成最终的SR图像。

损失函数:
Charbonnier函数: L = ∣ ∣ I S R − I L R ∣ ∣ 2 2 + ϵ 2 , ( ϵ = 1 0 − 3 ) \mathcal{L} = \sqrt{||I^{SR} - I^{LR}||^2_2 + \epsilon^2},(\epsilon=10^{-3}) L=∣∣ISRILR22+ϵ2 ,(ϵ=103)

3.2 Recurrent Feature Refinement

RFR这个结构和BasicVSR++很像——它的结构主要由双向传播与传播迭代组成;具体结构如下图所示:
在这里插入图片描述
别看它看起来复杂,其实就是个BasicVSR加了 L L L次传播迭代,只不过特征校正和对齐的方式变了。
具体而言,其横向是传播迭代,纵向是双向或者单向的特征传播;此外最重要的一点是,传统的Recurrent-based模型都是基于单张特征图之间的对齐,而RFR则是clips-to-clips之间的对齐。

Note:

  1. 但要注意的是RFR没有coupled-propagation。
  2. 此外BasicVSR系列的对齐使用flow-based或者flow-guided这些CNN-based方法,而RVRT使用的是Transformer-based方法——Guided Deformable Attention来做对齐。
  3. BasicVSR系列的特征校正模块都是使用CNN-Based例如残差块堆积的方式;而RVRT使用Transformer-based方法——MRSTB,一种基于Swin-T的结构,该结构将一个clips里面的 N N N帧同时并行提取特征而不像传统Recurrent-based模型中采取frames-by-frames的样式。
  4. 图中 F t i − 2 F_t^{i-2} Fti2表示第 i − 2 i-2 i2层中第 t t t个clips,其中每个clips内含有 N N N个连续帧。

接下来我们具体展开对RFR模型进行叙述。
设第 i i i层(即共 L L L个RFR块中的第 i i i个)的特征表示为 F i ∈ R T × H × W × C F^i\in\mathbb{R}^{T\times H\times W\times C} FiRT×H×W×C;首先对它进行split为 T N \frac{T}{N} NT个clips,每个clips内含 N N N个连续帧,记为 F i ∈ R T N × N × H × W × C F^i\in\mathbb{R}^{\frac{T}{N}\times N\times H\times W\times C} FiRNT×N×H×W×C,则每个clips可表示为 F 1 i , F 2 i , ⋯   , F T N i ∈ R N × H × W × C F_1^i, F_2^i, \cdots, F^i_{\frac{T}{N}}\in\mathbb{R}^{N\times H\times W\times C} F1i,F2i,,FNTiRN×H×W×C;每个clips内部的 N N N个连续帧表示为 F t , 1 i , F t , 2

#GEE library(multgee) library(dplyr) library(stringr) library(openxlsx) df_long2 <- df_long %>% # 受试者ID列名请替换成你真实的,比如 bian_hao;这里假设叫 bian_hao rename(id = bian_hao) %>% filter(!is.na(StoolType)) %>% mutate( group = factor(group, levels = c("3","6","1")), # 3 为参照组 StoolType = ordered(StoolType), # 确保有序 Timepoint = factor(Timepoint, ordered = TRUE) # 确保 V0 是第一层 ) df_long2 <- df_long2 %>% mutate( age_days_V0_z = scale(age_days_V0), mu_qin_yunqian_bmi_V0_z = scale(mu_qin_yunqian_bmi_V0), yunzhou_V0_z = scale(yunzhou_V0), chu_sheng_ti_zhong_V0_z = scale(chu_sheng_ti_zhong_V0) ) df_long2 <- df_long2 %>% mutate( Timepoint = factor(Timepoint, ordered = FALSE) # 取消有序 ) contrasts(df_long2$Timepoint) <- contr.treatment(levels(df_long2$Timepoint)) # 以 V0 为参照 # ========= 2) 拟合有序 GEE 模型(cumulative logit)========= # LORstr 可选 "uniform", "category.exch";repeated 指定时间变量 fit_ordgee <- ordLORgee( formula = StoolType ~ group * Timepoint + age_days_V0_z + xing_bie_V0 + chu_sheng_ti_zhong_V0_z + fen_mian_fang_shi_V0 + region + mu_qin_yunqian_bmi_V0_z + yunzhou_V0_z + income_cat + muqin_education_V0, id = id, repeated = Timepoint, data = df_long2, link = "logit", LORstr = "uniform", add = 0.5 # 防止某些格为0导致估计不稳定 ) summary(fit_ordgee) #三组间整体检验 ## 1) 从模型中取系数 & 协方差,并确保 b 有名字 s <- summary(fit_ordgee) b <- as.numeric(s$coef[, "Estimate"]) names(b) <- rownames(s$coef) # 关键:给 b 加上行名作为 names V <- if (!is.null(fit_ordgee$robust.variance)) fit_ordgee$robust.variance else vcov(fit_ordgee) coef_names <- names(b) print(coef_names) # 看看真实行名长啥样 ## 2) 通用的整体 Wald 检验函数(无任何外部包) wald_overall <- function(b, V, terms){ idx <- match(terms, names(b)) idx <- idx[!is.na(idx)] if (length(idx) == 0) stop("没有匹配到参数名") R <- diag(length(b))[idx, , drop = FALSE] rb <- R %*% b RVRT <- R %*% V %*% t(R) W <- as.numeric(t(rb) %*% solve(RVRT) %*% rb) df <- length(idx) p <- 1 - pchisq(W, df) data.frame(chi2 = W, df = df, p = p, row.names = NULL) } ## 3) 通用匹配规则(不假设“=”或“.”,也兼容 Time 与 Timepoint) ## - 组主效应:以 "group" 开头,且不含 ":"(不是交互项) terms_group_main <- coef_names[ grepl("^group", coef_names) & !grepl(":", coef_names) ] ## - 时间主效应:以 "Time" 或 "Timepoint" 开头,且不含 ":"(不是交互) terms_time_main <- coef_names[ grepl("^(Time|Timepoint)", coef_names) & !grepl(":", coef_names) ] ## - 组×时间交互:以 "group" 开头,且后面有 ":",并且冒号后面跟的是 Time/Timepoint terms_gxt <- coef_names[ grepl("^group", coef_names) & grepl(":(Time|Timepoint)", coef_names) ] ## 4) 跑整体检验 cat("== Group 主效应 ==\n"); print(wald_overall(b, V, terms_group_main)) cat("== Time 主效应 ==\n"); print(wald_overall(b, V, terms_time_main)) cat("== Group × Time 交互 ==\n"); print(wald_overall(b, V, terms_gxt)) #整体间交互项有显著性差异 # 注意:系数是 log(累积OR),exp() 后就是 累积OR # ========= 3) 提取系数,换算为 OR + 95%CI + P ========= coef_tab <- summary(fit_ordgee)$coef # 计算95%CI(β和OR双尺度),保留更多小数 beta_CI_low <- coef_tab[, "Estimate"] - qnorm(0.975) * coef_tab[, "san.se"] beta_CI_high <- coef_tab[, "Estimate"] + qnorm(0.975) * coef_tab[, "san.se"] out_coef <- data.frame( term = rownames(coef_tab), beta = round(coef_tab[, "Estimate"], 6), se = round(coef_tab[, "san.se"], 6), z = round(coef_tab[, "san.z"], 6), p_value = signif(coef_tab[, "Pr(>|san.z|)"], 6), # beta 尺度的置信区间 beta_CI_low = round(beta_CI_low, 6), beta_CI_high = round(beta_CI_high, 6), # OR 尺度 OR = round(exp(coef_tab[, "Estimate"]), 6), CI_low = round(exp(beta_CI_low), 6), CI_high = round(exp(beta_CI_high), 6) ) %>% mutate( term = str_replace_all(term, "group", "group="), term = str_replace_all(term, "Timepoint", "Time="), term = str_replace_all(term, ":", " × ") ) # 查看结果 print(out_coef) # 保存总体系数表(主效应 + 交互项) write.xlsx(out_coef, "gee_ordinal_overall_coefficients_调整协变量.xlsx", rowNames = FALSE) #与上面的两两每时点比较类似: 更严谨的线性组合(含协方差)做两两每时点的组间比较—— lincom_cov <- function(model, L_named){ s <- summary(model) b <- s$coef[, "Estimate"]; V <- if (!is.null(model$robust.variance)) model$robust.variance else vcov(model) L <- rep(0, length(b)); names(L) <- rownames(s$coef) hit <- intersect(names(L_named), names(L)); L[hit] <- L_named[hit] est <- sum(L * b) se <- sqrt(as.numeric(t(L) %*% V %*% L)) OR <- exp(est); CI <- exp(c(est - 1.96*se, est + 1.96*se)) z <- est / se; p <- 2*pnorm(-abs(z)) c(beta=est, se=se, OR=OR, CI_low=CI[1], CI_high=CI[2], z=z, p_value=p) } # —— 构造“组A vs 组B @ 指定时间点”的对比向量 —— # 记号:A/B ∈ c("1","3","6");tp ∈ levels(df_long2$Timepoint)(如 "V0","V1",...) # 规则:A_vs_B(tp) = [ groupA + groupA:Time_tp ] - [ groupB + groupB:Time_tp ] build_contrast <- function(A, B, tp, coef_names){ # 主效应名(可能不存在:如果该组恰好是 group 的参照,则主效应=0) gA <- paste0("group", A); gB <- paste0("group", B) # 交互项名(在 treatment 对比下应是 "groupX:TimepointVt" 这种) iA <- paste0("group", A, ":Timepoint", tp) iB <- paste0("group", B, ":Timepoint", tp) # 只对存在于模型的项赋值(不存在的等于0) pick <- function(nm) nm[nm %in% coef_names] setNames(c(rep(1, length(pick(c(gA,iA)))), rep(-1, length(pick(c(gB,iB))))), c(pick(c(gA,iA)), pick(c(gB,iB)))) } # —— 批量生成“三组两两 × 各时间点”的结果表 —— pairwise_by_time <- function(model, data, timevar = "Timepoint", groups = c("1","3","6")){ coef_names <- rownames(summary(model)$coef) tps <- levels(data[[timevar]]) out <- list() for(tp in tps){ for(i in 1:(length(groups)-1)){ for(j in (i+1):length(groups)){ A <- groups[i]; B <- groups[j] L <- build_contrast(A, B, tp, coef_names) res <- lincom_cov(model, L) out[[paste(tp, paste0(A," vs ",B), sep=" | ")]] <- data.frame(time=tp, contrast=paste0(A," vs ",B), t(res)) } } } dplyr::bind_rows(out) } # 然后这样运行: pw_tab <- pairwise_by_time(fit_ordgee, data = df_long2, timevar = "Timepoint", groups = c("1","3","6")) pw_tab # 保存“各时间点简单效应OR”表 write.xlsx(pw_tab, "gee_ordinal_simple_effects_by_time.xlsx", rowNames = FALSE) #随时间点变化的两两比较 library(emmeans) library(geepack) # 拟合:ordinal 作为近似连续(常见于临床纵向分析) fit_geeglm <- geeglm( as.numeric(StoolType) ~ group * Timepoint, id = id, data = df_long2, family = gaussian, # 或者poisson, 取决于数据分布 corstr = "exchangeable" ) emm <- emmeans(fit_geeglm, ~ group | Timepoint) pairs(emm, adjust = "bonferroni") library(ggplot2) emm_df <- as.data.frame(pairs(emm, adjust = "bonferroni")) p <- ggplot(emm_df, aes(x = Timepoint, y = estimate, group = contrast, color = contrast)) + geom_line(size = 1.1) + geom_point(size = 3) + geom_hline(yintercept = 0, linetype = "dashed") + labs(y = "Estimated mean difference (Stool consistency score)", x = "Visit (Time point)", title = "Pairwise comparison of stool consistency over time (GEE marginal means)", subtitle = "Negative estimate = Harder stool than reference group") + theme_minimal(base_size = 13) ggsave("GEE_margin_meandiff.png", p, width = 10, height = 5, dpi = 300) ###精确比较谁与母乳组更接近 library(dplyr) library(openxlsx) ## 1) 先把三个对比(1vs3, 6vs3, 1vs6)在每个时间点的OR算出来 ---- coef_names <- rownames(summary(fit_ordgee)$coef) tp_levels <- levels(df_long2$Timepoint) need_terms <- function(gr, tp, coef_names) { main <- paste0("group", gr) inter <- coef_names[grepl(paste0("^group", gr, ":"), coef_names) & grepl(as.character(tp), coef_names, fixed = TRUE)] c(main, inter) } lincom <- function(model, coefs) { coef_tab <- summary(model)$coef b <- coef_tab[, "Estimate"] se_vec <- coef_tab[, "san.se"] all_nm <- names(b) w <- rep(0, length(all_nm)); names(w) <- all_nm set_nm <- intersect(names(coefs), names(w)) w[set_nm] <- coefs[set_nm] est <- sum(w * b) V <- if (!is.null(model$robust.variance)) model$robust.variance else vcov(model) se <- sqrt(as.numeric(t(w) %*% V %*% w)) OR <- exp(est) CI_low <- exp(est - 1.96*se) CI_high <- exp(est + 1.96*se) z <- est / se p <- 2*pnorm(-abs(z)) c(beta=est, se=se, OR=OR, CI_low=CI_low, CI_high=CI_high, z=z, p_value=p) } # 统计每个时间点的样本量用于加权(丢失随访时更稳健) n_by_time <- df_long2 %>% group_by(Timepoint) %>% summarise(n_tp = n_distinct(id), .groups="drop") rows <- list() for (tp in tp_levels) { # 6 vs 3 t6 <- need_terms("6", tp, coef_names); v6 <- setNames(rep(1, length(t6)), t6) r63 <- lincom(fit_ordgee, v6) rows[[paste0("6_vs_3@", tp)]] <- c(group_comp="6 vs 3", time=tp, r63) # 1 vs 3 t1 <- need_terms("1", tp, coef_names); v1 <- setNames(rep(1, length(t1)), t1) r13 <- lincom(fit_ordgee, v1) rows[[paste0("1_vs_3@", tp)]] <- c(group_comp="1 vs 3", time=tp, r13) # 1 vs 6 = (1 vs 3) - (6 vs 3) v16 <- setNames(rep(0, length(coef_names)), coef_names) v16[t1] <- 1 v16[t6] <- -1 r16 <- lincom(fit_ordgee, v16) rows[[paste0("1_vs_6@", tp)]] <- c(group_comp="1 vs 6", time=tp, r16) } simple_OR <- bind_rows(lapply(rows, \(x) as.data.frame(t(x))), .id="contrast") %>% mutate(across(c(beta,se,OR,CI_low,CI_high,z,p_value), as.numeric)) %>% left_join(n_by_time, by = c("time" = "Timepoint")) # 2) 构造“与母乳距离”的指标:D3 = mean_w(|log OR_1vs3|), D6 = mean_w(|log OR_1vs6|) dist_OR <- simple_OR %>% filter(group_comp %in% c("1 vs 3","1 vs 6")) %>% mutate(abs_logOR = abs(beta)) %>% group_by(group_comp) %>% summarise( D = weighted.mean(abs_logOR, w = n_tp), .groups = "drop" ) %>% tidyr::pivot_wider(names_from = group_comp, values_from = D) %>% mutate(delta = `1 vs 3` - `1 vs 6`) # <0 表示 3 更接近母乳;>0 表示 6 更接近 dist_OR # 3) 按 id 聚类自助法(bootstrap)给 delta 置信区间 ---- set.seed(2025) B <- 100 # 可按需要加大 ids <- unique(df_long2$id) boot_delta <- replicate(B, { samp_ids <- sample(ids, replace = TRUE) dat_b <- df_long2 %>% semi_join(tibble(id=samp_ids), by="id") # 重新拟合 ordLORgee fit_b <- ordLORgee( StoolType ~ group * Timepoint, id = id, repeated = Timepoint, data = dat_b, link = "logit", LORstr = "uniform", add = 0.5 ) coef_names_b <- rownames(summary(fit_b)$coef) # 计算1vs3 & 1vs6在各时间点的 |log OR| tp_b <- levels(dat_b$Timepoint) rows_b <- list() for (tp in tp_b) { t6b <- need_terms("6", tp, coef_names_b); v6b <- setNames(rep(1, length(t6b)), t6b) t1b <- need_terms("1", tp, coef_names_b); v1b <- setNames(rep(1, length(t1b)), t1b) v16b <- setNames(rep(0, length(coef_names_b)), coef_names_b); v16b[t1b] <- 1; v16b[t6b] <- -1 r13b <- lincom(fit_b, v1b) r16b <- lincom(fit_b, v16b) rows_b[[tp]] <- data.frame(time=tp, abs_logOR_13 = abs(r13b["beta"]), abs_logOR_16 = abs(r16b["beta"])) } tab_b <- bind_rows(rows_b) # 该次抽样的时间点权重 wt_b <- dat_b %>% group_by(Timepoint) %>% summarise(n_tp = n_distinct(id), .groups="drop") tab_b <- tab_b %>% left_join(wt_b, by = c("time"="Timepoint")) D3_b <- weighted.mean(tab_b$abs_logOR_13, w = tab_b$n_tp) D6_b <- weighted.mean(tab_b$abs_logOR_16, w = tab_b$n_tp) D3_b - D6_b }) ci_OR <- quantile(boot_delta, probs = c(0.025, 0.975), na.rm = TRUE) list(delta_point = dist_OR$delta, delta_CI = ci_OR) # 解释:delta < 0 且 95%CI 不跨 0 => 组3 更接近母乳;若 >0 => 组6 更接近;若跨0 => 难分伯仲 # 可保存 write.xlsx(list( "OR_by_time" = simple_OR, "Distance_OR" = dist_OR, "Bootstrap_delta" = data.frame( delta_point = dist_OR$delta, CI_low = ci_OR[1], CI_high = ci_OR[2] ) ), "closeness_to_breastfed_OR.xlsx", rowNames = FALSE) library(ggplot2) df_delta <- data.frame( group = c("3 vs 1", "6 vs 1"), distance = c(dist_OR$`1 vs 3`, dist_OR$`1 vs 6`) ) ggplot(df_delta, aes(x = group, y = distance, fill = group)) + geom_bar(stat = "identity", width = 0.5) + geom_text(aes(label = round(distance, 2)), vjust = -0.5, size = 5) + geom_hline(yintercept = 0, linetype = "dashed") + labs(title = "Weighted Distance to Breastfed Group", subtitle = "Smaller = more similar to breastfeeding pattern", y = "Weighted |log(OR)|", x = "") + theme_minimal(base_size = 13) #哪一个柱子(配方组)更矮 → 更接近母乳。 #预测概率图绘制:不同组别在各时间点上,排便频率等级的分布或累积概率。 library(dplyr) library(tidyr) library(ggplot2) # ==== 1) 提取模型预测的“线性预测值” ==== # multgee 没有直接的 predict() 方法,但我们可以手动计算预测概率 # 思路:用模型系数、时间点、组别,计算每组每时间点的累积概率 # —— 取系数与阈值 —— coef_tab <- summary(fit_ordgee)$coef b <- coef_tab[, "Estimate"]; names(b) <- rownames(coef_tab) cut <- as.numeric(b[grepl("^beta", names(b))]) # θ1..θ{K-1} coefs <- b[!grepl("^beta", names(b))] # —— 拼每个 组×时间 的线性预测子 η —— pred_grid <- expand.grid( group = levels(df_long2$group), Timepoint = levels(df_long2$Timepoint) ) |> dplyr::rowwise() |> dplyr::mutate( lp = sum(c( 0, coefs[paste0("group", group)], coefs[paste0("Timepoint", Timepoint)], coefs[paste0("group", group, ":Timepoint", Timepoint)] ), na.rm = TRUE) ) |> dplyr::ungroup() # —— 把累积概率 Ck = P(Y<=k) 按公式 plogis(θk - η) 算出来 —— K <- length(cut) + 1L lev_labs <- levels(df_long2$StoolType) # 建议预先手动设定为从“低频”到“高频”的顺序 pred_prob <- pred_grid |> dplyr::rowwise() |> dplyr::do({ lp <- .$lp C <- plogis(cut + lp) # C1..C{K-1} p <- c(C[1], diff(C), 1 - tail(C, 1)) # 长度 K tibble::tibble( group = .$group, Timepoint = .$Timepoint, level = factor(lev_labs, levels = lev_labs, ordered = TRUE), prob = p ) }) |> dplyr::ungroup() # —— 画图(5个等级都会出现,方向由 level 的顺序决定) —— t <- ggplot(pred_prob, aes(x = Timepoint, y = prob, fill = level)) + geom_bar(stat = "identity", position = "stack") + facet_wrap(~ group, nrow = 1) + scale_y_continuous(labels = scales::percent_format(accuracy = 1)) + labs( title = "Predicted distributions of bowel movement frequency levels by feeding group and time", subtitle = "Based on the ordinal GEE (cumulative logit) model; levels ordered from low to high frequency", x = "Timepoint", y = "Predicted probability", fill = "Bowel frequency level" ) + theme_minimal(base_size = 13) ggsave("Predicted.png", t, width = 10, height = 5 , dpi = 300) 检查下这个代码做了什么 以及 对不对
最新发布
11-12
由于没有提供具体的R语言代码,下面给出一个使用`geepack`包进行广义估计方程(GEE)分析有序分类因变量的示例代码,并对代码功能进行解释,同时说明如何判断代码的正确性。 ### 示例代码 ```R # 加载所需的包 library(geepack) library(MASS) # 用于有序分类数据的处理 # 生成示例数据 set.seed(123) n <- 200 # 样本量 id <- rep(1:(n/2), each = 2) # 个体标识 x1 <- rnorm(n) x2 <- rbinom(n, 1, 0.5) # 生成有序分类因变量 y_prob <- plogis(0.5 + 0.8 * x1 + 1.2 * x2) y <- ordered(cut(y_prob, breaks = c(0, 0.3, 0.7, 1), labels = c("low", "medium", "high"))) data <- data.frame(id, x1, x2, y) # 模型拟合 model_gee <- geeglm(y ~ x1 + x2, data = data, id = id, family = "cumulative", corstr = "exchangeable") # 整体检验 anova(model_gee, test = "Wald") # 系数提取 coef(model_gee) # 两两比较(这里简单以x2不同水平下的效应比较为例) newdata1 <- data.frame(x1 = mean(x1), x2 = 0) newdata2 <- data.frame(x1 = mean(x1), x2 = 1) pred1 <- predict(model_gee, newdata1, type = "response") pred2 <- predict(model_gee, newdata2, type = "response") comparison <- pred2 - pred1 # 构造指标(以AIC为例) AIC(model_gee) # 自助法求置信区间 B <- 1000 # 自助抽样次数 coef_boot <- matrix(0, nrow = B, ncol = length(coef(model_gee))) for (i in 1:B) { boot_data <- data[sample(nrow(data), replace = TRUE), ] boot_model <- geeglm(y ~ x1 + x2, data = boot_data, id = id, family = "cumulative", corstr = "exchangeable") coef_boot[i, ] <- coef(boot_model) } ci_boot <- apply(coef_boot, 2, quantile, c(0.025, 0.975)) # 绘图(以系数的箱线图为例) boxplot(coef_boot, names = names(coef(model_gee)), main = "Bootstrap Coefficient Estimates") ``` ### 代码功能解释 1. **数据生成**:使用`set.seed`设置随机种子以保证结果可重复,生成个体标识`id`、自变量`x1`和`x2`,并根据线性预测生成有序分类因变量`y`。 2. **模型拟合**:使用`geeglm`函数拟合广义估计方程模型,指定因变量、自变量、个体标识、分布族(`cumulative`用于有序分类数据)和相关结构(`exchangeable`)。 3. **整体检验**:使用`anova`函数进行Wald检验,评估模型整体的显著性。 4. **系数提取**:使用`coef`函数提取模型的系数。 5. **两两比较**:通过`predict`函数预测不同自变量水平下的因变量概率,然后进行差值比较。 6. **构造指标**:使用`AIC`函数计算模型的赤池信息准则。 7. **自助法求置信区间**:通过多次有放回抽样,重新拟合模型并提取系数,最后计算系数的置信区间。 8. **绘图**:使用`boxplot`函数绘制自助法得到的系数估计的箱线图。 ### 判断代码正确性的方法 1. **语法检查**:RStudio等集成开发环境会自动检查代码的语法错误,确保代码没有拼写错误、括号不匹配等问题。 2. **模型收敛**:检查模型拟合的输出信息,确保模型收敛。如果模型不收敛,可能需要调整数据或模型参数。 3. **结果合理性**:检查系数估计值、检验统计量、置信区间等结果是否合理,例如系数的符号和大小是否符合预期,检验的p值是否在合理范围内。 4. **与理论结果对比**:对于一些简单的情况,可以与理论结果进行对比,确保代码实现的功能与理论一致。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值