from matplotlib import pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = Axes3D(fig)
x1 = np.linspace(0,2,40)
x2 = np.linspace(0,2,40)
y1 = np.linspace(0,2,40)
y2 = np.linspace(2,4,40)
x1, y1 = np.meshgrid(x1, y1)
x2, y2 = np.meshgrid(x2, y2)
z1 =x1**2+(2-y1)**2
z2 =x2**2
ax.plot_surface(x1, y1, z1, rstride=1, cstride=1, cmap='rainbow')
ax.plot_surface(x2, y2, z2, rstride=1, cstride=1, cmap='rainbow')
plt.show()
参考: https://blog.youkuaiyun.com/Eddy_zheng/article/details/48713449