题意:
给你一个 m*n 的母矩阵,再给你一个 mm*nn 的子矩阵,问你母矩阵里有多少个不同位置的子矩阵。
题解:
ac自动机在二维上的运用。
与一维不同的地方主要是:1,val数组变成了vector。2,查询时的操作结合dp数组的状态的一系列操作。
1)思考一下为什么val要由数组变成vector:
假设这么一个样例:
3 3
B B B
B B B
B B B
2 2
B B
B B
那么在向trie树(字典树)里插入时分别执行了insert(“BB”,1)和insert(“BB”,2)【此处的1,2代表行的信息】
,于是你发现,val[2]里的信息先是变成1,然后又被覆盖成2,于是我们就丢了第一行的信息,所以~
2)dp[i][j][k] 表示母串的(i,j)位置是第k个子串的末尾。
找到相同的串后,判断条件很简单,判断这个串是不是子串的第一行或者它的上一行是不是一样。
#include <iostream>
using namespace std;
#include <vector>
#include <string.h>
#include <cstdio>
#include <queue>
#define maxnode 100010
#define sigma 128
char s[1005][1005];
char sub[105][105];
bool dp[1005][1005][105];// 表示母串的(i,j)位置是第k个子串的末尾
int m,n,mm,nn;
struct ac_automation{
int ch[maxnode][sigma];// maxnode 一般设置为 模式串数量*模式串长度
vector<int> val[maxnode];
int last[maxnode];
int f[maxnode]; // fail指针
int sz; // the num of the trie
int ans; // answer
void clear(){
sz = 1;
ans = 0;
memset(ch[0],0,sizeof(ch[0]));
val[0].clear();
}
int idx(char c){
return (int)c;
/*
if (c >= 'a' && c <= 'z') return c - 'a';
if (c >= 'A' && c <= 'Z') return c - 'A' + 26;
if (c >= '0' && c <= '9') return c - '0' + 52;
*/
}
void insert(char s[],int v){
int u = 0;
for(int i = 0; s[i];i++){
int c = idx( s[i] );
if(!ch[u][c]){
memset(ch[sz],0,sizeof(ch[sz]));
ch[u][c] = sz++;
val[ch[u][c]].clear();
}
u = ch[u][c];
}
val[u].push_back(v);
}
void build(){
queue<int> q;
f[0] = 0;
for(int i = 0;i < sigma;i++){
if(ch[0][i]){
f[ch[0][i]] = 0;
q.push(ch[0][i]);
last[ch[0][i]] = 0;
}
}
while(!q.empty()){
int now = q.front();
q.pop();
for(int i = 0;i < sigma ;i++){
int son = ch[now][i];
if(!son){
ch[now][i] = ch[f[now]][i];
continue;
}
q.push(son);
f[son] = ch[f[now]][i];
last[son] = val[f[son]].size() ? f[son] : last[f[son]];
}
}
}
void find(char *s,int m){ // m 代表第几行
int u = 0;
for(int i = 0;s[i];i++){
int c = idx(s[i]);
u = ch[u][c];
if(val[u].size()){
process(u,m,i+1);
}
else{
process(last[u],m,i+1);
}
}
}
void process(int u,int m,int j){ // m行 j列
if(u){
for(int i = 0;i < val[u].size();i++){
int k = val[u][i];
if(val[u][i] == 1 || dp[m-1][j][k-1] == true){
dp[m][j][k] = true;
if(k == mm)ans++;
}
}
}
}
}ac;
int main()
{
int T;
cin>>T;
while(T--){
scanf("%d%d",&m,&n);
for(int i = 1;i <= m;i++){
scanf("%s",s+i);
}
ac.clear();
memset(dp,0,sizeof(dp));
scanf("%d%d",&mm,&nn);
for(int i = 1;i <= mm;i++){
scanf("%s",sub+i);
ac.insert(sub[i],i);
}
ac.build();
for(int i = 1;i <= m;i++){
ac.find(s[i],i);
}
printf("%d\n",ac.ans);
}
return 0;
}