D. Magic Numbers
看到计数取模就知道用dp了。。
首先把问题转换一下,变成求
1
~
最后就是分别算
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int mod = 1e9+7;
int m,d;
int len;
char stra[2010];
char strb[2010];
int a[2010];
int b[2010];
int dp[2010][2010][2];
int solve(int *num){
memset(dp,0,sizeof(dp));
for(int i=1;i<=num[0];i++){
if(i==d)continue;
if(i<num[0]){
dp[0][i%m][1]++;
}else{
dp[0][i%m][0]++;
}
}
for(int i=1;i<len;i++){
for(int j=0;j<m;j++){
if(i&1){
dp[i][ (j*10+d)%m ][1] += dp[i-1][j][1];
dp[i][ (j*10+d)%m ][1] %= mod;
if(d<num[i]){
dp[i][ (j*10+d)%m ][1] += dp[i-1][j][0];
dp[i][ (j*10+d)%m ][1] %= mod;
}else if(d==num[i]){
dp[i][ (j*10+d)%m ][0] += dp[i-1][j][0];
dp[i][ (j*10+d)%m ][0] %= mod;
}
}else{
for(int k=0;k<10;k++){
if(k==d)continue;
dp[i][ (j*10+k)%m ][1] += dp[i-1][j][1];
dp[i][ (j*10+k)%m ][1] %= mod;
if(k<num[i]){
dp[i][ (j*10+k)%m ][1] += dp[i-1][j][0];
dp[i][ (j*10+k)%m ][1] %= mod;
}else if(k==num[i]){
dp[i][ (j*10+k)%m ][0] += dp[i-1][j][0];
dp[i][ (j*10+k)%m ][0] %= mod;
}
}
}
}
}
int ans = dp[len-1][0][0]+dp[len-1][0][1];
ans %= mod;
return ans;
}
bool checkA(){
ll cur = 0;
for(int i=0;i<len;i++){
cur*=10;
cur+=a[i];
cur%=m;
if(i&1){
if(a[i]!=d)return 0;
}else{
if(a[i]==d)return 0;
}
}
return cur==0;
}
int main(){
cin>>m>>d;
scanf("%s",stra);
scanf("%s",strb);
len = strlen(stra);
for(int i=0;i<len;i++){
a[i]=stra[i]-'0';
b[i]=strb[i]-'0';
}
int l=solve(a);
int r=solve(b);
ll ans = (r-l + checkA() +mod)%mod;
cout<<ans<<endl;
return 0;
}
E. Zbazi in Zeydabad
首先预处理出每个位置 (i,j) 开始往左,往右,往左下有多少个连续的z,分别记为 l(i,j) , r(i,j) , ld(i,j) 。然后按 i+j 递增(右上-左下对角线)的顺序枚举每个位置作为大z的右上角的情况。对于每个位置 (i,j) 而言,它能形成的大z至多为 min(l(i,j),ld(i,j)) ,再来考察z的下面那一“横”,根据 r(i,j) ,将向右足够长的行添加到BIT中维护,然后查询在 min(l(i,j),ld(i,j)) 范围内有多少行满足。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
char z[3010][3010];
int r[3010][3010];
int l[3010][3010];
int ld[3010][3010];
int c[6010];
inline int lowbit(int x){
return x&(-x);
}
void update(int x){
while(x<=6000){
c[x]++;
x+=lowbit(x);
}
}
int query(int l,int r){
int resl=0;
l--;
while(l){
resl+=c[l];
l-=lowbit(l);
}
int resr=0;
while(r){
resr+=c[r];
r-=lowbit(r);
}
return resr-resl;
}
struct node{
int row;
int rr;
node(){
}
node(int row,int rr):row(row),rr(rr){
}
bool operator<(const node &other)const{
return rr>other.rr;
}
};
int main(){
int n,m;
cin>>n>>m;
for(int i=1;i<=n;i++){
scanf("%s",z[i]+1);
}
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
if(z[i][j]=='z'){
l[i][j]=l[i][j-1]+1;
}
}
for(int j=m;j>=1;j--){
if(z[i][j]=='z'){
r[i][j]=r[i][j+1]+1;
}
}
}
ll ans = 0;
for(int sum=2;sum<=n+m;sum++){
memset(c,0,sizeof(c));
int i=min(sum-1,n);
int j=sum-i;
vector<node> vec;
while(j<=m && i>=1){
if(r[i][j])vec.push_back(node(i,j+r[i][j]-1));
if(z[i][j]=='z'){
ld[i][j] = ld[i+1][j-1]+1;
}
i--;
j++;
}
sort(vec.begin(),vec.end());
//
j=min(sum-1,m);
i=sum-j;
int k = 0;
while(i<=n && j>=1){
while(k<vec.size() && vec[k].rr>=j){
update(vec[k].row);
k++;
}
int t = min(l[i][j],ld[i][j]);
ans+=query(i,i+t-1);
i++;
j--;
}
}
cout<<ans<<endl;
return 0;
}