import os
import time
import pickle
import numpy as np
import jax
import jax.numpy as jnp
import quimb
from quimb.gates import * # 正确导入量子门
import matplotlib.pyplot as plt
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "gpu")
# ======================
# Noise Models (Kraus)
# ======================
def single_qubit_depolarizing_kraus(p):
I = jnp.eye(2)
X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex128)
Y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex128)
Z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex128)
return [jnp.sqrt(1 - p) * I,
jnp.sqrt(p / 3.0) * X,
jnp.sqrt(p / 3.0) * Y,
jnp.sqrt(p / 3.0) * Z]
def two_qubit_depolarizing_kraus(p):
I = jnp.eye(4)
paulis = [jnp.eye(2),
jnp.array([[0, 1], [1, 0]]),
jnp.array([[0, -1j], [1j, 0]]),
jnp.array([[1, 0], [0, -1]])]
kraus = [jnp.sqrt(1 - p) * I]
for a in paulis:
for b in paulis:
# 排除掉 Identity x Identity 的情况,因为 (1-p) * I 已经包含了
if jnp.allclose(a, jnp.eye(2)) and jnp.allclose(b, jnp.eye(2)):
continue
op = jnp.kron(a, b)
kraus.append(jnp.sqrt(p / 15.0) * op)
return kraus
def bit_flip_kraus(p):
I = jnp.eye(2)
X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex128)
return [jnp.sqrt(1 - p) * I, jnp.sqrt(p) * X]
# ======================
# Utility: embed ops
# ======================
def embed_single(U, n, q):
op = 1
# 确保eye的dtype一致,并处理U可能来自quimb的qarray而非jnp.ndarray的情况
if not isinstance(U, jnp.ndarray):
U = jnp.array(U)
target_dtype = U.dtype
# 针对JAX操作,我们需要构建一个列表,然后使用reduce
ops = [U if i == q else jnp.eye(2, dtype=target_dtype) for i in range(n)]
return jax.tree_util.tree_reduce(jnp.kron, ops)
def embed_two(U, n, q1, q2):
# quimb 的门函数返回的是 qarray,确保类型一致性
if not isinstance(U, jnp.ndarray):
U = jnp.array(U)
target_dtype = U.dtype
if abs(q1 - q2) != 1:
# 针对任意两比特门,需要更复杂的实现,这里假设是相邻的
raise NotImplementedError("Only adjacent 2-qubit gates supported for embed_two.")
low = min(q1, q2)
# 针对JAX操作,我们需要构建一个列表,然后使用reduce
ops = []
i = 0
while i < n:
if i == low:
ops.append(U)
i += 2
else:
ops.append(jnp.eye(2, dtype=target_dtype))
i += 1
return jax.tree_util.tree_reduce(jnp.kron, ops)
# ======================
# Hamiltonian (Heisenberg)
# ======================
def build_heisenberg_H(n, jx=1.0, jy=1.0, jz=1.0, hz=1.0):
X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex128)
Y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex128)
Z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex128)
H = jnp.zeros((2 ** n, 2 ** n), dtype=jnp.complex128)
for i in range(n - 1):
Xi, Xj = embed_single(X, n, i), embed_single(X, n, i + 1)
Yi, Yj = embed_single(Y, n, i), embed_single(Y, n, i + 1)
Zi, Zj = embed_single(Z, n, i), embed_single(Z, n, i + 1)
H += jx * (Xi @ Xj) + jy * (Yi @ Yj) + jz * (Zi @ Zj)
for i in range(n):
Zi = embed_single(Z, n, i)
H += hz * Zi
return H
# ======================
# Energy (density-matrix evolution)
# ======================
def compute_energy(params, structure, H, n_qubit, noise_param, circuit):
dim = 2 ** n_qubit
rho = jnp.zeros((dim, dim), dtype=jnp.complex128)
rho = rho.at[0, 0].set(1.0) # 初始态 |0...0><0...0|
p1, p2, pb = (noise_param["single_qubit_depolarizing_p"],
noise_param["two_qubit_channel_depolarizing_p"],
noise_param["bit_flip_p"])
single_kraus = single_qubit_depolarizing_kraus(p1)
two_kraus = two_qubit_depolarizing_kraus(p2)
bit_kraus = bit_flip_kraus(pb)
param_idx = 0
for gate_info in structure:
name = gate_info[0]
if name in ["rx", "ry", "rz"]:
theta = params[param_idx]
q_idx = gate_info[1]
gate_func = globals()[name.upper()] # 获取门函数
U = gate_func(float(theta)) # 调用门函数生成门矩阵
if not isinstance(U, jnp.ndarray): # 确保是JAX数组
U = jnp.array(U, dtype=jnp.complex128)
fullU = embed_single(U, n_qubit, q_idx)
rho = fullU @ rho @ fullU.conj().T
new_rho = jnp.zeros_like(rho)
for K in single_kraus:
fullK = embed_single(K, n_qubit, q_idx)
new_rho += fullK @ rho @ fullK.conj().T
rho = new_rho
param_idx += 1
elif name in ["cx", "cz"]:
q1_idx, q2_idx = gate_info[1], gate_info[2]
if name == "cx":
U_op = CNOT()
else: # name == "cz"
U_op = CZ()
if not isinstance(U_op, jnp.ndarray): # 确保是JAX数组
U = jnp.array(U_op, dtype=jnp.complex128)
else:
U = U_op
fullU = embed_two(U, n_qubit, q1_idx, q2_idx)
rho = fullU @ rho @ fullU.conj().T
new_rho = jnp.zeros_like(rho)
for K in two_kraus:
fullK = embed_two(K, n_qubit, q1_idx, q2_idx)
new_rho += fullK @ rho @ fullK.conj().T
rho = new_rho
else:
raise ValueError(f"Unsupported gate {name}")
# global bit flip (在所有门应用之后,对每个量子比特独立施加)
for q in range(n_qubit):
new_rho_q_noise = jnp.zeros_like(rho)
for K_bit in bit_kraus:
fullK_bit = embed_single(K_bit, n_qubit, q)
new_rho_q_noise += fullK_bit @ rho @ fullK_bit.conj().T
rho = new_rho_q_noise # 更新密度矩阵为施加了比特翻转噪声后的结果
e = jnp.trace(rho @ H)
return jnp.real(e)
# ======================
# Adam optimizer
# ======================
def adam_update(params, grads, m, v, t, lr=0.05, b1=0.9, b2=0.999, eps=1e-8):
m = b1 * m + (1 - b1) * grads
v = b2 * v + (1 - b2) * (grads ** 2)
m_hat = m / (1 - b1 ** t)
v_hat = v / (1 - b2 ** t)
params = params - lr * m_hat / (jnp.sqrt(v_hat) + eps)
return params, m, v
# ======================
# Single-circuit training
# ======================
def train_single_circuit(sample, H, n_qubit, noise_param, max_iter=200, lr=0.05):
if isinstance(sample, dict):
structure = sample.get("structure", None)
elif isinstance(sample, (list, tuple)):
structure = sample[0] if isinstance(sample[0], list) and len(sample[0]) > 0 and isinstance(sample[0][0], (list,
tuple)) else sample
else:
raise TypeError(f"Unknown sample type: {type(sample)}")
if not structure:
raise ValueError("Circuit structure not found in sample.")
# 将结构转换为元组,以确保它是可哈希的
structure = tuple(tuple(gate) for gate in structure)
n_param = sum(1 for g in structure if g[0] in ["rx", "ry", "rz"])
params = jnp.array(np.random.normal(0, 0.1, n_param), dtype=jnp.float64) # 确保参数是float64
m = jnp.zeros_like(params)
v = jnp.zeros_like(params)
# 编译能量计算函数时,确保将静态参数传入 (移除H)
compiled_compute_energy = jax.jit(compute_energy, static_argnums=(1, 2, 4, 5)) # 移除H的静态参数
# 对求导函数进行JIT编译
grad_fn = jax.grad(compiled_compute_energy, argnums=0) # 不再需要 static_argnums
energy_history = []
for it in range(1, max_iter + 1):
e = compiled_compute_energy(params, structure, H, n_qubit, noise_param, sample)
grads = grad_fn(params, structure, H, n_qubit, noise_param, sample)
params, m, v = adam_update(params, grads, m, v, it, lr)
energy_history.append(float(e))
if it % 20 == 0:
print(f" Iter {it:3d} | Energy: {float(e):.6f}")
return float(energy_history[-1]), np.array(energy_history)
# ======================
# Batch training
# ======================
def batch_train(device_name, task_name, run_id, n_qubit, noise_param, max_iter=200):
base_dir = f"result/cir_sample/{device_name}_{task_name}/training/run_{run_id}/"
# 确保文件路径正确
sample_file_path = os.path.join(base_dir, "samples.pkl")
if not os.path.exists(sample_file_path):
print(f"错误:未找到样本文件: {sample_file_path}")
print("请检查 `result/cir_sample/` 路径下是否存在 `grid_16q_Heisenberg_12/training/run_0/samples.pkl` 文件。")
return
# 从 samples.pkl 加载样本数据
try:
with open(sample_file_path, "rb") as f:
samples = pickle.load(f)[:10]
print(f"成功加载 {len(samples)} 个样本。")
except Exception as e:
print(f"加载样本文件失败: {e}")
return
H = build_heisenberg_H(n_qubit)
os.makedirs(base_dir + "quimb_results/", exist_ok=True)
all_energy_curves = []
total_start = time.time()
for i, sample in enumerate(samples):
print(f"\nTraining circuit {i} ...")
try:
final_e, energy_curve = train_single_circuit(sample, H, n_qubit, noise_param, max_iter)
all_energy_curves.append(energy_curve)
# 绘制每条能量曲线,但不要每次都保存,最后统一保存
plt.plot(energy_curve, label=f"Circuit {i}")
except Exception as e:
print(f"❌ Circuit {i} failed: {e}")
continue
if all_energy_curves:
np.save(base_dir + "quimb_results/energy_curves.npy", all_energy_curves)
plt.xlabel("Iteration")
plt.ylabel("Energy")
plt.title(f"VQE Energy Convergence for {n_qubit} Qubits ({device_name}_{task_name})")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(base_dir + "quimb_results/vqe_quimb_energy.png")
print(f"\n训练结果图已保存至: {base_dir + 'quimb_results/vqe_quimb_energy.png'}")
# 显示一张总的能量曲线图
plt.figure(figsize=(10, 6))
for j, curve in enumerate(all_energy_curves):
plt.plot(curve, label=f"Circuit {j}")
plt.xlabel("Iteration")
plt.ylabel("Energy")
plt.title(f"VQE Energy Convergence for {n_qubit} Qubits ({device_name}_{task_name}) - All Circuits")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(base_dir + "quimb_results/vqe_quimb_energy_combined.png")
print(f"\n组合能量曲线图已保存至: {base_dir + 'quimb_results/vqe_quimb_energy_combined.png'}")
else:
print("\n没有电路成功训练,未生成能量曲线图。")
print(f"\n总GPU训练时间: {time.time() - total_start:.2f} s")
# ======================
# Main
# ======================
if __name__ == "__main__":
device_name = "grid_16q"
task_name = "Heisenberg_12"
n_qubit = 12
run_id = 0
noise_param = {
"two_qubit_channel_depolarizing_p": 0.01,
"single_qubit_depolarizing_p": 0.001,
"bit_flip_p": 0.01
}
batch_train(device_name, task_name, run_id, n_qubit, noise_param, max_iter=300)
逐字逐句分析这段代码,挑错。
最新发布