python绘制不同分布间运动的粒子的运动轨迹

# -*- coding: utf-8 -*-
"""
Created on Thu Jun 20 17:39:25 2024

@author: PC
"""

import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchdiffeq import odeint_adjoint as odeint
import torch.optim as optim
from scipy.stats import norm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# In[]
def mixed_gaussian_sample(mus, sigmas, mix_proportions, num_samples):
    # 根据混合比例确定每个分布的样本数量
    num_samples_each = np.round(num_samples * np.array(mix_proportions)).astype(int)
    samples = np.array([])

    for i in range(len(mus)):
        samples = np.append(samples, mus[i] + sigmas[i] * np.random.randn(num_samples_each[i]))

    return samples

def mixed_gaussian_pdf(x, mus, sigmas, pi):
    K = len(mus)  # 高斯分布的个数
    N = len(x)  # 样本数量
    pdf = np.zeros(N)  # 初始化概率密度函数的值

    for k in range(K):
        # 计算第 k 个高斯分布的概率密度
        pdf_k = norm.pdf(x, mus[k], sigmas[k])
        # 加权和
        pdf += pi[k] * pdf_k

    return pdf

def _flip(x, dim):
	indices = [slice(None)] * x.dim()
	indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, \
			device=x.device)
	return x[tuple(indices)]


class ODEfunc(nn.Module):
	"""
	Calculates time derivatives.

	torchdiffeq requires this to be a torch.nn.Module.
	"""

	def __init__(self, hidden_dims=(5,5),z_dim=1):
		super(ODEfunc, self).__init__()
		# Define network layers.
		dim_list = [z_dim] + list(hidden_dims) + [z_dim]
		layers = []
		for i in range(len(dim_list)-1):
			layers.append(nn.Linear(dim_list[i]+1, dim_list[i+1]))
		self.layers = nn.ModuleList(layers)


	def get_z_dot(self, t, z):
		"""z_dot is parameterized by a NN: z_dot = NN(t, z(t))"""
		z_dot = z
		for l, layer in enumerate(self.layers):
			# Concatenate t at each layer.
			tz_cat = torch.cat((t.expand(z.shape[0],1), z_dot), dim=1)
			z_dot = layer(tz_cat)
			if l < len(self.layers) - 1:
				z_dot = F.softplus(z_dot)
		return z_dot


	def forward(self, t, states):
		"""
		Calculate the time derivative of z and divergence.

		Parameters
		----------
		t : torch.Tensor
			time
		state : tuple
			Contains two torch.Tensors: z and delta_logpz

		Returns
		-------
		z_dot : torch.Tensor
			Time derivative of z.
		negative_divergence : torch.Tensor
			Time derivative of the log determinant of the Jacobian.
		"""
		z = states[0]
		batchsize = z.shape[0]

		with torch.set_grad_enabled(True):
			z.requires_grad_(True)
			t.requires_grad_(True)

			# Calculate the time derivative of z.
			# This is f(z(t), t; \theta) in Eq. 4.
			z_dot = self.get_z_dot(t, z)

			# Calculate the time derivative of the log determinant of the
			# Jacobian.
			# This is -Tr(\partial z_dot / \partial z(t)) in Eq.s 2-4.
			#
			# Note that this is the brute force, O(D^2), method. This is fine
			# for D=2, but the authors suggest using a Monte-carlo estimate
			# of the trace (Hutchinson's trace estimator, eq. 7) for a linear
			# time estimate in larger dimensions.
			divergence = 0.0
			for i in range(z.shape[1]):
				divergence += \
						torch.autograd.grad( \
							z_dot[:, i].sum(), z, create_graph=True \
						)[0][:, i]

		return z_dot, -divergence.view(batchsize, 1)
    
    
class FfjordModel(torch.nn.Module):
	"""Continuous noramlizing flow model."""

	def __init__(self, hidden_dims=(64,64),z_dim=1):
		super(FfjordModel, self).__init__()
		self.time_deriv_func = ODEfunc(hidden_dims, z_dim)

	def save_state(self, fn='state.tar'):
		"""Save model state."""
		torch.save(self.state_dict(), fn)

	def load_state(self, fn='state.tar'):
		"""Load model state."""
		self.load_state_dict(torch.load(fn))


	def forward(self, z, delta_logpz=None, integration_times=None, \
		reverse=False):
		"""
		Implementation of Eq. 4.

		We want to integrate both f and the trace term. During training, we
		integrate from t_1 (data distribution) to t_0 (base distibution).

		Parameters
		----------
		z : torch.Tensor
			Samples.
		delta_logpz : torch.Tensor
			Log determininant of the Jacobian.
		integration_times : torch.Tensor
			Which times to evaluate at.
		reverse : bool, optional
			Whether to reverse the integration times.

		Returns
		-------
		z : torch.Tensor
			Updated samples.
		delta_logpz : torch.Tensor
			Updated log determinant term.
		"""
		if delta_logpz is None:
			delta_logpz = torch.zeros(z.shape[0], 1).to(device)
		if integration_times is None:
			integration_times = torch.tensor([0.0, 1.0]).to(z)
		if reverse:
			integration_times = _flip(integration_times, 0)

		# Integrate. This is the call to torchdiffeq.
		state = odeint(
			self.time_deriv_func, # Calculates time derivatives.
			(z, delta_logpz), # Values to update.
			integration_times, # When to evaluate.
			method='dopri5', # Runge-Kutta
			atol=1e-5, # Error tolerance
			rtol=1e-5, # Error tolerance
		)

		if len(integration_times) == 2:
			state = tuple(s[1] for s in state)
			z, delta_logpz = state
			return z, delta_logpz
		else:
			return state
    

