请你在原有的基础上进行刚才的修改:import pandas as pd
import numpy as np
import itertools
from collections import Counter
import seaborn as sns
import matplotlib.pyplot as plt
import os
import sys
def process_custom_gff(file_path):
"""处理自定义格式的GFF文件"""
# 检查文件是否存在
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
print(f"📖 正在读取GFF文件: {file_path}")
# 读取自定义格式的GFF文件
try:
df = pd.read_csv(
file_path,
sep="\t",
header=None,
names=["chr", "start", "end", "strand", "TFs"]
)
except Exception as e:
raise ValueError(f"读取文件失败: {str(e)}")
# 处理TF列表:拆分、去重、移除可能的空白符
def clean_tf_list(tf_str):
"""清洗并处理TF字符串"""
if not isinstance(tf_str, str):
return []
tfs = [tf.strip() for tf in tf_str.split(",")]
# 移除链信息中的逗号(如"+,-")
tfs = [tf for tf in tfs if tf not in ['+', '-', ',']]
return list(set(tfs))
print("🧹 清洗和预处理TF数据...")
df["TF_list"] = df["TFs"].apply(clean_tf_list)
return df
def generate_co_matrix(df):
"""生成转录因子共现矩阵"""
print("🧮 正在计算转录因子共现次数...")
# 统计TF对的共现次数(仅无序对)
pair_counter = Counter()
for i, tf_list in enumerate(df["TF_list"]):
# 每处理1000行打印一次进度
if (i + 1) % 1000 == 0:
print(f" 已处理 {i+1}/{len(df)} 行...")
# 只处理有2个以上TF的条目
if len(tf_list) >= 2:
# 生成所有组合(无序对)
for tf1, tf2 in itertools.combinations(sorted(tf_list), 2):
# 使用排序确保(a, b)和(b, a)被视为相同
ordered_pair = tuple(sorted([tf1, tf2]))
pair_counter[ordered_pair] += 1
# 获取所有唯一的TF
print("🔄 收集唯一转录因子...")
all_tfs = sorted(set(tf for sublist in df["TF_list"] for tf in sublist))
print(f"🔢 创建 {len(all_tfs)} x {len(all_tfs)} 共现矩阵...")
# 创建全零矩阵
co_matrix = pd.DataFrame(
np.zeros((len(all_tfs), len(all_tfs)), dtype=int),
index=all_tfs,
columns=all_tfs
)
# 填充矩阵
print("🖊️ 填充共现矩阵...")
for (tf1, tf2), count in pair_counter.items():
co_matrix.loc[tf1, tf2] = count
co_matrix.loc[tf2, tf1] = count # 使矩阵对称
return co_matrix
def visualize_matrix(co_matrix, output_file):
"""可视化共现矩阵并保存图片"""
print(f"🎨 正在生成热力图并保存到: {output_file}")
plt.figure(figsize=(15, 13))
# 使用掩码仅隐藏对角线(自身共现)
mask = np.zeros_like(co_matrix, dtype=bool)
np.fill_diagonal(mask, True)
# 创建热力图 - 显示所有TF
sns.heatmap(
co_matrix,
mask=mask, # 只隐藏对角线
annot=True,
fmt="d",
cmap="viridis",
linewidths=0.5,
cbar_kws={"label": "Co-occurrence Count"},
annot_kws={"size": 8}
)
plt.title("Transcription Factor Co-occurrence Matrix", fontsize=16)
# 调整标签大小和旋转角度
plt.xticks(rotation=45, ha="right", fontsize=9)
plt.yticks(fontsize=9)
# 添加标签以减少重叠的技巧
if len(co_matrix) > 30:
# 每隔一个标签显示一次
plt.xticks(range(len(co_matrix))[::2], co_matrix.columns[::2])
plt.yticks(range(len(co_matrix))[::2], co_matrix.index[::2])
plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches="tight")
plt.close()
print("✅ 热力图保存成功!")
def calculate_tf_frequency(df):
"""计算每个TF的出现频率"""
print("📊 正在计算转录因子频率...")
tf_freq = Counter()
for i, tf_list in enumerate(df["TF_list"]):
# 每处理1000行打印一次进度
if (i + 1) % 1000 == 0:
print(f" 已处理 {i+1}/{len(df)} 行...")
for tf in tf_list:
tf_freq[tf] += 1
freq_df = pd.DataFrame.from_dict(tf_freq, orient="index", columns=["frequency"])
freq_df = freq_df.sort_values("frequency", ascending=False)
return freq_df
def generate_co_occurrence_network(co_matrix, min_count=1, output_file="tf_network.gexf"):
"""将共现矩阵转换为网络图格式(GEXF)"""
print(f"🌐 正在生成共现网络图 (min_count={min_count})...")
try:
import networkx as nx
except ImportError:
print("❌ 需要安装networkx库: pip install networkx")
return
G = nx.Graph()
# 添加节点
for tf in co_matrix.index:
G.add_node(tf)
# 添加边(只添加满足最小共现次数的边)
for i, tf1 in enumerate(co_matrix.index):
for j, tf2 in enumerate(co_matrix.columns):
if i < j: # 避免重复添加边
count = co_matrix.loc[tf1, tf2]
if count >= min_count:
G.add_edge(tf1, tf2, weight=count)
print(f" 网络包含 {G.number_of_nodes()} 个节点和 {G.number_of_edges()} 条边")
# 保存为GEXF格式
nx.write_gexf(G, output_file)
print(f"💾 网络图已保存至: {output_file}")
def main():
# 使用绝对路径
input_file = "/share/home/xiaoshunpeng/CTCF.gff"
# 输出文件路径(与输入文件同目录)
output_dir = os.path.dirname(input_file)
output_csv = os.path.join(output_dir, "TF_co_occurrence_matrix.csv")
output_image = os.path.join(output_dir, "TF_co_occurrence_heatmap.png")
freq_output = os.path.join(output_dir, "TF_frequency.csv")
network_output = os.path.join(output_dir, "TF_network.gexf")
try:
df = process_custom_gff(input_file)
except Exception as e:
print(f"❌ 错误: {str(e)}")
print("请检查文件路径和格式是否正确")
sys.exit(1)
# 检查是否提取到TF数据
total_tf_occurrences = sum(len(tfs) for tfs in df["TF_list"])
if total_tf_occurrences == 0:
print("⚠️ 警告: 未在GFF文件中找到转录因子数据")
print("请检查文件格式,确保第五列包含逗号分隔的转录因子列表")
print("文件前5行预览:")
print(df.head())
sys.exit(1)
print(f"✅ 发现 {len(df)} 个区域,共 {total_tf_occurrences} 次转录因子出现")
# 计算TF频率
freq_df = calculate_tf_frequency(df)
freq_df.to_csv(freq_output)
print(f"💾 TF频率数据已保存至: {freq_output}")
# 打印最常见的TF
print("\n🔝 前10个最常出现的转录因子:")
print(freq_df.head(10))
# 生成共现矩阵
co_matrix = generate_co_matrix(df)
# 保存矩阵
co_matrix.to_csv(output_csv)
print(f"💾 共现矩阵已保存至: {output_csv}")
# 可视化并保存
visualize_matrix(co_matrix, output_image)
# 生成网络图
generate_co_occurrence_network(co_matrix, min_count=5, output_file=network_output)
# 打印统计信息
num_tfs = len(co_matrix)
num_pairs = (num_tfs * (num_tfs - 1)) // 2
max_co_occur = co_matrix.values.max()
max_tf_pair = co_matrix.stack().idxmax()
print("\n📈 统计摘要:")
print(f"- 唯一转录因子数量: {num_tfs}")
print(f"- 可能的TF对数量: {num_pairs}")
print(f"- 实际观测到的TF对数量: {len([x for x in co_matrix.values.flatten() if x > 0]) // 2}")
print(f"- 最高共现次数: {max_co_occur} (TF对: {max_tf_pair})")
print("✨ 处理完成!")
if __name__ == "__main__":
main()