CLRS算法推理基准项目教程
1. 项目介绍
CLRS算法推理基准(CLRS Algorithmic Reasoning Benchmark)是一个新兴的机器学习领域,旨在将神经网络与经典算法相结合。该项目提供了一系列经典算法的实现,这些算法选自《算法导论》第三版(Introduction to Algorithms by Cormen, Leiserson, Rivest and Stein)。CLRS项目的目标是通过提供这些算法的实现,来评估算法推理的能力。
2. 项目快速启动
安装CLRS
你可以通过pip安装CLRS,可以从PyPI安装:
pip install dm-clrs
或者直接从GitHub安装(更新更频繁):
pip install git+https://github.com/google-deepmind/clrs.git
如果你希望避免依赖冲突,可以在虚拟环境中安装:
python3 -m venv clrs_env
source clrs_env/bin/activate
pip install git+https://github.com/google-deepmind/clrs.git
运行示例模型
安装完成后,你可以运行示例基准模型:
python3 -m clrs.examples.run
首次运行时,数据集将被下载并存储在--dataset_path(默认路径为/tmp/CLRS30)。
3. 应用案例和最佳实践
使用图神经网络(GNN)进行训练
CLRS项目提供了一个完整的图神经网络(GNN)示例,使用JAX和DeepMind的JAX生态库。该示例允许在单个处理器上训练多个算法,如“A Generalist Neural Algorithmic Learner”中所述。
自定义处理器
如果你希望实现一个新的处理器,可以在processors.py文件中添加,并通过get_processor_factory方法使其可用。处理器的__call__方法应如下所示:
__call__(self, node_fts, edge_fts, graph_fts, adj_mat, hidden, nb_nodes, batch_size)
其中,node_fts、edge_fts和graph_fts是形状为batch_size x nb_nodes x H、batch_size x nb_nodes x nb_nodes x H和batch_size x H的浮点数组,分别表示节点、边和图的编码特征。adj_mat是一个布尔数组,表示连接性,hidden是前一步处理器的输出。
4. 典型生态项目
相关算法
CLRS-30基准包括以下30个算法:
- 排序:插入排序、冒泡排序、堆排序、快速排序
- 搜索:最小值、二分搜索、快速选择
- 分治:最大子数组(Kadane变体)
- 贪心:活动选择、任务调度
- 动态规划:矩阵链乘法、最长公共子序列、最优二叉搜索树
- 图:深度优先搜索、广度优先搜索、拓扑排序、割点、桥、Kosaraju强连通分量算法、Kruskal最小生成树算法、Prim最小生成树算法、Bellman-Ford单源最短路径算法、Dijkstra单源最短路径算法、有向无环图单源最短路径、Floyd-Warshall全对最短路径
- 字符串:朴素字符串匹配、Knuth-Morris-Pratt(KMP)字符串匹配器
- 几何:线段相交、Graham扫描凸包算法、Jarvis步进凸包算法
基线模型
CLRS项目提供了以下GNN基线处理器的JAX实现:
- Deep Sets
- End-to-End Memory Networks
- Graph Attention Networks
- Graph Attention Networks v2
- Message-Passing Neural Networks
- Pointer Graph Networks
这些基线模型可以帮助你快速上手并理解如何在CLRS基准上进行算法推理。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



