Matplotlib 可视化大师系列(二):plt.scatter() - 探索变量关系的散点图

Matplotlib 可视化大师系列博客总览

本系列旨在提供一份系统、全面、深入的 Matplotlib 学习指南。以下是博客列表:

  1. 基础篇plt.plot() - 绘制折线图的利刃
  2. 分布篇plt.scatter() - 探索变量关系的散点图
  3. 比较篇plt.bar()plt.barh() - 清晰对比的柱状图
  4. 统计篇plt.hist()plt.boxplot() - 洞察数据分布
  5. 占比篇plt.pie() - 展示组成部分的饼图
  6. 高级篇plt.imshow() - 绘制矩阵与图像的强大工具
  7. 专属篇绘制误差线 (plt.errorbar())、等高线 (plt.contour()) 等特殊图表
  8. 综合篇在一张图中组合多种图表类型

Matplotlib 可视化大师系列(二):plt.scatter() - 探索变量关系的散点图

散点图是数据科学中最基本但最强大的探索性分析工具之一。它通过在二维平面上绘制数据点来展示两个连续变量之间的关系,帮助我们识别模式、趋势、异常值和相关性。Matplotlib 的 plt.scatter() 函数专门用于创建这种重要的可视化图表。

一、 散点图是什么?何时使用?

散点图使用笛卡尔坐标系中的点来显示两个变量的数值关系。每个点的位置由两个变量的值决定:一个变量确定X轴位置,另一个确定Y轴位置。

核心用途:

  • 探索相关性:判断两个变量之间是否存在正相关、负相关或无关系
  • 识别聚类:发现数据中自然形成的分组或集群
  • 检测异常值:识别远离主要数据分布的异常点
  • 可视化分布:了解两个变量联合分布的特征
  • 发现非线性关系:识别变量间的曲线关系或其他复杂模式

与折线图的区别:

  • 折线图:强调数据的连续性和趋势,点按顺序连接,适用于时间序列或有序数据
  • 散点图:强调数据的分布和关系,点不连接,适用于探索两个变量间的统计关系

二、 函数原型与核心参数

plt.scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, edgecolors=None, **kwargs)

核心参数详解:

  1. 基础参数:

    • x, y: 数据点的x坐标和y坐标数组,长度必须相同
    • s: 点的大小。可以是标量(所有点相同大小)或与x,y长度相同的数组(每个点不同大小)
    • c: 点的颜色。可以是单一颜色字符串、RGB元组,或与x,y长度相同的数组(用于颜色映射)
  2. 标记样式:

    • marker: 点的形状。常用值:'o'(圆), 's'(方), '^'(上三角), 'v'(下三角), '<'(左三角), '>'(右三角), '*'(星), '+'(加), 'x'(叉)
    • alpha: 透明度,0(完全透明)到1(完全不透明)
    • linewidths: 标记边缘线宽
    • edgecolors: 标记边缘颜色
  3. 颜色映射:

    • cmap: 颜色映射,当c是数值数组时使用。常用值:'viridis', 'plasma', 'coolwarm', 'RdYlBu', 'Spectral'
    • vmin, vmax: 颜色映射的数据范围
    • norm: 数据标准化方法

三、 从入门到精通:代码示例

示例 1:基础散点图

import matplotlib.pyplot as plt
import numpy as np

# 生成示例数据
np.random.seed(42)
x = np.random.randn(100)  # 100个随机点
y = 2 * x + np.random.randn(100) * 0.5  # 添加一些噪声

plt.figure(figsize=(10, 6))
plt.scatter(x, y, alpha=0.7)  # 使用透明度避免重叠点看不清
plt.title('Basic Scatter Plot')
plt.xlabel('X Variable')
plt.ylabel('Y Variable')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

示例 2:多维度数据可视化

散点图的真正威力在于可以通过大小和颜色编码第三个甚至第四个变量。

