Quadratic Form
题意
X
=
(
x
1
,
x
2
,
.
.
.
,
x
n
)
T
X=(x_1, x_2, ..., x_n)^T
X=(x1,x2,...,xn)T,
A
A
A为
n
×
n
n×n
n×n的正定二次型,
b
b
b为
n
×
1
n×1
n×1的列向量
求满足求
X
T
A
X
≤
1
X^TAX \leq 1
XTAX≤1,
(
X
T
b
)
2
\left(X^Tb\right)^2
(XTb)2的最大的值
题解
带有不等式约束条件解极值问题, 使用拉格朗日乘子法
设拉格朗日函数
L
(
X
,
λ
)
=
X
T
b
+
λ
(
X
T
A
X
−
1
)
L\left(X, \lambda \right) = X^Tb + \lambda \left( X^TAX - 1 \right)
L(X,λ)=XTb+λ(XTAX−1)
由KKT条件有
{
∂
L
∂
X
=
b
+
2
λ
A
X
=
0
λ
(
X
T
A
X
−
1
)
=
0
X
T
A
X
−
1
≤
0
λ
≤
0
⇒
{
b
+
2
λ
A
X
=
0
X
T
A
X
−
1
=
0
λ
≤
0
\begin{cases} \frac{\partial L}{\partial X} = b + 2 \lambda AX=0 \\ \lambda (X^TAX - 1) = 0 \\ X^TAX - 1 \leq 0 \\ \lambda \leq 0 \\ \end{cases} \Rightarrow \begin{cases} b + 2 \lambda AX=0\\ X^TAX - 1 =0 \\ \lambda \leq 0 \\ \end{cases}
⎩⎪⎪⎪⎨⎪⎪⎪⎧∂X∂L=b+2λAX=0λ(XTAX−1)=0XTAX−1≤0λ≤0⇒⎩⎪⎨⎪⎧b+2λAX=0XTAX−1=0λ≤0
λ
(
X
T
A
X
−
1
)
=
0
\lambda (X^TAX - 1) = 0
λ(XTAX−1)=0,
λ
=
0
\lambda=0
λ=0则
b
=
0
b=0
b=0, 因此令
X
T
A
X
−
1
=
0
X^TAX - 1=0
XTAX−1=0
由
b
+
2
λ
A
X
=
0
b + 2 \lambda AX=0
b+2λAX=0, 得
X
=
−
1
2
λ
A
−
1
b
X = -\frac{1}{2 \lambda} A ^ {-1} b
X=−2λ1A−1b
则
X
T
A
X
=
1
(
−
1
2
λ
A
−
1
b
)
T
A
(
−
1
2
λ
A
−
1
b
)
=
1
1
4
λ
2
b
T
A
−
1
A
A
−
1
b
=
1
b
T
A
−
1
b
=
4
λ
2
\begin{array}{lcl} X^TAX & = &1 \\ (-\frac{1}{2 \lambda} A ^ {-1} b) ^ TA(-\frac{1}{2 \lambda} A ^ {-1} b) &= & 1 \\ \frac{1}{4 \lambda ^ 2}b^TA^{-1}AA^{-1}b &= &1 \\ b^TA^{-1}b &= & 4\lambda^2 \end{array}
XTAX(−2λ1A−1b)TA(−2λ1A−1b)4λ21bTA−1AA−1bbTA−1b====1114λ2
又
X
T
A
X
=
1
X
T
(
−
1
2
λ
b
)
=
1
X
T
b
=
−
2
λ
\begin{array}{lcl} X^TAX & = &1 \\ X^T(-\frac{1}{2\lambda}b) &= & 1 \\ X^T b &= & -2\lambda \end{array}
XTAXXT(−2λ1b)XTb===11−2λ
最终有
(
X
T
b
)
2
=
4
λ
2
=
b
T
A
−
1
b
\left(X^Tb\right)^2=4\lambda^2=b^TA^{-1}b
(XTb)2=4λ2=bTA−1b
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAX = 2e2 + 10;
const ll mod = 998244353;
ll qpow(ll a, ll b) {
ll res = 1;
while (b) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
bool Gauss(ll a[][MAX << 1], int n) {//Gauss求逆
for (int i = 0, r; i < n; i++) {
r = i;
for (int j = i + 1; j < n; j++)
if (a[j][i] > a[r][i]) r = j;
if (r != i) swap(a[i], a[r]);
if (!a[i][i]) return false;//无解
ll inv = qpow(a[i][i], mod - 2);
for (int k = 0; k < n; k++) {
if (k == i) continue;
ll t = a[k][i] * inv % mod;
for (int j = i; j < (n << 1); j++)
a[k][j] = (a[k][j] - t * a[i][j] % mod + mod) % mod;
}
for (int j = 0; j < (n << 1); j++) a[i][j] = a[i][j] * inv % mod;
}
return true;
}
int n;
ll a[MAX][MAX << 1], b[MAX];
int main() {
while (~scanf("%d", &n)) {
memset(a, 0, sizeof(a));
for (int i = 0; i < n; i++) {
a[i][i + n] = 1;
for (int j = 0; j < n; j++)
scanf("%lld", &a[i][j]);
}
Gauss(a, n);
for (int i = 0; i < n; i++) scanf("%lld", &b[i]);
ll ans = 0;
//calc b^T A b
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
ans = (ans + a[i][j + n] * b[i] % mod * b[j] % mod) % mod;
printf("%lld\n", (ans + mod) % mod);
}
return 0;
}