import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
import warnings
from scipy.cluster.hierarchy import linkage, dendrogram, leaves_list
import itertools
import itertools
import matplotlib.transforms as mtrans
from matplotlib.font_manager import FontProperties # 用于控制字体样式
warnings.filterwarnings('ignore')
id_name_or_path = 'gene/.ipynb_checkpoints/scz_for_result/id_name_filter2_or.csv'
id_name_p_path = 'gene/.ipynb_checkpoints/scz_for_result/id_name_filter2_p.csv'
id_name_or_order = 'gene/.ipynb_checkpoints/scz_for_result/id_name_or_order2.csv'
id_name_p_order = 'gene/.ipynb_checkpoints/scz_for_result/id_name_p_order2.csv'
path_or = "gene/.ipynb_checkpoints/scz_for_result/overlap_or_conditional_filter_2.csv"
path_p = "gene/.ipynb_checkpoints/scz_for_result/overlap_neglog10p_filter_2.csv"
txt_file1 = "gene/SCZ_AD_genes/SCZ_EGS_Singh_00001.txt" # 第一个TXT文件
txt_file2 = "gene/SCZ_AD_genes/SCZ_GWAS_Trubetskoy_prioritized.txt" # 第二个TXT文件
gene_id_name_path = 'gene/.ipynb_checkpoints/id_name/ID2Symbol.txt'
#加*
list1_gene_names_1 = [] # 初始化,后续通过ID→名映射填充
list1_gene_names_2 = []
def get_gene_id_lists(txt1_path, txt2_path):
with open(txt1_path, 'r', encoding='utf-8') as f1:
# 读取所有行→去除每行首尾空格→过滤空行→转为集合(自动去重)
gene_set1 = {line.strip() for line in f1 if line.strip()}
with open(txt2_path, 'r', encoding='utf-8') as f2:
gene_set2 = {line.strip() for line in f2 if line.strip()}
common_genes = list(gene_set1 & gene_set2) # 交集:共现ID
only_txt1_genes = list(gene_set1 - gene_set2) # 差集:仅第一个文件独有ID
return common_genes, only_txt1_genes
# ---------------------- 新增函数:基因ID转换为基因名(复用原映射规则) ----------------------
def convert_gene_id_to_name(gene_id_list, mapping_file_path):
mapping_df = pd.read_csv(mapping_file_path, sep='\t')
mapping_dict = {}
for _, row in mapping_df.iterrows():
gene_id = str(row['id'])
gene_name = row['name']
mapping_dict[gene_id] = gene_name
converted_names = []
for gene_id in gene_id_list:
gene_id_str = str(gene_id) # 统一转为字符串匹配
if gene_id_str in mapping_dict:
converted_names.append(mapping_dict[gene_id_str])
else:
converted_names.append(str(gene_id)) # 未匹配到的ID保留原形式
print(f"警告:基因ID {gene_id} 未在映射表中找到对应的基因名,将保留原ID")
return converted_names
def id_name_mapping(path, outpath):
df = pd.read_csv(path, header=0, sep=' ')
mapping_df = pd.read_csv('gene/.ipynb_checkpoints/id_name/ID2Symbol.txt', sep='\t')
mapping_dict = dict(zip(mapping_df['id'], mapping_df['name']))
df['dataset'] = df['dataset'].map(mapping_dict)
df.set_index('dataset', inplace=True)
mapping_df_1 = pd.read_excel('gene/gene_mapping_x.xlsx')
mapping_dict_1 = dict(zip(mapping_df_1['id'], mapping_df_1['name']))
df.columns = df.columns.astype(int)
columns_to_replace = df.columns[0:]
partial_map = {id: name for id, name in mapping_dict_1.items() if id in columns_to_replace}
df = df.rename(columns=partial_map)
df.to_csv(outpath, sep=' ')
def genes_cell_order(path, outpath):
df = pd.read_csv(path, header=0, sep=' ', index_col=0)
df_gene_order = pd.read_csv("gene/.ipynb_checkpoints/clusters/gene_cluster_order_2.csv")
target_gene_order = df_gene_order["gene_order"].tolist()
df_cell_order = pd.read_csv("gene/.ipynb_checkpoints/clusters/cell_cluster_order.csv")
target_cell_order = df_cell_order["cell_order"].tolist()
df_reordered = df.reindex(target_gene_order)
df_reordered = df_reordered.reindex(columns=target_cell_order)
df_reordered.to_csv(outpath, sep=' ')
def read_matrix_without_group(path):
df = pd.read_csv(path, keep_default_na=True, sep=' ', header=0)
if "dataset" not in df.columns:
original_first_col = df.columns[0]
df.rename(columns={original_first_col: "dataset"}, inplace=True)
mat = df.copy()
mat.set_index("dataset", inplace=True)
return {"mat": mat}
common_ids, only_first_ids = get_gene_id_lists(txt_file1, txt_file2)
id_name_mapping(path_or, id_name_or_path)
id_name_mapping(path_p, id_name_p_path)
genes_cell_order(id_name_or_path, id_name_or_order)
genes_cell_order(id_name_or_path, id_name_p_order)
or_list = read_matrix_without_group(id_name_or_order)
p_list = read_matrix_without_group(id_name_p_order)
list1_gene_names_1 = convert_gene_id_to_name(common_ids, gene_id_name_path)
list1_gene_names_2 = convert_gene_id_to_name(only_first_ids, gene_id_name_path)
print(f"\n转换后需要加**的基因名列表(list1):{list1_gene_names_1}")
print(f"\n转换后需要加**的基因名列表(list1):{list1_gene_names_2}")
or_mat = or_list["mat"].copy()
neglogp_mat = p_list["mat"].copy()
common_rows = list(set(or_mat.index) & set(neglogp_mat.index))
common_cols = list(set(or_mat.columns) & set(neglogp_mat.columns))
common_rows = [row for row in or_mat.index if row in common_rows]
common_cols = [col for col in or_mat.columns if col in common_cols]
or_mat = or_mat.loc[common_rows, common_cols].copy()
neglogp_mat = neglogp_mat.loc[common_rows, common_cols].copy()
y_levels = common_rows
x_levels = common_cols
df_or = or_mat.reset_index().melt(
id_vars="dataset",
var_name="cell_type",
value_name="OR"
)
df_neglogp = neglogp_mat.reset_index().melt(
id_vars="dataset",
var_name="cell_type",
value_name="neglog10p"
)
df = pd.merge(df_or, df_neglogp, on=["dataset", "cell_type"], how="outer")
df["neglog10p"].fillna(0, inplace=True)
df["cell_type"] = pd.Categorical(df["cell_type"], categories=x_levels, ordered=True)
df["dataset"] = pd.Categorical(df["dataset"], categories=list(reversed(y_levels)), ordered=True)
or_breaks = [-np.inf, 1, 2, 3, 4, 5, np.inf]
or_labels = ["<1", "1-2", "2-3", "3-4", "4-5", ">5"]
df["OR_cat"] = pd.cut(
df["OR"],
bins=or_breaks,
labels=or_labels,
right=False,
include_lowest=True,
ordered=True
)
# -log10(P)值分级
p_breaks = [-np.inf, 1, 2, 3, 4, np.inf]
p_labels = [ "<1", "1-2", "2-3", "3-4", ">4"]
df["neglog10p_cat"] = pd.cut(
df["neglog10p"],
bins=p_breaks,
labels=p_labels,
right=False,
include_lowest=True,
ordered=True
)
or_colors = {
"<1": "#FFF5F5", "1-2": "#EFA2A2", "2-3": "#EF5E5E",
"3-4": "#F64343", "4-5": "#D51B1B", ">5": "#580707FB"
}
p_sizes = {
"<1": 1.2, "1-2": 1.75,
"2-3": 2.5, "3-4": 3, ">4": 4
}
df["size_val"] = df["neglog10p_cat"].map(p_sizes)
or_mat_1 = or_mat.fillna(0)
# 生成基因(行)和细胞(列)的聚类树
gene_linkage = linkage(or_mat_1, method="average", metric="euclidean")
gene_order = leaves_list(gene_linkage) # 提取聚类后的行顺序
print(gene_linkage)
cell_linkage = linkage(or_mat_1.T, method="average", metric="euclidean")
# fig,ax = plt.subplots(
# nrows=1, # 无分组:仅1行1列的单个子图
# ncols=1,
# figsize=(15, 25), # 画布大小:宽10,高6
# )
# 绘制 clustermap 作为基础,包含聚类树
g = sns.clustermap(
or_mat,
row_linkage=gene_linkage,
col_linkage=cell_linkage,
cmap="coolwarm",
center=0,
figsize=(15, 15),
dendrogram_ratio=(0.1, 0.1),
cbar_pos=None, # 不显示默认颜色条,后续自定义
row_cluster=True,
col_cluster=True,
xticklabels=True,
yticklabels=True,
annot=False
)
g.ax_heatmap.set_ylim(or_mat.shape[0], 0) # 从0到基因数,完整显示
# 获取热图的轴
ax_heatmap = g.ax_heatmap
# 绘制背景网格和散点
all_pairs = list(itertools.product(x_levels, y_levels))
all_comb = pd.DataFrame(all_pairs, columns=["cell_type", "dataset"])
all_comb["cell_type"] = pd.Categorical(all_comb["cell_type"], categories=x_levels, ordered=True)
all_comb["dataset"] = pd.Categorical(all_comb["dataset"], categories=y_levels, ordered=True)
for _, row in all_comb.iterrows():
x = x_levels.index(row["cell_type"])
y = list(reversed(y_levels)).index(row["dataset"])
rect = Rectangle((x - 0.5, y - 0.5), 1, 1, facecolor="white", edgecolor="grey", linewidth=0.6)
ax_heatmap.add_patch(rect)
for or_cat in or_labels:
df_cat = df[df["OR_cat"] == or_cat].copy()
if not df_cat.empty:
x = df_cat["cell_type"].map(lambda x: x_levels.index(x))
y = df_cat["dataset"].map(lambda y: list(reversed(y_levels)).index(y))
size_values = df_cat["size_val"].astype(int)
ax_heatmap.scatter(
x, y,
s=size_values * 30,
c=or_colors[or_cat],
alpha=0.8,
edgecolors="black",
linewidth=0.3,
label=or_cat
)
right_bar_path = 'gene/gene_numbercell.xlsx'
# ---------------------- 右侧紧凑柱状图:读取Excel数据+绘制 ----------------------
right_bar_df = pd.read_excel(right_bar_path, header=0) # 读取柱状图数据
right_bar_dict = dict(zip(right_bar_df.iloc[:, 0], right_bar_df.iloc[:, 1]))
right_bar_values = [right_bar_dict.get(gene, 0) for gene in y_levels] # 按基因顺序提取数值
# ---------------------- 调整坐标轴标签 ----------------------
# X轴标签
labels_x = ax_heatmap.get_xticklabels()
dx = 6/72
dy = 5/72
offset = mtrans.ScaledTranslation(dx, dy, g.fig.dpi_scale_trans)
for label in labels_x:
label.set_transform(label.get_transform() + offset)
ax_heatmap.set_xticklabels(
x_levels, rotation=45, ha="right", va="top",
fontproperties=FontProperties(style='italic', size=8), color="black"
)
ax_heatmap.set_xlim(-0.5, len(x_levels) - 0.5)
# Y轴标签(避开右侧柱状图)
ax_heatmap.set_yticklabels([])
ax_heatmap.yaxis.tick_right()
ax_heatmap.yaxis.set_label_position("right")
ax_heatmap.set_ylabel('')
ax_heatmap.set_yticks(range(len(y_levels)))
# 字体样式
gene_font = FontProperties(style='italic', size=8, weight='normal')
symbol_font = FontProperties(style='italic', size=8, weight='bold')
# 生成基因标签(含*标记)
gene_label_parts = []
for gene in reversed(y_levels):
if gene in list1_gene_names_1:
gene_label_parts.append((gene, "**"))
elif gene in list1_gene_names_2:
gene_label_parts.append((gene, "*"))
else:
gene_label_parts.append((gene, ""))
# 绘制Y轴标签(向右移动,与柱状图紧凑对齐)
for y_pos, (gene_name, symbol) in enumerate(gene_label_parts):
ax_heatmap.text(
x=len(x_levels) -0.05, # 紧贴热图右侧
y=y_pos,
s=gene_name,
fontproperties=gene_font,
color="black",
ha="left",
va="center"
)
if symbol:
gene_length = len(gene_name) * 0.25
ax_heatmap.text(
x=len(x_levels) + 0.1 + gene_length - 0.1,
y=y_pos,
s=symbol,
fontproperties=symbol_font,
color="black",
ha="left",
va="center"
)
ax_heatmap.set_ylim(-0.5, len(y_levels) - 0.5)
# ax_heatmap.set_xticks(range(len(x_levels)))
# # 调整坐标轴范围
# # ax_heatmap.set_yticks(range(len(gene_order_aligned)))
# # 调整 y 轴标签(核心修改:斜体基因名+粗体*标记,删除dataset标志)
# # 清除默认y轴标签和刻度标签
# ax_heatmap.set_yticklabels([]) # 清除原有标签
# ax_heatmap.yaxis.tick_right()
# ax_heatmap.yaxis.set_label_position("right")
# ax_heatmap.set_ylabel('') # 删除y轴dataset标志
# # 确保y轴刻度与基因数量匹配
# ax_heatmap.set_yticks(range(len(y_levels)))
# # 定义字体样式:基因名普通斜体,*符号粗体斜体
# gene_font = FontProperties(style='italic', size=8, weight='normal') # 基因名:斜体
# symbol_font = FontProperties(style='italic', size=8, weight='bold') # *符号:斜体+粗体
# # 生成基因名与标记的对应关系(按y轴显示顺序)
# gene_label_parts = []
# for gene in reversed(y_levels): # 保持原y轴顺序
# if gene in list1_gene_names_1:
# gene_label_parts.append((gene, "**")) # 列表1:加**
# elif gene in list1_gene_names_2:
# gene_label_parts.append((gene, "*")) # 列表2:加*
# else:
# gene_label_parts.append((gene, "")) # 无标记
# # 逐个绘制y轴标签(基因名+粗体*)
# for y_pos, (gene_name, symbol) in enumerate(gene_label_parts):
# # 绘制基因名(斜体)
# x_pos = len(x_levels) # 基于细胞数量动态定位
# ax_heatmap.text(
# x=x_pos-0.2, # 调整x坐标使标签在y轴右侧
# y=y_pos,
# s=gene_name,
# fontproperties=gene_font,
# color="black",
# ha="left", # 右对齐,与*衔接
# va="center" # 垂直居中
# )
# # 绘制*符号(粗体)
# if symbol:
# gene_length = len(gene_name) *0.25 # 估算基因名宽度(可微调系数)
# ax_heatmap.text(
# x=x_pos + gene_length -0.1 ,
# y=y_pos,
# s=symbol,
# fontproperties=symbol_font,
# color="black",
# ha="left", # 左对齐,与基因名衔接
# va="center"
# )
# ax_heatmap.set_ylim(-0.5, len(y_levels) - 0.5) # 确保标签完整显示
# # 调整 x 轴标签
# labels_x= ax_heatmap.get_xticklabels()
# dx = 6/72
# dy = 5/72
# offset = mtrans.ScaledTranslation(dx, dy, g.fig.dpi_scale_trans)
# for label in labels_x:
# label.set_transform(label.get_transform() + offset)
# ax_heatmap.set_xticklabels(x_levels, rotation=45, ha="right", va="top",fontproperties=FontProperties(style='italic', size=8),fontsize=8, color="black")
# ax_heatmap.set_xlim(-0.5, len(x_levels) -0.5) # 右侧预留足够空间
# 图例和布局调整
or_legend_base_size = p_sizes["1-2"] * 30 / 10
handles_or = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=or_colors[cat],
markersize=or_legend_base_size * 1.5, label=cat, markeredgecolor='black', markeredgewidth=0.2)
for cat in or_labels]
handles_p = []
for cat in p_labels:
if cat in p_sizes:
size = or_legend_base_size * (p_sizes[cat] / p_sizes["1-2"])
else:
size = or_legend_base_size * (0.3 / p_sizes["1-2"])
handle = plt.Line2D(
[0], [0],
marker='o',
color='w',
markerfacecolor='grey',
markersize=size,
label=cat,
markeredgecolor='black',
markeredgewidth=0.2
)
handles_p.append(handle)
legend_or = g.fig.legend(handles_or, or_labels, title="Odds ratio", title_fontsize=11,
loc="center right", bbox_to_anchor=(0.93, 0.05), fontsize=8)
legend_p = g.fig.legend(handles_p, p_labels, title=r'$-log_{10}(p)$', title_fontsize=11,
loc="center right", bbox_to_anchor=(0.99, 0.05), fontsize=8)
plt.subplots_adjust(right=0.85, left=0.1, top=0.85, bottom=0.1)
g.savefig("gene/.ipynb_checkpoints/png/cluster_plot.png", dpi=300, bbox_inches="tight")右侧要添加的是柱状图,数据在xlsx里,柱状图最左边与y轴基因名称最后面对应,柱状图最右边到整个图的边缘