# 生成更多维度的数据
np.random.seed(123)
n_points = 50
x = np.random.rand(n_points) * 10
y = np.random.rand(n_points) * 10
sizes = np.random.rand(n_points) * 300  # 点的大小数组
colors = np.random.rand(n_points)  # 点的颜色值数组
categories = np.random.choice(['A', 'B', 'C'], n_points)  # 分类变量

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))

# 1. 用大小表示第三个连续变量
sc1 = ax1.scatter(x, y, s=sizes, alpha=0.6, edgecolors='black', linewidth=0.5)
ax1.set_title('Size Encodes Third Variable')
ax1.set_xlabel('X')
ax1.set_ylabel('Y')

# 手动创建图例表示大小含义
for size in [100, 200, 300]:
    ax1.scatter([], [], s=size, alpha=0.6, label=str(size), edgecolors='black', linewidth=0.5)
ax1.legend(title='Size', scatterpoints=1, frameon=True)

# 2. 用颜色表示第三个连续变量(颜色映射)
sc2 = ax2.scatter(x, y, c=colors, s=100, cmap='viridis', alpha=0.7)
ax2.set_title('Color Encodes Third Variable (Continuous)')
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
plt.colorbar(sc2, ax=ax2, label='Color Value')

# 3. 用颜色表示分类变量
category_colors = {'A': 'red', 'B': 'blue', 'C': 'green'}
color_list = [category_colors[cat] for cat in categories]

sc3 = ax3.scatter(x, y, c=color_list, s=100, alpha=0.7)
ax3.set_title('Color Encodes Categorical Variable')
ax3.set_xlabel('X')
ax3.set_ylabel('Y')

# 为分类数据创建图例
from matplotlib.lines import Line2D
legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor=color, 
                          markersize=10, label=cat) for cat, color in category_colors.items()]
ax3.legend(handles=legend_elements)

plt.tight_layout()
plt.show()

示例 3:高级应用与样式定制

from matplotlib.colors import LogNorm
from sklearn.datasets import make_blobs

# 创建更复杂的数据
X, y = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=0)
density = np.exp(-0.1 * (X[:, 0]**2 + X[:, 1]**2))  # 模拟密度数据

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# 1. 聚类可视化与密度表示
sc1 = ax1.scatter(X[:, 0], X[:, 1], c=y, s=100, cmap='Spectral', 
                 alpha=0.7, edgecolor='black', linewidth=0.5)
ax1.set_title('Cluster Visualization with Colormap')
ax1.set_xlabel('Feature 1')
ax1.set_ylabel('Feature 2')
plt.colorbar(sc1, ax=ax1, label='Cluster ID')

# 2. 使用对数标准化和自定义颜色映射
sc2 = ax2.scatter(X[:, 0], X[:, 1], c=density, s=70, cmap='plasma', 
                 norm=LogNorm(vmin=0.01, vmax=1), alpha=0.8,
                 edgecolors='white', linewidth=0.3)
ax2.set_title('Density Visualization with Log Color Scale')
ax2.set_xlabel('Feature 1')
ax2.set_ylabel('Feature 2')
cbar = plt.colorbar(sc2, ax=ax2)
cbar.set_label('Density (log scale)')

# 添加趋势线
z = np.polyfit(X[:, 0], X[:, 1], 1)
p = np.poly1d(z)
ax2.plot(X[:, 0], p(X[:, 0]), "r--", alpha=0.8, linewidth=2, label='Trend line')
ax2.legend()

plt.tight_layout()
plt.show()

示例 4:处理过度绘制与大数据集

当数据点很多时,散点图会出现过度绘制问题,可以使用以下技巧解决:

# 创建大型数据集
np.random.seed(42)
x_large = np.random.randn(10000)
y_large = 0.5 * x_large + np.random.randn(10000) * 0.5

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))

# 1. 普通散点图(过度绘制严重)
ax1.scatter(x_large, y_large, alpha=0.1, s=10)
ax1.set_title('Standard Scatter (Overplotting)')
ax1.set_xlabel('X')

