对称的正方形
时间限制:1000MS1000MS1000MS 内存限制:128MB128 MB128MB
问题描述
OrezOrezOrez很喜欢搜集一些神秘的数据,并经常把它们排成一个矩阵进行研究。最近,OrezOrezOrez又得到了一些数据,并已经把它们排成了一个nnn行mmm列的矩阵。通过观察,OrezOrezOrez发现这些数据蕴涵了一个奇特的数,就是矩阵中上下对称且左右对称的正方形子矩阵的个数。 OrezOrezOrez自然很想知道这个数是多少,可是矩阵太大,无法去数。只能请你编个程序来计算出这个数。
输入格式
文件的第一行为两个整数nnn和mmm。接下来nnn行每行包含mmm个正整数,表示OrezOrezOrez得到的矩阵。
输出格式
文件中仅包含一个整数answeransweranswer,表示矩阵中有answeransweranswer个上下左右对称的正方形子矩阵。
样例输入
555 555
444 222 444 444 444
333 111 444 444 333
333 555 333 333 333
333 111 555 333 333
444 222 111 222 444样例输出
272727
数据范围
对于303030%的数据 n,m≤100n,m≤100n,m≤100
对于100100100%的数据 n,m≤1000n,m≤1000n,m≤1000 ,矩阵中的数的大小≤109≤10^9≤109
解析
听说有很多神仙是用ManacherManacherManacher做的 ,然而本蒟蒻并不会 。
所以我们用一种比较简单粗暴且易于理解的算法——hashhashhash来替代。(说白了就是弱)
首先我们会发现两个很 (显然且) 有用的性质:
- 如果一个正方形子矩阵是对称的且边长>2>2>2,那么比它小一圈的正方形子矩阵也一定是对称的。
- 正方形子矩阵的对称中心至多只有O(2nm)O(2nm)O(2nm)个
那么我们可以枚举正方形子矩阵的对称中心,并二分此对称中心的最大边长。
至于判定该正方形子矩阵是否对称,我们可以通过矩阵hash来解决。
设一个矩阵a[1..i][1..j]a[1..i][1..j]a[1..i][1..j]的hashhashhash值为Σp1i∗p2j∗a[i][j]Σp_1^i*p_2^j*a[i][j]Σp1i∗p2j∗a[i][j]
我们把原矩阵、原矩阵上下翻转、原矩阵左右翻转分别做一次hashhashhash,判定时只要把对应矩阵的hashhashhash值用二维前缀和求出来并简单处理一下行差、列差对p1,p2p_1,p_2p1,p2乘方次数的影响之后判断是否相等即可。
Tips:Tips:Tips:
- 本题时限较紧,请提前预处理p1i,p2jp_1^i,p_2^jp1i,p2j
- 如果你对自己的常数不是非常自信的话请不要写双hashhashhash,写了也不要用pairpairpair
cyl大佬已经身先士卒地T了 - 别把mmm打成nnn,本蒟蒻对此已经不想说什么了
- 枚举偶数边长时不一定有答案
代码
#include <cstdio>
#define ll long long
using namespace std;
const int maxn = 1005;
const int p1 = 29;
const int p2 = 31;
const int mod = 1e9 + 7;
int n , m;
int a[maxn][maxn] , b[maxn][maxn] , c[maxn][maxn];
ll pow_x[maxn] , pow_y[maxn];
int min(int x , int y){return x < y ? x : y;}
int read()
{
char ch = getchar(); bool f = 1;
while(ch < '0' || ch > '9') f &= ch != '-' , ch = getchar();
int res = 0;
while(ch >= '0' && ch <= '9') res = (res << 3) + (res << 1) + (ch ^ 48) , ch = getchar();
return f ? res : -res;
}
void pow_init()
{
pow_x[0] = pow_y[0] = 1;
for(int i = 1;i <= n;i++) pow_x[i] = pow_x[i - 1] * p1 % mod;
for(int i = 1;i <= m;i++) pow_y[i] = pow_y[i - 1] * p2 % mod;
}
struct HASH
{
private:
ll s[maxn][maxn];
public:
void init(int (*x)[maxn])
{
for(int i = 1;i <= n;i++)
for(int j = 1;j <= m;j++)
{
ll tmp = pow_x[i] * pow_y[j] % mod * x[i][j] % mod;
s[i][j] = ((s[i - 1][j] + s[i][j - 1]) % mod - s[i - 1][j - 1] + tmp + mod) % mod;
}
}
ll sum(int bx , int by , int ex , int ey){return ((s[ex][ey] - s[bx - 1][ey] - s[ex][by - 1] + s[bx - 1][by - 1]) % mod + mod) % mod;}
}hash1 , hash2 , hash3;
bool check1(int len , int i , int j)
{
int bx1 = i - len , by1 = j - len , ex1 = i + len , ey1 = j + len;
ll res1 = hash1.sum(bx1 , by1 , ex1 , ey1);
int bx2 = bx1 , by2 = m - ey1 + 1 , ex2 = ex1 , ey2 = m - by1 + 1;
ll res2 = hash2.sum(bx2 , by2 , ex2 , ey2);
int bx3 = n - ex1 + 1 , by3 = by1 , ex3 = n - bx1 + 1 , ey3 = ey1;
ll res3 = hash3.sum(bx3 , by3 , ex3 , ey3);
ll res4 = res1;
if(by1 > by2) res2 = (res2 * pow_y[by1 - by2]) % mod;
if(by1 < by2) res1 = (res1 * pow_y[by2 - by1]) % mod;
if(bx1 > bx3) res3 = (res3 * pow_x[bx1 - bx3]) % mod;
if(bx1 < bx3) res4 = (res4 * pow_x[bx3 - bx1]) % mod;
return res1 == res2 && res4 == res3;
}
bool check2(int len , int i , int j)
{
int bx1 = i - len , by1 = j - len , ex1 = i + len + 1 , ey1 = j + len + 1;
ll res1 = hash1.sum(bx1 , by1 , ex1 , ey1);
int bx2 = bx1 , by2 = m - ey1 + 1 , ex2 = ex1 , ey2 = m - by1 + 1;
ll res2 = hash2.sum(bx2 , by2 , ex2 , ey2);
int bx3 = n - ex1 + 1 , by3 = by1 , ex3 = n - bx1 + 1 , ey3 = ey1;
ll res3 = hash3.sum(bx3 , by3 , ex3 , ey3);
ll res4 = res1;
if(by1 > by2) res2 = (res2 * pow_y[by1 - by2]) % mod;
if(by1 < by2) res1 = (res1 * pow_y[by2 - by1]) % mod;
if(bx1 > bx3) res3 = (res3 * pow_x[bx1 - bx3]) % mod;
if(bx1 < bx3) res4 = (res4 * pow_x[bx3 - bx1]) % mod;
return res1 == res2 && res4 == res3;
}
int main()
{
n = read() , m = read();
for(int i = 1;i <= n;i++)
for(int j = 1;j <= m;j++) a[i][j] = b[i][m - j + 1] = c[n - i + 1][j] = read();
pow_init();
hash1.init(a) , hash2.init(b) , hash3.init(c);
int ans = 0;
for(int i = 1;i <= n;i++)
for(int j = 1;j <= m;j++)
{
int l = 0 , r = min(min(i - 1 , n - i) , min(j - 1 , m - j));
while(l < r)
{
int mid = l + r + 1 >> 1;
if(check1(mid , i , j)) l = mid;
else r = mid - 1;
}
ans += l + 1;
}
for(int i = 1;i < n;i++)
for(int j = 1;j < m;j++)
{
int l = 0 , r = min(min(i - 1 , n - i - 1) , min(j - 1 , m - j - 1)) , res = -1;
while(l <= r)
{
int mid = l + r >> 1;
if(check2(mid , i , j)) res = mid , l = mid + 1;
else r = mid - 1;
}
ans += res + 1;
}
printf("%d\n",ans);
return 0;
}