import numpy as np
import bisect
import matplotlib.pyplot as plt
# 定义 Silu 函数
def silu(x):
return x / (1 + np.exp(-x))
# 分段参数
N_SEGMENTS = 1024 + 1 # 增加分段数(原为512+1)
X_MIN, X_MAX = -8.0, 8.0 # 扩展定义域(原为[-4,4])
DX = (X_MAX - X_MIN) / (N_SEGMENTS - 1)
class SiluLUT:
def __init__(self, is_sys=False):
if is_sys:
self.coeffs, self.seg = self.calculate_symmetric_coeffs()
else:
self.coeffs, self.seg = self.calculate_coefficients()
self.x_min = X_MIN
self.x_max = X_MAX
self.dx = DX
def func_silu(self, a, b, c, x):
return a * x**2 + b * x + c # 二次多项式近似
def __call__(self, x):
# 处理边界外的输入
x = np.clip(x, self.x_min + 1e-6, self.x_max - 1e-6) # 避免边界溢出
# 计算分段索引
idx = np.clip(int((x - self.x_min) / self.dx), 0, N_SEGMENTS - 2)
# 获取系数
a, b, c = self.coeffs[idx]
return self.func_silu(a, b, c, x)
def calculate_coefficients(self):
coefficients = np.zeros((N_SEGMENTS - 1, 3)) # 每个分段存储[a, b, c]
seg = []
for i in range(N_SEGMENTS - 1):
start = X_MIN + i * DX
end = start + DX
seg.append([start, end])
# 增加采样点数量(原为20个)
x_samples = np.linspace(start, end, 100) # 100个采样点
y_samples = silu(x_samples)
# 改进拟合方法:加权最小二乘,中间点权重更高
# 定义左侧加权函数(指数衰减)
# tau = 2.0
# weights = np.exp(-x_samples / tau)
# weights = weights / np.sum(weights)
weights = np.exp(-0.5 * ((x_samples - (start + end)/2) / (DX/4))**2) # 高斯权重
A = np.vstack([x_samples**2, x_samples, np.ones_like(x_samples)]).T
coef, _, _, _ = np.linalg.lstsq(A * weights[:, np.newaxis], y_samples * weights, rcond=None)
coefficients[i] = coef
return coefficients, seg
def evaluate_accuracy(lut, n_samples=100000):
x_test = np.linspace(X_MIN, X_MAX, n_samples)
y_true = silu(x_test)
y_pred = np.array([lut(x) for x in x_test])
# 计算误差
abs_error = np.abs(y_true - y_pred)
max_error = np.max(abs_error)
rmse = np.sqrt(np.mean(abs_error**2))
# 统计超出门限的误差
threshold = 1e-5 # 更严格的门限(原为1e-4)
over_threshold = np.sum(abs_error > threshold) / n_samples
# 绘制误差分布
plt.figure(figsize=(12, 6))
plt.scatter(x_test, abs_error, s=1, alpha=0.5)
plt.axhline(threshold, color='r', linestyle='--', label='Threshold')
plt.title('Absolute Error Distribution')
plt.xlabel('Input Value')
plt.ylabel('Absolute Error')
plt.yscale('log') # 对数坐标更清晰
plt.grid(True)
plt.legend()
plt.savefig('error_distribution_optimized.png')
plt.close()
# 绘制函数对比(高分辨率)
plt.figure(figsize=(12, 6))
plt.plot(x_test, y_true, label='True Silu', linewidth=0.5)
plt.plot(x_test, y_pred, label='LUT Approximation', linewidth=0.5)
plt.title('Silu Function vs. LUT Approximation')
plt.xlabel('Input')
plt.ylabel('Output')
plt.legend()
plt.grid(True)
plt.savefig('function_comparison_optimized.png')
plt.close()
return max_error, rmse, over_threshold
# 测试优化后的实现
optimized_lut = SiluLUT()
max_err, rmse, over_th = evaluate_accuracy(optimized_lut)
print(f"优化实现: MaxError={max_err:.2e}, RMSE={rmse:.2e}, >1e-5比例={over_th*100:.4f}%")
我就想用多项式泰勒级数展开,能否帮我优化好