对于每一行的 −1-1−1,显然只会填一种数字,因此可以得出一种比较朴素的 DP。设 fi,jf_{i,j}fi,j 表示考虑前 iii 行,第 iii 行的 −1-1−1 变为 jjj 时的最大值,再设 di,jd_{i,j}di,j 表示第 iii 行数字 jjj 的个数,则可列出转移方程:
fi,j=maxo=1k{fi−1,o+di−1,−1×di,o+di−1,−1×di,−1}+di−1,j×di,j+di−1,j×di,−1 f_{i,j} = \max \limits _{o = 1} ^ k\{f_{i - 1,o} + d_{i - 1,-1} \times d_{i,o} + d_{i - 1,-1} \times d_{i,-1}\} + d_{i - 1,j} \times d_{i,j} + d_{i - 1,j} \times d_{i,-1} fi,j=o=1maxk{fi−1,o+di−1,−1×di,o+di−1,−1×di,−1}+di−1,j×di,j+di−1,j×di,−1
由于 ddd 数组可以预处理获得,于是就得到了 O(nk)O(nk)O(nk) 的算法。
进一步观察这个方程,发现方程中与 jjj 无关的量可以直接用线段树区间加处理,与 jjj 相关的直接单点加(每一行最多只需要 O(m)O(m)O(m) 次)。于是可以用有区间加,单点加,单点取最大值,最后维护区间最大值的数据结构实现。这里用两个懒标记的线段树去实现,每一次标记下传的时候,加法的懒标记直接累加,而最大值的懒标记需要先作加法后再取最值。最后我们就能得到 O(nmlog(k))O(nm \log (k))O(nmlog(k)) 的算法:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <map>
#include <vector>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 6e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
map <ll,ll> mp[2];
ll tree[MAX << 2],tmp[MAX << 2],mx[MAX << 2];
int t,n,m,k;
void modify (int cur,int l,int r,int x,int y,pair <ll,ll> p);
void pushdown (int cur);
int main ()
{
//freopen (".in","r",stdin);
//freopen (".out","w",stdout);
t = read ();
while (t--)
{
n = read ();m = read ();k = read ();mp[0].clear ();
vector <vector <ll> > a (n + 5, vector <ll> (m + 5));
for (int i = 1;i <= n;++i)
for (int j = 1;j <= m;++j) a[i][j] = read ();
for (int i = 1;i <= n;++i)
{
for (int j = 1;j <= m;++j) ++mp[1][a[i][j]];
for (auto v : mp[1])
if (v.first != -1) modify (1,1,k,v.first,v.first,{mp[0][-1] * v.second,0});
int mx = tree[1];
modify (1,1,k,1,k,{mp[0][-1] * mp[1][-1],0});
modify (1,1,k,1,k,{0,mx});
for (auto v : mp[1])
if (v.first != -1) modify (1,1,k,1,k,{mp[0][v.first] * mp[1][v.first],0});
for (auto v : mp[0])
if (v.first != -1) modify (1,1,k,v.first,v.first,{mp[1][-1] * v.second,0});
mp[0].clear ();
for (auto v : mp[1]) mp[0][v.first] = v.second;
mp[1].clear ();
}
printf ("%lld\n",tree[1]);
for (int i = 0;i <= 4 * k;++i) tree[i] = tmp[i] = mx[i] = 0;
}
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}
void pushdown (int cur) // 注意是先作加法后取 max
{
tree[cur << 1] = max (tree[cur << 1] + tmp[cur],mx[cur]);
tree[cur << 1 | 1] = max (tree[cur << 1 | 1] + tmp[cur],mx[cur]);
tmp[cur << 1] += tmp[cur],tmp[cur << 1 | 1] += tmp[cur];
mx[cur << 1] = max (mx[cur << 1] + tmp[cur],mx[cur]);mx[cur << 1 | 1] = max (mx[cur << 1 | 1] + tmp[cur],mx[cur]);
tmp[cur] = mx[cur] = 0;
}
void modify (int cur,int l,int r,int x,int y,pair <ll,ll> p)
{
if (x <= l && y >= r)
{
tmp[cur] += p.first;mx[cur] = max (mx[cur] + p.first,p.second);
tree[cur] = max (tree[cur] + p.first,p.second);
return ;
}
int mid = (l + r) >> 1;
pushdown (cur);
if (x <= mid) modify (cur << 1,l,mid,x,y,p);
if (y > mid) modify (cur << 1 | 1,mid + 1,r,x,y,p);
tree[cur] = max (tree[cur << 1],tree[cur << 1 | 1]);
}

被折叠的 条评论
为什么被折叠?