# 2. 使用六边形分箱图
hb = ax2.hexbin(x_large, y_large, gridsize=50, cmap='viridis', mincnt=1)
ax2.set_title('Hexbin Plot for Large Data')
ax2.set_xlabel('X')
plt.colorbar(hb, ax=ax2, label='Count')

# 3. 使用二维直方图
from matplotlib.colors import LogNorm
h = ax3.hist2d(x_large, y_large, bins=50, cmap='plasma', norm=LogNorm())
ax3.set_title('2D Histogram for Large Data')
ax3.set_xlabel('X')
plt.colorbar(h[3], ax=ax3, label='Count')

for ax in [ax1, ax2, ax3]:
    ax.set_ylabel('Y')

plt.tight_layout()
plt.show()

四、 最佳实践与常见陷阱

  1. 最佳实践:

    • 使用透明度:设置alpha=0.5-0.7来解决过度绘制问题
    • 选择合适的点大小:点不应该太大(导致过度重叠)或太小(难以看到)
    • 有意义地使用颜色:使用颜色表示重要的第三个变量,而不是随意装饰
    • 添加趋势线:对于明显的关系,添加趋势线或回归线帮助解释
    • 考虑数据变换:对于非线性关系,考虑对x或y轴使用对数尺度
  2. 常见陷阱:

    • 过度绘制:数据点太多导致无法看清模式(解决方案:使用透明度、分箱图或采样)
    • 误导性的比例:不适当的轴范围可能扭曲数据关系(确保轴从0开始或明确标注)
    • 忽略异常值:异常值可能包含重要信息,不应该随意删除
    • 相关不等于因果:散点图显示关系但不能证明因果关系
    • 过度解释:避免从少量数据点或弱关系中得出强烈结论
  3. 进阶技巧:

    • 边际分布图:结合直方图或箱线图显示每个变量的单独分布
    • 分组散点图:使用不同颜色/形状区分数据子集
    • 动态散点图:使用Plotly或Bokeh创建交互式散点图
    • 高维数据:使用PCA或t-SNE等降维技术后再可视化

五、 总结

plt.scatter() 是探索数据关系的入门工具,但也是最重要的工具之一:

  • 核心功能:展示两个连续变量间的关系,识别模式、聚类和异常值
  • 关键参数s(大小), c(颜色), marker(形状), alpha(透明度), cmap(颜色映射)
  • 高级应用:通过大小和颜色编码第三、第四个变量,处理大数据集
  • 最佳实践:使用透明度避免过度绘制,有意义地使用颜色,添加趋势线

掌握散点图意味着你拥有了探索数据关系的第一件强大武器。它是任何数据分析项目的起点,能够为你提供初步的洞察,指导后续的深入分析。在下一篇文章中,我们将学习如何创建清晰的比较图表——柱状图。

