插值算法
interpolation.cpp
#include <utility>
#include <stdexcept>
#include "interpolation.h"
constexpr double C_PI = 3.14159265358979323846;
static inline bool fuzzyIsNull(double d)
{
return std::abs(d) < 0.000000000001;
}
static inline bool fuzzyIsNull(float f)
{
return std::abs(f) <= 0.00001f;
}
Interpolation::Vector<Interpolation::data_type> SineInterp::interp(const Vector<data_type> &data, unsigned int interpNum) noexcept(false)
{
// 理论上正弦插值每个周期至少需要三个点才能重现波形
if(data.size() < 3)
{
throw std::invalid_argument("[SineInterp::interp] data.size() < 3;");
}
const int sz = static_cast<int>(data.size());
int M = int(interpNum) + 1;
Vector<data_type> ret((sz - 1) * interpNum + sz);
for(int i = 0; i < sz; ++i)
{
ret[M * i] = data[i];
}
for(int k = 0; k < sz - 1; ++k)
{
for(int m = 1; m < M; ++m)
{
for(int n = 0; n < sz; ++n)
{
data_type x = (k - n + static_cast<data_type>(m) / M) * C_PI;
ret[k * M + m] += data[n] * (std::sin(x) / x);
}
}
}
return std::move(ret);
}
Interpolation::Vector<Interpolation::data_type> CubicHermiteInterp::interp(const Vector<data_type> &data, unsigned int interNum) noexcept(false)
{
// 只要x坐标是等间隔的,波形最终形状就与x无关,因此自己构造与y相等个数的x坐标
const size_type sz = data.size();
Vector<data_type> x(sz);
for(size_type i = 0; i < sz; ++i)
{
x[i] = static_cast<data_type>(i * 10);
}
return std::move(interp(x, data, interNum).second);
}
std::pair<Interpolation::Vector<Interpolation::data_type>, Interpolation::Vector<Interpolation::data_type> >
CubicHermiteInterp::interp(const Vector<data_type> &x,
const Vector<data_type> &y,
unsigned int interpNum) noexcept(false)
{
if(x.size() != y.size())
{
throw std::invalid_argument("[CubicHermiteInterp::interp] x.size() != y.size()");
}
const size_type sz = x.size();
// 判断x值是否递增
data_type deltaX = 0;
for(size_type i = 0; i < sz - 1; ++i)
{
deltaX = x[i+1] - x[i];
if(deltaX <= 0)
{
throw std::invalid_argument("[CubicHermiteInterp::interp] x vector not monotone increasing;");
}
}
// 求取中间每个点处的斜率
Vector<data_type> derivative(sz);
for(size_type i = 1; i < sz - 1; ++i)
{
if((y[i] > y[i+1] && y[i] > y[i-1])
|| (y[i] < y[i+1] && y[i] < y[i-1]))
{
derivative[i] = 0;
continue;
}
double deltaY1 = y[i] - y[i - 1];
double deltaY2 = y[i+1] - y[i];
double prevRatio = deltaY1 / (x[i] - x[i - 1]);
double backRatio = deltaY2 / (x[i + 1] - x[i]);
double deltaY = deltaY1 + deltaY2;
// 非极值点,导数应该起到让通过该点的曲线能够单调的平滑过渡到极值点,这里存在优化空间
if(fuzzyIsNull(deltaY1) || fuzzyIsNull(deltaY2))
{
derivative[i] = 0;
}
else
{
derivative[i] = prevRatio * (deltaY2 / deltaY) + backRatio * (deltaY1 / deltaY);
}
}
// 求第一个点处的斜率
double ratio0 = (y[0] - y[1]) / (x[0] - x[1]);
if(fuzzyIsNull(derivative[1]))
{
derivative[0] = ratio0 * 2; // 后一个点是极值点,曲线在第一段应该更陡
}
else
{
derivative[0] = ratio0 / 2; // 后一个点不是极值点,曲线在第一段应该趋于平缓
}
// 求取最后一个点处的斜率
const size_type iLast = sz - 1;
double ratioN = (y[iLast] - y[iLast - 1]) / (x[iLast] - x[iLast - 1]);
if(derivative[iLast-1] < -1 || derivative[iLast-1] > 1)
{
derivative[iLast] = ratioN / 2; // 倒数第二个点处曲线较平滑,最后一段也应该趋于平缓
}
else
{
derivative[iLast] = ratioN * 2; // 倒数第二个点处曲线较陡,最后一段也应该保持这个趋势
}
const size_type n = interpNum;
const size_type M = interpNum + 1;
const size_type retSz = (x.size() - 1) * interpNum + x.size();
Vector<data_type> retY(retSz);
Vector<data_type> retX(retSz);
for(size_type k = 0; k < sz - 1; ++k)
{
retY[k * M] = y[k];
retX[k * M] = x[k];
for(size_type j = 1; j <= n; ++j)
{
auto currX = x[k] + (x[k+1] - x[k]) * j / (n + 1);
auto div1 = (currX - x[k+1]) / (x[k] - x[k+1]);
auto div2 = (currX - x[k]) / (x[k+1] - x[k]);
auto currY = div1 * div1 * (1 + 2 * div2) * y[k]
+ div2 * div2 * (1 + 2 * div1) * y[k+1]
+ div1 * div1 * (currX - x[k]) * derivative[k]
+ div2 * div2 * (currX - x[k+1]) * derivative[k+1];
retY[k * M + j] = currY;
retX[k * M + j] = currX;
}
}
retY.back() = y.back();
return std::move(std::make_pair(std::move(retX), std::move(retY)));
}
Interpolation::Vector<Interpolation::data_type> SplineInterp::interp(const Vector<data_type> &data,
unsigned int interpNum) noexcept(false)
{
const size_type sz = data.size();
Vector<data_type> xData(sz);
for(size_type i = 0; i < sz; ++i)
{
xData[i] = static_cast<data_type>(i * 10);
}
return std::move(interp(xData, data, interpNum).second);
}
std::pair<Interpolation::Vector<Interpolation::data_type>, Interpolation::Vector<Interpolation::data_type>>
SplineInterp::interp(const Vector<data_type> &x,
const Vector<data_type> &y,
unsigned int interpNum) noexcept(false)
{
if(!setDatas(x, y))
{
return std::make_pair(Vector<data_type>(), Vector<data_type>());
}
const size_type sz = (x.size() - 1) * interpNum + x.size();
const size_type M = interpNum + 1;
Vector<data_type> xDatas(sz);
Vector<data_type> yDatas(sz);
for(size_type i = 0; i < x.size() - 1; ++i)
{
xDatas[M*i] = x[i];
yDatas[M*i] = y[i];
data_type delta = (x[i+1] - x[i]) / M;
data_type xVal = 0;
for(size_type j = 1; j <= interpNum; ++j)
{
xVal = x[i] + delta * j;
xDatas[M*i+j] = xVal;
yDatas[M*i+j] = value(xVal);
}
}
xDatas.back() = x.back();
yDatas.back() = y.back();
return std::move(std::make_pair(std::move(xDatas), std::move(yDatas)));
}
bool SplineInterp::setDatas(const Vector<data_type> &x, const Vector<data_type> &y) noexcept(false)
{
if(x.size() != y.size())
{
throw std::invalid_argument("[SplineInterp::setDatas] x.size() != y.size();");
}
if(x.size() <= 2)
{
reset();
return false;
}
const size_type sz = x.size();
m_x = x;
m_y = y;
m_a.resize(sz - 1);
m_b.resize(sz - 1);
m_c.resize(sz - 1);
bool ok = buildSpline(x, y);
if(!ok)
{
reset();
}
return ok;
}
Interpolation::data_type SplineInterp::value(data_type xVal) const
{
if(m_a.size() == 0)
{
return 0.0;
}
const size_type i = lookup(xVal);
const data_type delta = xVal - m_x[i];
return ((((m_a[i] * delta) + m_b[i])
* delta + m_c[i]) * delta + m_y[i]);
}
bool SplineInterp::isValid() const
{
return m_a.size() != 0;
}
void SplineInterp::reset()
{
m_a.resize(0);
m_b.resize(0);
m_c.resize(0);
m_x.resize(0);
m_y.resize(0);
}
bool SplineInterp::buildSpline(const Vector<data_type> &x, const Vector<data_type> &y)
{
const size_type sz = x.size();
const data_type *px = x.data();
const data_type *py = y.data();
data_type *a = m_a.data();
data_type *b = m_b.data();
data_type *c = m_c.data();
// 判断x值是否递增
Vector<data_type> h(sz - 1);
for(size_type i = 0; i < sz - 1; ++i)
{
h[i] = px[i+1] - px[i];
if(h[i] <= 0)
{
return false;
}
}
Vector<data_type> d(sz - 1);
data_type dy1 = (py[1] - py[0]) / h[0];
for(size_type i = 1; i < sz - 1; ++i)
{
b[i] = c[i] = h[i];
a[i] = 2.0 * (h[i-1] + h[i]);
const data_type dy2 = (py[i+1] - py[i]) / h[i];
d[i] = 6.0 * (dy1 - dy2);
dy1 = dy2;
}
//
// solve it
//
// L-U Factorization
for(size_type i = 1; i < sz - 2; ++i)
{
c[i] /= a[i];
a[i+1] -= b[i] * c[i];
}
// forward elimination
Vector<data_type> s(sz);
s[1] = d[1];
for(size_type i = 2; i < sz - 1; ++i)
{
s[i] = d[i] - c[i-1] * s[i-1];
}
// backward elimination
s[sz - 2] = - s[sz - 2] / a[sz - 2];
for(size_type i = sz - 3; i > 0; --i)
{
s[i] = -(s[i] + b[i] * s[i+1]) / a[i];
}
s[sz - 1] = s[0] = 0.0;
//
// Finally, determine the spline coefficients
//
for(size_type i = 0; i < sz - 1; ++i)
{
a[i] = (s[i+1] - s[i]) / (6.0 * h[i]);
b[i] = 0.5 * s[i];
c[i] = (py[i+1] - py[i]) / h[i] - (s[i+1] + 2.0 * s[i]) * h[i] / 6.0;
}
return true;
}
Interpolation::size_type SplineInterp::lookup(data_type xVal) const
{
size_type i1 = 0;
const size_type sz = m_x.size();
if(xVal <= m_x[0])
{
i1 = 0;
}
else if(xVal >= m_x[sz - 2])
{
i1 = sz - 2;
}
else
{
i1 = 0;
size_type i2 = sz - 2;
size_type i3 = 0;
while(i2 - i1 > 1)
{
i3 = i1 + ((i2 - i1) >> 1);
if(m_x[i3] > xVal)
{
i2 = i3;
}
else
{
i1 = i3;
}
}
}
return i1;
interpolation.h
#ifndef INTERPOLATION_H
#define INTERPOLATION_H
#include <vector>
#include <cmath>
#ifdef _MSC_FULL_VER
#if _MSC_VER < 1600
#error "should use C++11 implementation"
#else
#pragma execution_character_set("utf-8")
#endif
#elif __cplusplus < 201103L
#error "should use C++11 implementation"
#endif
class Interpolation
{
public:
using data_type = double;
template<typename T> using Vector = std::vector<T>;
using size_type = Vector<data_type>::size_type;
Interpolation() = default;
virtual ~Interpolation() = default;
// 定义接口
/// data: 需要插值的原始数据
/// interNum: 每相邻两个点之间插入数据的个数
virtual Vector<data_type> interp(const Vector<data_type> &data, unsigned int interNum) noexcept(false) = 0;
};
// 正弦插值
class SineInterp : public Interpolation
{
public:
Vector<data_type> interp(const Vector<data_type> &data, unsigned int interpNum) noexcept(false) override;
};
// 三次埃尔米特插值
class CubicHermiteInterp : public Interpolation
{
public:
/// x坐标等间隔时并且不关心x坐标,可调用此函数
Vector<data_type> interp(const Vector<data_type> &data, unsigned int interNum) noexcept(false) override;
/// x坐标非等间隔或同时需要x坐标, 调用此函数
/// 返回值: pair--first为x坐标向量, second为y坐标向量
std::pair<Vector<data_type>, Vector<data_type>>
interp(const Vector<data_type> &x,
const Vector<data_type> &y,
unsigned int interpNum) noexcept(false);
};
// 三次样条曲线插值
class SplineInterp : public Interpolation
{
public:
/// x坐标等间隔时并且不关心x坐标,可调用此函数
Vector<data_type> interp(const Vector<data_type> &data, unsigned int interpNum)noexcept(false) override;
/// x坐标非等间隔或同时需要x坐标, 调用此函数
/// 返回值: pair--first为x坐标向量, second为y坐标向量
std::pair<Vector<data_type>, Vector<data_type>>
interp(const Vector<data_type> &x,
const Vector<data_type> &y,
unsigned int interpNum) noexcept(false);
protected:
bool setDatas(const Vector<data_type> &x, const Vector<data_type> &y) noexcept(false);
void reset();
bool isValid() const;
data_type value(data_type xVal) const;
bool buildSpline(const Vector<data_type> &x, const Vector<data_type> &y);
size_type lookup(data_type xVal) const;
private:
// 系数
Vector<data_type> m_a;
Vector<data_type> m_b;
Vector<data_type> m_c;
// 控制点
Vector<data_type> m_x;
Vector<data_type> m_y;
};
#endif // INTERPOLATION_H
main.cpp
#include <iostream>
#include <vector>
#include "interpolation.h"
using namespace std;
int main()
{
SineInterp sine;
CubicHermiteInterp cubic;
SplineInterp spline;
vector<double> y = {23, 93, 33, 46, 80, 14, 96, 42, 16};
auto y1 = sine.interp(y, 10);
auto y2 = cubic.interp(y, 10);
auto y3 = spline.interp(y, 10);
auto print = [](const vector<double> &y)
{
for(auto d : y)
{
cout << d << "\t";
}
cout << endl;
};
cout << "sine interp data is: " << y1.size() << endl;
print(y1);
cout << "cubic interp data is: " << y2.size() << endl;
print(y2);
cout << "spline interp data is: " << y3.size() << endl;
print(y3);
return 0;
}
输出结果
fang@fang-virtual-machine:~/桌面/2Dinterpolation-master/2Dinterpolation-master$ g++ main.cpp interpolation.cpp -o main
fang@fang-virtual-machine:~/桌面/2Dinterpolation-master/2Dinterpolation-master$ ./main
sine interp data is: 89
23 29.0578 35.7788 43.0315 50.64 58.3885 66.0278 73.2861 79.8806 85.5322 89.9806 93 94.4136 94.107292.0397 88.2502 82.8613 76.0776 68.1796 59.5132 50.4743 41.4906 33 25.428 19.1647 14.5422 11.8158 11.14712.5929 16.1004 21.5057 28.5419 36.8505 46 55.508 64.8675 73.574 81.1542 87.1923 91.3542 93.4068 93.232190.8347 86.3424 80 72.1559 63.2431 53.7549 44.2178 35.1611 27.0881 20.4473 15.6073 12.8371 12.2907 14 17.8734 23.7023 31.1739 39.89 49.3898 59.177 68.7472 77.6162 85.346 91.5678 96 98.4613 98.877 97.279593.8021 88.6671 82.1699 74.6581 66.51 58.1105 49.8293 42 34.9028 28.7516 23.6853 19.7649 16.9756 15.232914.3939 14.2712 14.6488 15.2991 16
cubic interp data is: 89
23 35.1488 46.1405 55.9752 64.6529 72.1736 78.5372 83.7438 87.7934 90.686 92.4215 93 91.6026 87.770882.0458 74.9684 67.0796 58.9204 51.0316 43.9542 38.2292 34.3974 33 33.1615 33.6243 34.356 35.3242 36.496437.8403 39.3235 40.9134 42.5778 44.2841 46 48.205 51.2524 54.9205 58.9876 63.2318 67.4315 71.3648 74.810177.5455 79.3494 80 78.4628 74.2479 67.9504 60.1653 51.4876 42.5124 33.8347 26.0496 19.7521 15.5372 14 15.9098 21.1465 28.9707 38.6431 49.4245 60.5755 71.3569 81.0293 88.8535 94.0902 96 95.006 92.2431 88.04 82.7252 76.6273 70.0751 63.3971 56.922 50.9784 45.8948 42 38.855 35.8135 32.8932 30.1115 27.4861 25.034622.7745 20.7234 18.8989 17.3186 16
spline interp data is: 89
23 32.8706 42.5659 51.9104 60.7289 68.846 76.0864 82.2747 87.2356 90.7936 92.7736 93 91.3753 88.112583.5022 77.8353 71.4025 64.4943 57.4017 50.4152 43.8256 37.9236 33 29.2849 26.7664 25.3721 25.0297 25.666627.2105 29.5889 32.7295 36.5598 41.0075 46 51.4378 57.1127 62.7891 68.2316 73.2047 77.473 80.8009 82.95383.6939 82.788 80 75.2117 68.7745 61.1574 52.8291 44.2584 35.9142 28.2654 21.7807 16.9291 14.1792 14 16.6952 21.9079 29.1163 37.7985 47.4326 57.4967 67.4688 76.8271 85.0497 91.6146 96 97.8274 97.2918 94.731990.4862 84.8932 78.2914 71.0194 63.4157 55.8189 48.5675 42 36.3831 31.6956 27.8448 24.7375 22.2809 20.382118.9479 17.8856 17.1022 16.5046 16
绘制图形
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
# 输入数据
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, \
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, \
27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, \
41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, \
58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, \
75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89] # x轴数据
y1 = [23, 29.0578, 35.7788, 43.0315, 50.64, 58.3885, 66.0278, 73.2861, 79.8806, 85.5322,89.9806,93,\
94.4136, 94.1072, 92.0397, 88.2502,82.8613,76.0776,68.1796,59.5132,50.4743,41.4906,\
33, 25.428, 19.1647, 14.5422,11.8158,11.147,12.5929,16.1004,21.5057,28.5419,36.8505,46,\
55.508, 64.8675, 73.574, 81.1542,87.1923,91.3542,93.4068,93.2321,90.8347,86.3424,\
80, 72.1559, 63.2431, 53.7549, 44.2178,35.1611,27.0881,20.4473,15.6073,12.8371,12.2907,14,\
17.8734, 23.7023,31.1739, 39.89,49.3898,59.177,68.7472,77.6162,85.346,91.5678,96,\
98.4613,98.877,97.2795,93.8021,88.6671,82.1699,74.6581,66.51,58.1105,49.8293,42,\
34.9028,28.7516,23.6853,19.7649,16.9756,15.2329,14.3939,14.2712,14.6488,15.2991,16] # y轴数据
y2 = [23,35.1488,46.1405,55.9752,64.6529,72.1736,78.5372,83.7438,87.7934,90.686,92.4215,\
93,91.6026,87.7708,82.0458,74.9684,67.0796,58.9204,51.0316,43.9542,38.2292,34.3974,\
33,33.1615,33.6243,34.356, 35.3242,36.4964,37.8403,39.3235,40.9134,42.5778,44.2841,\
46,48.205,51.2524, 54.9205, 58.9876, 63.2318,67.4315,71.3648,74.8101,77.5455,79.3494,\
80,78.4628,74.2479, 67.9504,60.1653,51.4876,42.5124,33.8347,26.0496,19.7521,15.5372,\
14,15.9098,21.1465, 28.9707,38.6431, 49.4245,60.5755,71.3569,81.0293,88.8535,94.0902,\
96,95.006,92.2431, 88.04, 82.7252,76.6273,70.0751,63.3971,56.922,50.9784,45.8948,42,\
38.855,35.8135,32.8932,30.1115,27.4861,25.0346,22.7745,20.7234,18.8989,17.3186,16]
y3 = [23, 32.8706, 42.5659, 51.9104, 60.7289, 68.846, 76.0864, 82.2747, 87.2356, \
90.7936, 92.7736, 93, 91.3753, 88.1125, 83.5022, 77.8353, 71.4025, 64.4943, \
57.4017, 50.4152, 43.8256, 37.9236, 33, 29.2849, 26.7664, 25.3721, 25.0297, \
25.6666, 27.2105, 29.5889, 32.7295, 36.5598, 41.0075, 46, 51.4378, 57.1127, \
62.7891, 68.2316, 73.2047, 77.473, 80.8009, 82.953, 83.6939, 82.788, \
80, 75.2117, 68.7745, 61.1574, 52.8291, 44.2584, 35.9142, 28.2654, 21.7807,\
16.9291, 14.1792, 14, 16.6952, 21.9079, 29.1163, 37.7985, 47.4326, 57.4967,\
67.4688, 76.8271, 85.0497, 91.6146, 96, 97.8274, 97.2918, 94.7319, 90.4862, \
84.8932, 78.2914, 71.0194, 63.4157, 55.8189, 48.5675, 42, 36.3831, 31.6956, \
27.8448, 24.7375, 22.2809, 20.3821, 18.9479, 17.8856, 17.1022, 16.5046, 16]
# 设置字体为中文字体,例如SimHei(宋体)或者MS Gothic(等线)
font = FontProperties(family='SimHei')
# 创建折线图
plt.plot(x, y1, marker='o', linestyle='-', color='blue', label='sine')
plt.plot(x, y2, marker='s', linestyle='--', color='green', label='cubic')
plt.plot(x, y3, marker='^', linestyle='-.', color='red', label='spline')
# 添加标题和标签
plt.title('三种算法对比', fontproperties=font)
# plt.xlabel('X轴标签')
# plt.ylabel('Y轴标签')
# 添加图例
plt.legend()
# 显示图形
plt.show()
结果图