Scanpy画图时保证不同标签颜色映射相同
我在可视化时发现不同的adata.obs列对应的颜色不同,这让我分析不同标签时很不直观,怎么解决这个问题呢?请看!
首先有一个adata对象,obs里面有[‘Level1_P2CRC’, ‘Level2_P2CRC’, ‘cell_type_2’, ‘cell_type_3’, ‘Level1_ct’, ‘Level2_ct’]这些不同的选项。此外,uns 属性中包含了每个列的颜色映射,如 ‘Level1_P2CRC_colors’ 等。
我们首先要确定哪些类别在多个列中共享,然后为每个共享类别选择一个标准颜色,并确保所有相关列使用此颜色,最后更新adata的颜色映射并画图。
识别共享类别
列出所有观察列,并创建一个字典 category_columns,其中键是类别名称,值是该类别出现的列列表:
adata_copy = adata.copy()
obs_cols = ['Level1_P2CRC', 'Level2_P2CRC', 'cell_type_2', 'cell_type_3', 'Level1_ct', 'Level2_ct']
category_columns = {}
for col in obs_cols:
categories = adata_copy.obs[col].cat.categories
for cat in categories:
if cat not in category_columns:
category_columns[cat] = [col]
else:
category_columns[cat].append(col)
创建当前颜色映射
对于每个观察列,我们需要获取其当前颜色映射。Scanpy 的颜色映射存储在 uns 属性中,以列名后缀 _colors 的形式存在。我们创建 color_mappings 字典,其中键是列名,值是类别到颜色的映射:
color_mappings = {}
for col in obs_cols:
categories = adata_copy.obs[col].cat.categories
colors = adata_copy.uns[f'{col}_colors']
color_dict = dict(zip(categories, colors))
color_mappings[col] = color_dict
标准化共享类别的颜色
对于每个共享类别(即 category_columns 中值列表长度大于 1 的条目),我们选择一个标准颜色。这里,我们选择该类别在第一个列中的颜色作为标准:
standard_colors = {}
for cat, cols in category_columns.items():
if len(cols) > 1:
standard_color = color_mappings[cols[0]][cat]
standard_colors[cat] = standard_color
for col in cols:
color_mappings[col][cat] = standard_color
更新 AnnData 对象的颜色映射
for col in obs_cols:
categories = adata_copy.obs[col].cat.categories
new_colors = [color_mappings[col][cat] for cat in categories]
adata_copy.uns[f'{col}_colors'] = new_colors
画图
def plot_scatter(adata, color, size):
"""
创建一个散点图函数
参数:
adata: AnnData对象,包含绘图所需的数据
color: str, 用于着色的列名
size: float, 散点大小
返回:
fig: matplotlib Figure对象
"""
import scanpy as sc
import matplotlib.pyplot as plt
# 创建散点图
fig = sc.pl.scatter(
adata,
alpha=1,
x="x",
y="y",
color=color,
size=size,
marker='s',
show=False
)
# 设置图形属性
fig.set_aspect('equal', 'box')
fig.invert_yaxis()
fig.grid(False)
# 显示图形
fig.figure.show()
return fig
plot_scatter(adata_copy, "Level1_P2CRC", 12.8)
plot_scatter(adata_copy, "Level2_P2CRC", 12.8)
plot_scatter(adata_copy, "cell_type_3", 12.8)
plot_scatter(adata_copy, "Level1_ct", 12.8)
plot_scatter(adata_copy, "Level2_ct", 12.8)
可以看到不同的图里面颜色都是一样的了。