DLX
常规方法是DFS+一些剪枝。但是因为这是数独,可以用DLX求出数独的所有解,取分数最大的那个即可。
不会DLX的小伙伴看这里
代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 81*9
#define M 81*4
using namespace std;
struct node{
node *l,*r,*u,*d;
int x,y;
}c[M+5],a[N*M+5],s,*null=&s,*h;
int cnt[M+5],mp[10][10],ans[10][10],id[N+5][M+5],n=N,m=M,nd,res;
bool f[N+5][M+5];
int w[10][10]=
{{0,0,0,0,0,0,0,0,0,0},
{0,6,6,6,6,6,6,6,6,6},
{0,6,7,7,7,7,7,7,7,6},
{0,6,7,8,8,8,8,8,7,6},
{0,6,7,8,9,9,9,8,7,6},
{0,6,7,8,9,10,9,8,7,6},
{0,6,7,8,9,9,9,8,7,6},
{0,6,7,8,8,8,8,8,7,6},
{0,6,7,7,7,7,7,7,7,6},
{0,6,6,6,6,6,6,6,6,6}};
void nsrt(int x,int y,int z){
int i=(x-1)*9+y,j=(i-1)*9+z;
f[j][(x-1)*9+z]=true;
f[j][81+(y-1)*9+z]=true;
int k=((x-1)/3*3+(y-1)/3)+1;
f[j][162+(k-1)*9+z]=true;
f[j][243+i]=true;
}
void build(){
for (int i=1;i<=9;i++)
for (int j=1;j<=9;j++)
if (!mp[i][j]) for (int k=1;k<=9;k++) nsrt(i,j,k);
else nsrt(i,j,mp[i][j]);
h=&c[0],h->d=h->u=h->l=h->r=h,h->x=h->y=0; node *pre=h;
for (int i=1;i<=m;i++){
node *p=&c[i]; p->u=p->d=p,p->x=0,p->y=i;
p->r=pre->r,p->l=pre,pre->r->l=p,pre->r=p,pre=p;
}
for (int i=1;i<=n;i++)
for (int j=1;j<=m;j++)
if (f[i][j]){
a[id[i][j]=++nd].x=i,a[nd].y=j;
a[nd].l=a[nd].r=a[nd].u=a[nd].d=&a[nd];
}
for (int j=1;j<=m;j++){
node *pre=&c[j];
for (int i=1;i<=n;i++)
if (f[i][j]){
cnt[j]++; node *p=&a[id[i][j]];
p->d=pre->d,p->u=pre,pre->d->u=p,pre->d=p,pre=p;
}
}
for (int i=1;i<=n;i++){
node *pre=null;
for (int j=1;j<=m;j++)
if (f[i][j])
if (pre==null) pre=&a[id[i][j]];
else{
node *p=&a[id[i][j]]; p->r=pre->r;
p->l=pre,pre->r->l=p,pre->r=p,pre=p;
}
}
}
void rmv(int x){
node *p=&c[x],*p1=p->d;
p->l->r=p->r,p->r->l=p->l;
while (p1!=p){
node *p2=p1->r;
while (p2!=p1)
p2->d->u=p2->u,p2->u->d=p2->d,cnt[p2->y]--,p2=p2->r;
p1=p1->d;
}
}
void rsm(int x){
node *p=&c[x],*p1=p->d;
p->l->r=p->r->l=p;
while (p1!=p){
node *p2=p1->r;
while (p2!=p1)
p2->d->u=p2->u->d=p2,cnt[p2->y]++,p2=p2->r;
p1=p1->d;
}
}
node *find(){
node *p=h->r,*ans=h->r; int mn=0x7fffffff;
while (p!=h){
if (cnt[p->y]<mn) mn=cnt[p->y],ans=p;
p=p->r;
}
return ans;
}
//上面都是板子
void calc(){//计算分数
int ret=0;
for (int i=1;i<=9;i++)
for (int j=1;j<=9;j++)
ret+=w[i][j]*ans[i][j];
res=max(res,ret);
}
void dfs(){
if (h->r==h) return calc();
node *p=find(),*p1=p->d;
if (p1==p) return; rmv(p->y);
while (p1!=p){
ans[(p1->x-1)/81+1][((p1->x-1)/9)%9+1]=(p1->x-1)%9+1;
node *p2=p1->r;
while (p2!=p1) rmv(p2->y),p2=p2->r;
dfs(),p2=p1->l;
while (p2!=p1) rsm(p2->y),p2=p2->l;
p1=p1->d;
}
return rsm(p->y);
}
int main(){
for (int i=1;i<=9;i++)
for (int j=1;j<=9;j++)
scanf("%d",&mp[i][j]);
build(),res=-1,dfs();
printf("%d\n",res);
}