抖音式短视频推荐系统开发方案
系统架构设计
graph TD
A[数据源] --> B[数据采集]
B --> C[数据存储]
C --> D[离线处理]
C --> E[实时处理]
D --> F[特征工程]
E --> F
F --> G[召回层]
F --> H[粗排层]
F --> I[精排层]
G --> J[候选集生成]
H --> K[初步排序]
I --> L[最终排序]
L --> M[重排优化]
M --> N[推荐服务]
N --> O[API网关]
O --> P[客户端]
subgraph 数据源
A1[用户行为日志]
A2[视频内容元数据]
A3[社交关系数据]
A4[外部热点数据]
end
subgraph 数据采集
B1[Kafka] -->|实时流| 实时处理
B2[Flume] -->|批量日志| 离线处理
B3[Sqoop] -->|关系型数据| 离线处理
end
subgraph 数据存储
C1[HDFS] -->|离线数据| 离线处理
C2[HBase] -->|实时特征| 实时处理
C3[Redis] -->|缓存特征| 推荐服务
C4[Neo4j] -->|社交图谱| 社交关系
end
subgraph 离线处理
D1[Spark] -->|特征计算| 特征工程
D2[Hive] -->|离线模型训练| 模型服务
end
subgraph 实时处理
E1[Flink] -->|实时特征更新| 特征工程
E2[Storm] -->|实时行为分析| 候选集生成
end
subgraph 推荐服务
N1[Spring Cloud] -->|API服务| API网关
N2[gRPC] -->|高性能调用| 客户端
end
核心模块实现
1. 数据采集与存储
多源数据采集
// 用户行为日志生产者(Kafka)
public class UserBehaviorProducer {
private static final String BOOTSTRAP_SERVERS = "kafka1:9092,kafka2:9092";
private static final String TOPIC = "user_behavior";
public void sendBehaviorEvent(UserBehaviorEvent event) {
Properties props = new Properties();
props.put("bootstrap.servers", BOOTSTRAP_SERVERS);
props.put("key.serializer", "org.apache.kafka.common.serialization.StringSerializer");
props.put("value.serializer", "org.apache.kafka.common.serialization.StringSerializer");
try (Producer<String, String> producer = new KafkaProducer<>(props)) {
String eventJson = new ObjectMapper().writeValueAsString(event);
producer.send(new ProducerRecord<>(TOPIC, event.getUserId(), eventJson));
} catch (JsonProcessingException e) {
logger.error("Failed to serialize event", e);
}
}
}
// 视频内容元数据采集(Flume)
public class VideoMetadataSource extends AbstractSource implements Configurable {
private String hdfsPath;
private FileSystem fs;
@Override
public void configure(Context context) {
hdfsPath = context.getString("hdfs.path");
fs = FileSystem.get(new Configuration());
}
@Override
public Status process() throws EventDeliveryException {
// 从HDFS拉取视频元数据(MP4/JSON)
FileStatus[] statuses = fs.listStatus(new Path(hdfsPath));
for (FileStatus status : statuses) {
String metadata = new String(fs.getContent(status.getPath()));
// 发送到Kafka
getChannelProcessor().processEvent(EventBuilder.withBody(metadata.getBytes()));
}
return Status.READY;
}
}
数据存储设计
数据类型 | 存储介质 | 用途 | 示例 |
---|---|---|---|
用户行为日志 | HDFS + Kafka | 离线训练+实时特征更新 | 观看记录、点赞、评论、分享 |
视频元数据 | HBase + Hive | 视频内容检索+特征提取 | 标签、时长、作者、音乐、分辨率 |
社交关系 | Neo4j + Redis | 社交推荐+粉丝关系链 | 关注列表、粉丝列表、共同好友 |
实时特征 | Redis + HBase | 在线推荐特征查询 | 用户兴趣向量、视频热度、近期观看偏好 |
模型参数 | HDFS + TensorFlow Serving | 模型部署+在线推理 | DeepFM权重、DIN嵌入向量 |
2. 特征工程
用户特征提取
// 使用Spark计算用户兴趣向量(基于协同过滤)
val userBehavior = spark.read.parquet("hdfs:///data/user_behavior")
.select("user_id", "item_id", "behavior_type", "timestamp")
// 计算用户-物品交互矩阵
val userItemMatrix = userBehavior
.groupBy("user_id", "item_id")
.agg(count("*").as("interaction_count"))
// 使用ALS分解隐向量
val als = new ALS()
.setUserCol("user_id")
.setItemCol("item_id")
.setRatingCol("interaction_count")
.setRank(64)
.setMaxIter(10)
.setRegParam(0.01)
val model = als.fit(userItemMatrix)
val userFactors = model.userFactors // 用户隐向量(64维)
视频特征提取
# 使用ResNet提取视频帧特征(PyTorch)
import torch
from torchvision import models, transforms
# 加载预训练模型
model = models.resnet50(pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1])) # 去掉全连接层
model.eval()
# 视频帧预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 提取视频关键帧特征
def extract_video_features(video_path):
frames = extract_frames(video_path) # 使用OpenCV提取关键帧
features = []
for frame in frames:
tensor = transform(frame).unsqueeze(0)
with torch.no_grad():
feature = model(tensor).squeeze().numpy()
features.append(feature)
return np.mean(features, axis=0) # 平均池化
# 存储视频特征到HBase
def save_video_features(video_id, features):
connection = happybase.Connection('hbase-host')
table = connection.table('video_features')
table.put(video_id, {
b'cf:features': features.tobytes(),
b'cf:metadata': json.dumps({'duration': 60, 'author_id': 'author123'}).encode()
})
上下文特征
// 实时上下文特征提取(Flink)
public class ContextFeatureExtractor
extends KeyedProcessFunction<String, UserBehaviorEvent, ContextFeatures> {
private transient ValueState<ContextState> state;
@Override
public void open(Configuration parameters) {
ValueStateDescriptor<ContextState> descriptor =
new ValueStateDescriptor<>("context-state", ContextState.class);
state = getRuntimeContext().getState(descriptor);
}
@Override
public void processElement(
UserBehaviorEvent event,
Context ctx,
Collector<ContextFeatures> out) throws Exception {
ContextState currentState = state.value();
if (currentState == null) {
currentState = new ContextState();
currentState.setUserId(event.getUserId());
currentState.setTimeWindow(new TimeWindow(ctx.timestamp(), ctx.timestamp() + 60000)); // 1分钟窗口
}
// 更新上下文状态
currentState.update(event);
state.update(currentState);
// 生成上下文特征
ContextFeatures features = new ContextFeatures();
features.setHourOfDay(ctx.timestamp() / 3600000 % 24);
features.setDayOfWeek(ctx.timestamp() / 86400000 % 7);
features.setRecentViewedItems(currentState.getRecentItems());
features.setDeviceType(event.getDeviceType());
out.collect(features);
}
}
3. 多阶段推荐算法
召回层(Retrieval)
目标:从百万级视频中快速筛选千级候选集
方法:
- 协同过滤召回:基于用户历史行为的相似物品(ItemCF)
- 内容匹配召回:用户兴趣向量与视频特征向量的余弦相似度
- 哈希索引召回:对视频标签/作者ID做哈希分桶,快速定位候选
// ItemCF召回(Spark)
val userItemInteractions = spark.read.parquet("hdfs:///data/user_behavior")
.filter("behavior_type IN ('view', 'like', 'purchase')")
.groupBy("user_id", "item_id")
.agg(count("*").as("weight"))
// 计算物品相似度矩阵
val itemSimilarity = userItemInteractions
.groupBy("item_id")
.agg(collect_list(struct("cooccurrence_item_id", "weight")).as("similar_items"))
.cache()
// 为用户生成候选集
def getRecallItems(userId: String, topN: Int = 1000): Array[String] = {
val userInteractions = userItemInteractions.filter(col("user_id") === userId)
val similarItems = userInteractions.flatMap { row =>
val itemId = row.getAs"item_id"
val similarities = itemSimilarity.filter(col("item_id") === itemId)
.select("similar_items.*")
.as[SimilarItem]
similarities.map(si => (si.cooccurrenceItemId, si.weight))
}.groupBy("cooccurrenceItemId")
.agg(sum("weight").as("total_weight"))
.orderBy(desc("total_weight"))
similarItems.limit(topN).select("cooccurrenceItemId").as[String].collect()
}
粗排层(Pre-Ranking)
目标:对候选集初步排序,过滤低质量视频
方法:轻量级模型(如LR、LightGBM)
特征:用户兴趣向量、视频热度(播放量/点赞数)、内容标签匹配度、作者粉丝量
# LightGBM粗排模型(Python)
import lightgbm as lgb
# 特征工程
def build_features(user_features, video_features, context_features):
features = {
"user_age": user_features["age"],
"user_gender": user_features["gender"],
"video_play_count": video_features["play_count"],
"video_like_count": video_features["like_count"],
"tag_match": cosine_similarity(user_features["interest_vector"], video_features["tag_vector"]),
"author_fans": video_features["author_fans"],
"hour_of_day": context_features["hour_of_day"]
}
return pd.DataFrame([features])
# 模型训练
train_data = spark.read.parquet("hdfs:///data/train_data")
X_train = build_features(train_data["user_features"], train_data["video_features"], train_data["context_features"])
y_train = train_data["label"] # 0/1表示是否点击
model = lgb.LGBMClassifier(
num_leaves=31,
learning_rate=0.05,
n_estimators=100,
metric="auc"
)
model.fit(X_train, y_train)
# 粗排预测
def pre_rank(user_id, candidates):
user_features = get_user_features(user_id)
scores = []
for video_id in candidates:
video_features = get_video_features(video_id)
context_features = get_context_features()
X = build_features(user_features, video_features, context_features)
score = model.predict_proba(X)[0][1]
scores.append((video_id, score))
return sorted(scores, key=lambda x: -x[1])[:200] # 取前200
精排层(Ranking)
目标:精确排序,优化CTR/CVR
方法:深度学习模型(如DeepFM、DIN、Transformer)
特征:用户长期兴趣(Embedding)、短期行为序列(RNN/Transformer)、视频多模态特征(视觉+文本)
# DeepFM精排模型(TensorFlow)
import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, Dense, Concatenate, FM
# 特征配置
user_features = {
"user_id": {"vocab_size": 1e6, "emb_dim": 64},
"gender": {"vocab_size": 3, "emb_dim": 8},
"age_group": {"vocab_size": 10, "emb_dim": 8}
}
video_features = {
"video_id": {"vocab_size": 5e5, "emb_dim": 64},
"category": {"vocab_size": 100, "emb_dim": 16},
"author_id": {"vocab_size": 1e5, "emb_dim": 32}
}
# 构建DeepFM模型
def build_deepfm_model():
# 用户特征输入
user_inputs = {
"user_id": Input(shape=(1,), name="user_id"),
"gender": Input(shape=(1,), name="gender"),
"age_group": Input(shape=(1,), name="age_group")
}
# 视频特征输入
video_inputs = {
"video_id": Input(shape=(1,), name="video_id"),
"category": Input(shape=(1,), name="category"),
"author_id": Input(shape=(1,), name="author_id")
}
# 嵌入层
user_emb = {k: Embedding(v["vocab_size"], v["emb_dim"])(v_in)
for k, v_in in user_inputs.items() for k, v in user_features.items()}
video_emb = {k: Embedding(v["vocab_size"], v["emb_dim"])(v_in)
for k, v_in in video_inputs.items() for k, v in video_features.items()}
# FM部分
fm_input = Concatenate(axis=1)([
Flatten()(user_emb["user_id"]),
Flatten()(video_emb["video_id"]),
Flatten()(user_emb["gender"]),
Flatten()(video_emb["category"])
])
fm_output = FM()(fm_input)
# DNN部分
dnn_input = Concatenate(axis=1)([
Flatten()(user_emb["age_group"]),
Flatten()(video_emb["author_id"])
])
dnn_output = Dense(128, activation="relu")(dnn_input)
dnn_output = Dense(64, activation="relu")(dnn_output)
# 输出层
output = Dense(1, activation="sigmoid")(Concatenate()([fm_output, dnn_output]))
model = tf.keras.Model(
inputs=list(user_inputs.values()) + list(video_inputs.values()),
outputs=output
)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["auc"])
return model
重排层(Re-Ranking)
目标:优化多样性、覆盖率、用户体验
方法:
- 多样性控制:基于类别/标签的去重(如限制同一作者视频不超过2个)
- 新颖性提升:引入冷门视频(播放量<1万但质量分高)
- 场景适配:根据用户使用场景(如睡前、通勤)调整推荐策略
// 重排策略(Java)
public class ReRanker {
private static final int MAX_CATEGORY = 5; // 最多保留5个不同类别
private static final int MAX_AUTHOR = 2; // 同一作者最多2个视频
public List<Video> reRank(List<Video> candidates) {
// 1. 按精排分数排序
candidates.sort((a, b) -> Double.compare(b.getScore(), a.getScore()));
// 2. 多样性控制:按类别去重
Map<String, Video> categoryMap = new HashMap<>();
for (Video video : candidates) {
String category = video.getCategory();
if (!categoryMap.containsKey(category) && categoryMap.size() < MAX_CATEGORY) {
categoryMap.put(category, video);
}
}
// 3. 新颖性控制:加入冷门视频
List<Video> coldVideos = getColdVideos(candidates); // 从候选集中筛选播放量低的视频
List<Video> result = new ArrayList<>(categoryMap.values());
result.addAll(coldVideos.stream().limit(10).collect(Collectors.toList()));
// 4. 去重同一作者
Set<String> authorSet = new HashSet<>();
List<Video> finalResult = new ArrayList<>();
for (Video video : result) {
if (!authorSet.contains(video.getAuthorId()) && finalResult.size() < 30) {
authorSet.add(video.getAuthorId());
finalResult.add(video);
}
}
return finalResult;
}
}
4. 实时推荐服务
在线推理服务
// 推荐服务接口(Spring Boot)
@RestController
@RequestMapping("/recommend")
public class RecommendationController {
@Autowired
private RecommendationService recommendationService;
@GetMapping("/{userId}")
public ResponseEntity<List<Video>> getRecommendations(
@PathVariable String userId,
@RequestParam(defaultValue = "30") int count) {
// 1. 获取实时特征
UserFeatures userFeatures = featureStore.getUserFeatures(userId);
ContextFeatures context = featureStore.getContextFeatures(userId);
// 2. 召回候选集
List<String> candidateIds = recallService.getRecallItems(userId, 1000);
// 3. 粗排+精排
List<Video> recommendations = rankingService.rank(
userFeatures,
candidateIds,
context,
count
);
// 4. 重排优化
recommendations = reRanker.reRank(recommendations);
return ResponseEntity.ok(recommendations);
}
}
性能优化
// 多级缓存策略(Caffeine + Redis)
public class VideoCache {
private static final Caffeine<Object, Object> localCache = Caffeine.newBuilder()
.maximumSize(1000)
.expireAfterWrite(5, TimeUnit.MINUTES)
.build();
private static final RedisCacheClient redisCache = new RedisCacheClient("redis-host", 6379);
public Video getVideo(String videoId) {
// 1. 本地缓存
Video video = (Video) localCache.get(videoId);
if (video != null) {
return video;
}
// 2. Redis缓存
video = redisCache.get(videoId, Video.class);
if (video != null) {
localCache.put(videoId, video);
return video;
}
// 3. 数据库查询
video = videoService.getVideoFromDB(videoId);
if (video != null) {
redisCache.put(videoId, video, 30, TimeUnit.MINUTES); // 缓存30分钟
localCache.put(videoId, video);
}
return video;
}
}
5. 冷启动解决方案
新用户推荐
public class NewUserRecommender {
// 基于人口统计的初始推荐
public List<Video> getDemographicRecommendations(User user) {
String sql = "SELECT video_id, score FROM demographic_recommendations " +
"WHERE age_group = ? AND gender = ? AND city = ? " +
"ORDER BY score DESC LIMIT 30";
return jdbcTemplate.query(sql,
new Object[]{user.getAgeGroup(), user.getGender(), user.getLocation()},
new VideoRowMapper());
}
// 基于热门内容的推荐
public List<Video> getHotRecommendations() {
String sql = "SELECT video_id, play_count FROM videos " +
"ORDER BY play_count DESC LIMIT 30";
return jdbcTemplate.query(sql, new VideoRowMapper());
}
}
新视频推荐
public class NewVideoRecommender {
// 基于内容相似度的推荐
public List<Video> getContentSimilarRecommendations(Video newVideo) {
// 使用视频特征向量计算相似度
String sql = "SELECT similar_video_id, similarity FROM video_similarity " +
"WHERE video_id = ? " +
"ORDER BY similarity DESC LIMIT 30";
return jdbcTemplate.query(sql,
new Object[]{newVideo.getVideoId()},
(rs, rowNum) -> {
Video video = new Video();
video.setVideoId(rs.getString("similar_video_id"));
video.setScore(rs.getDouble("similarity"));
return video;
});
}
// 基于作者的推荐
public List<Video> getAuthorBasedRecommendations(Video newVideo) {
String authorId = newVideo.getAuthorId();
String sql = "SELECT video_id FROM videos " +
"WHERE author_id = ? AND video_id != ? " +
"ORDER BY publish_time DESC LIMIT 30";
return jdbcTemplate.query(sql,
new Object[]{authorId, newVideo.getVideoId()},
new VideoRowMapper());
}
}
6. 监控与优化
关键监控指标
指标类型 | 具体指标 | 阈值/说明 |
---|---|---|
性能指标 | 推荐延迟(ms) | < 200ms |
QPS(每秒请求数) | > 10万 | |
效果指标 | CTR(点击率) | 目标值:8%-12% |
完播率(视频播放完成比例) | 目标值:30%-40% | |
互动率(点赞+评论+分享) | 目标值:5%-8% | |
质量指标 | 推荐多样性(类别覆盖数) | > 10 |
冷门视频占比 | 10%-20% |
A/B测试框架
-- 实验效果评估(CTR对比)
SELECT
experiment_group,
COUNT(DISTINCT user_id) AS users,
SUM(CASE WHEN event_type = 'click' THEN 1 ELSE 0 END) AS clicks,
SUM(CASE WHEN event_type = 'impression' THEN 1 ELSE 0 END) AS impressions,
ROUND(SUM(CASE WHEN event_type = 'click' THEN 1 ELSE 0 END) * 1.0 /
SUM(CASE WHEN event_type = 'impression' THEN 1 ELSE 0 END), 4) AS ctr
FROM recommendation_events
WHERE experiment_id = 'exp_20240715'
AND event_time BETWEEN '2024-07-15' AND '2024-07-22'
GROUP BY experiment_group;
-- 转化漏斗分析
WITH funnel AS (
SELECT
user_id,
MAX(CASE WHEN event_type = 'impression' THEN 1 ELSE 0 END) AS impression,
MAX(CASE WHEN event_type = 'click' THEN 1 ELSE 0 END) AS click,
MAX(CASE WHEN event_type = 'purchase' THEN 1 ELSE 0 END) AS purchase
FROM recommendation_events
WHERE experiment_id = 'exp_20240715'
AND event_time BETWEEN '2024-07-15' AND '2024-07-22'
GROUP BY user_id
)
SELECT
COUNT(*) AS total_users,
SUM(impression) AS impressions,
SUM(click) AS clicks,
SUM(purchase) AS purchases,
ROUND(SUM(click)*1.0/SUM(impression), 4) AS ctr,
ROUND(SUM(purchase)*1.0/SUM(click), 4) AS cvr
FROM funnel;
总结
抖音式短视频推荐系统的核心在于海量数据的实时处理、多阶段精准排序和个性化体验优化。通过分层架构设计(数据层→算法层→服务层),结合协同过滤、深度学习等算法,以及实时特征工程和缓存策略,能够支撑亿级用户的低延迟推荐需求。同时,冷启动策略和A/B测试机制确保了系统的持续优化能力。
该方案的优势:
- 高实时性:毫秒级响应,满足用户即时需求
- 强个性化:基于用户行为、兴趣、上下文的多维度推荐
- 高多样性:通过重排策略避免信息茧房
- 可扩展性:分布式架构支持业务规模增长
适用于短视频平台、直播推荐、内容社区等场景,为企业提供数据驱动的用户增长解决方案。