def standard_normal_logprob(z):
	"""2d standard normal, sum over the second dimension."""
	return (-np.log(2 * np.pi) - 0.5 * z.pow(2)).sum(1, keepdim=True)


# In[] 参数初始化与设置
n = 200  # 将定义域划分为 n 部分
m = 800  # 将时间域划分为 m 部分
trans_map = np.zeros((m, n))   # 用于存放样本转移轨迹的矩阵
t = torch.linspace(0, 1, m).to(device)  # 时间区间 [0, 1]

min_val, max_val = -3, 3
bin_width = (max_val - min_val) / n

num_samples = 2000  # 确定总的样本数量

# 混合高斯分布设置
mus = [-2, 0, 2]  # 均值
sigmas = [0.5, 0.5, 0.5]  # 标准差

# 定义混合比例,这里我们假设三个分布的混合比例相等
mix_proportions = [1/3, 1/3, 1/3]

# 初始化神经网络
input_dim = 1
hidden_dim = 10

num_epochs = 10000

# In[] 混合高斯
# 生成混合高斯分布的样本
samples_mixnormal = mixed_gaussian_sample(mus, sigmas, mix_proportions, num_samples)

# 计算混合高斯分布的概率密度函数
pdf_mixnormal = mixed_gaussian_pdf(samples_mixnormal, mus, sigmas, mix_proportions)

# In[]
# 训练神经网络
z_t1 = torch.tensor(samples_mixnormal,dtype=torch.double).reshape(-1,1).to(t)
model = FfjordModel(hidden_dims=(hidden_dim, hidden_dim),z_dim=input_dim).to(device)
if False:
	model.load_state()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

model.train()
for epoch in range(num_epochs):
    optimizer.zero_grad()
    z_t0, delta_logpz = model(z_t1)
    
    # Log likelihood of the base distribution samples.
    logpz_t0 = standard_normal_logprob(z_t0)
    
    # Subtract the correction term (log determinant of the Jacobian). Note
    # that we integrated from t_1 to t_0 and integrated a negative trace
    # term, so the signs align with Eq. 3.
    logpz_t1 = logpz_t0 - delta_logpz
    loss = -torch.mean(logpz_t1)
    
    if epoch % 100 == 0:
        print(f'Epoch {epoch}/{num_epochs}, Loss: {loss.item()}')
    
    loss.backward()
    optimizer.step()
    
model.save_state()

# In[] 获取结果
num_samples = 40000
samples_mixnormal = mixed_gaussian_sample(mus, sigmas, mix_proportions, num_samples)
z_t1 = torch.tensor(samples_mixnormal,dtype=torch.double).reshape(-1,1).to(t)

with torch.no_grad():
    model = FfjordModel(hidden_dims=(hidden_dim, hidden_dim),z_dim=input_dim).to(device)
    model.load_state()
    
    z_t, _ = model(z_t1, integration_times=t)
    z_t = z_t.squeeze()   # [m, batch_size]
    
# 对每一行进行操作
for row in range(m):
    # 计算当前行的频数分布
    counts = torch.histc(z_t[row, :], bins=n, min=min_val, max=max_val)
    
    # 将频数分布赋值给结果数组
    trans_map[row, :] = counts.cpu()

# In[] 绘制结果
plt.rcParams['figure.figsize'] = (6.0, 6.0)    
plt.rcParams['savefig.dpi'] = 1000 #图片像素
plt.rcParams['figure.dpi'] = 1000 #分辨率
cmap = plt.colormaps["plasma"]
cmap = cmap.with_extremes(bad=cmap(0))

plt.figure(figsize=(6.0, 8.0))
plt.gca().set_facecolor('none')
plt.imshow(trans_map, aspect='auto', cmap=cmap)  # , aspect='auto' 保持图像的长宽比
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.axis('off') # 去掉坐标轴的框架
plt.show()

# plt.savefig(save_path, bbox_inches='tight',pad_inches=0)

plt.figure(figsize=(6.0, 6.0))
x = np.linspace(min_val, max_val, 1000)
pdf = mixed_gaussian_pdf(x, mus, sigmas, mix_proportions)
plt.plot(x, pdf, color='black', linewidth=2)
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.axis('off') # 去掉坐标轴的框架
plt.show()

plt.figure(figsize=(6.0, 6.0))
pdf2 = norm.pdf(x, 0, 1)
plt.plot(x, pdf2, color='black', linewidth=2)
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.axis('off') # 去掉坐标轴的框架
plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

三只佩奇不结义

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值