Note: 本脚本没有对库的别名做处理,如 uwsgi 和uWSGI,所以在筛查不必要的库,如运行过程或结果可视化的 tqdm 或 tensorboard 之类的时候一并手动删去。如果读者有更好的方法,期待您的智慧。
# 把当前目录及其子目录下所有的requirements.txt文件中的内容汇总到一个文件中
# 同名库则以版本最新的为准,如果没有指定版本则认为是最新版本
import os
# 获取当前目录下所有的requirements.txt文件
def get_lib(path):
for root, dirs, files in os.walk(path):
for file in files:
if file == 'requirements.txt' and root != '.':
with open(os.path.join(root, file), 'r') as f:
for line in f.readlines():
if line.startswith('#') or line.startswith('\n'):
continue
yield line
# 比较运算符
comparison_operators = ['==', '>=', '<=', '>', '<']
# {lib_name: [lib_version, operator]}
hash_map = {}
for lib in get_lib('.'):
for op in comparison_operators:
if op in lib:
lib_name, lib_version = lib.split(op)
break
else:
lib_name, lib_version, op = lib.strip(), 'new', ''
lib_name, lib_version = lib_name.strip(), lib_version.strip()
if lib_name not in hash_map:
hash_map[lib_name] = [lib_version, op]
else:
if hash_map[lib_name][0] == 'new':
continue
elif lib_version == 'new':
hash_map[lib_name] = [lib_version, op]
else:
# 比较版本号,以新的为准
# 计算版本号的差值
lib_version_list = [int(i) for i in lib_version.strip().split('.')]
hash_map_version_list = [int(i) for i in hash_map[lib_name][0].strip().split('.')]
# 短的补零
if len(lib_version_list) > len(hash_map_version_list):
hash_map_version_list.extend([0] * (len(lib_version_list) - len(hash_map_version_list)))
elif len(lib_version_list) < len(hash_map_version_list):
lib_version_list.extend([0] * (len(hash_map_version_list) - len(lib_version_list)))
# 比较
for v1, v2 in zip(lib_version_list, hash_map_version_list):
if v1 > v2:
hash_map[lib_name] = [lib_version, op]
break
elif v1 < v2:
break
with open('temp/requirements.txt', 'w') as f:
for lib, version_info in hash_map.items():
if version_info[1]:
f.write(f'{lib}{version_info[1]}{version_info[0]}\n')
else:
f.write(f'{lib}\n')