<think>嗯,用户想用matplotlibplt.scatter绘制三维散点图,但可能不太清楚具体怎么做。首先,我得回忆一下matplotlib中3D图的基本使用方法。记得普通的散点图scatter没问题,但三维的话需要引入Axes3D模块,或者使用投影参数。然后,可能需要创建一个三维子图,比如用fig.add_subplot(111, projection='3d')。 接着,用户可能的数据结构是什么样的?通常三维数据有三个坐标轴,所以需要准备x、y、z三个数组。例如示例中的x = [1,2,3,4,5],y = [6,7,2,4,5],再加上一个z的数组。然后,绘制的时候应该用scatter方法,但是这里可能需要注意,在三维坐标系下,scatter的参数是否有所不同? 突然想到,matplotlib中的3D散点图其实是使用scatter函数,但需要指定三个坐标轴的数据。可能需要查阅一下文档,确认正确的调用方式。比如,在创建了3D坐标轴后,调用ax.scatter(x, y, z, ...),其中x、y、z是数据点。 另外,颜色、大小、标记等参数的设置应该和维类似,比如color、s、marker等参数。但是透明度alpha可能也适用。用户可能还关心如何调整坐标轴标签、标题,以及如何显示图例,或者调整视角。 还有,用户提到的plt.scatter是否可以直接用于3D?可能不行,因为plt.scatter通常是维的,所以必须通过3D轴对象来调用scatter方法。这一步可能会让用户感到困惑,所以需要明确指出创建3D轴的重要性。 然后,示例代码的结构应该是:导入必要的库(matplotlib.pyplot,mpl_toolkits.mplot3d中的Axes3D),生成数据,创建图形和3D轴,调用scatter方法,设置标签和标题,最后显示图形。 可能需要检查代码的正确性,比如是否确实需要从mpl_toolkits.mplot3d导入Axes3D,尽管有些版本的matplotlib可能已经自动处理了。不过为了兼容性,最好明确导入。 最后,用户可能的问题还包括如何调整点的颜色、大小,或者添加颜色条。比如,使用c参数和cmap参数来设置颜色映射,或者用s参数调整点的大小。这些可以作为扩展内容在示例中展示,或者作为相关问题提出。 总结一下,回答应该包括导入库、创建3D轴、调用scatter方法、设置图表属性,并提供完整示例代码。同时,考虑到用户可能的需求,生成相关问题来引导深入学习,比如如何设置颜色、调整视角、处理大数据集等。</think>### 如何使用 `plt.scatter` 绘制三维散点图Matplotlib 中绘制三维散点图需要借助 `mpl_toolkits.mplot3d` 模块。以下是详细步骤和示例代码: --- #### **1. 导入必要库** 首先导入 Matplotlib 的 `pyplot` 和 3D 绘图工具: ```python import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D ``` --- #### **2. 准备数据** 生成三维坐标数据(假设 `x`, `y`, `z` 是三个一维数组): ```python x = [1, 2, 3, 4, 5] y = [6, 7, 2, 4, 5] z = [8, 3, 5, 9, 1] ``` --- #### **3. 创建三维坐标系** 通过 `figure` 和 `add_subplot` 创建 3D 画布: ```python fig = plt.figure(figsize=(8, 6)) # 设置画布尺寸 ax = fig.add_subplot(111, projection='3d') # 启用三维投影 ``` --- #### **4. 绘制三维散点图** 使用 `scatter` 方法并传入 `x`, `y`, `z` 数据: ```python ax.scatter( x, y, z, s=50, # 点的大小(屏幕像素或数据单位) c='red', # 颜色(支持颜色名称或 RGB 元组) marker='o', # 标记形状(如 'o', '^', 's') alpha=0.8, # 透明度(0-1) edgecolors='k' # 边缘颜色 ) ``` --- #### **5. 设置图表属性** 添加坐标轴标签和标题: ```python ax.set_xlabel('X Label') ax.set_ylabel('Y Label') ax.set_zlabel('Z Label') ax.set_title('3D Scatter Plot') ``` --- #### **6. 显示图形** ```python plt.show() ``` --- #### **完整代码示例** ```python import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D x = [1, 2, 3, 4, 5] y = [6, 7, 2, 4, 5] z = [8, 3, 5, 9, 1] fig = plt.figure(figsize=(8, 6)) ax = fig.add_subplot(111, projection='3d') ax.scatter(x, y, z, s=50, c='red', marker='o', alpha=0.8, edgecolors='k') ax.set_xlabel('X Label') ax.set_ylabel('Y Label') ax.set_zlabel('Z Label') ax.set_title('3D Scatter Plot') plt.show() ``` --- #### **关键参数说明** - `s`:点的大小,支持标量或与数据长度相同的数组[^1] - `c`:颜色,支持颜色名称(如 `'blue'`)、十六进制值(如 `'#1f77b4'`)或 RGB 元组(如 `(0.1, 0.2, 0.5)`) - `cmap`:颜色映射(需与 `c` 参数配合使用,如 `c=z, cmap='viridis'`) - `depthshade`:是否根据深度调整颜色深浅(默认为 `True`) ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值