[省选联考 2020 A/B 卷] 信号传递
链接
分析
大毒瘤状压。。。
首先注意对于原序列上一个 \(x\rightarrow y\) 的贡献可以拆到 \(x\) 和 \(y\) 上,也就是说
\(\left\{\begin{matrix}
g[x]+=k,g[y]+=k \ (x>y)\\
g[x]-=1,g[y]+=1 \ (x<y)
\end{matrix}\right.\)
最后把每个数的序号乘上 \(g\) 再求和就是答案。
我们发现这样每个数的贡献只和在它前面的数有哪些有关,所以我们设 \(g[x][S]\) 表示在 \(x\) 之前的数集为 \(S\) 时 \(x\) 的上面的 \(g\)。
我们先把传递序列拆开,记 \(e[i][j]\) 表示从 \(i\) 直接走到 \(j\) 的边数。
于是我们可以轻松做到 \(O(m^2 2^m)\) 求出 \(g\)。
点击查看代码
for(int i=0;i<=S;i++)
for(int j=1;j<=m;j++){
if(i&(1<<(j-1)))continue;
for(int k=1;k<=m;k++){
if(j==k)continue;
if((i>>(k-1))&1)g[j][i]+=e[k][j]+K*e[j][k];
else g[j][i]+=-e[j][k]+K*e[k][j];
}
}
于是我们有了 \(g\),只需要确定数的序号就可以求答案。
受到 \(g\) 的启发,因为一个数的序号和 \(g\) 的值只和它前面有哪些数有关,不需要知道前面的顺序。
所以我们设 \(f[S]\) 表示已经确定了 \(S\) 数集的顺序。转移时发现 \(S\) 刚好和 \(g\) 的第二维相同,于是我们也可以轻松地做到 \(O(m 2^m)\) 求出 \(f\)。
点击查看代码
memset(f,127,sizeof(f));f[0]=0,num[0]=1;
for(int i=0;i<S;i++)
for(int j=1;j<=m;j++){
if(i&(1<<(j-1)))continue;
f[i|(1<<(j-1))]=min(f[i|(1<<(j-1))],f[i]+num[i]*g[j][i]),
num[i|(1<<(j-1))]=num[i]+1;
}
现在我们有了一个较为成熟的 \(O(m^2 2^m)\) 的做法,此时你能得到 \(60\) 分的好成绩。注意 \(g[x][S]\) 要避开 \(S\) 包含 \(x\) 的不合法情况。
时间复杂度优化
发现我们时间复杂度的瓶颈 \(g\) 的求解还有明显的可优化空间,因为 \(g\) 其实是可以由之前的 \(S\) 继承来的。只需从 \(S\) 中随便找一个数 \(j\) 去掉,那么根据 \(g\) 原本的求法我们可以得到 \(g[i][S]=g[i][S-(1<<j)]+(e[j][i]+K*e[i][j])-(-e[i][j]+K*e[j][i])\)。
这里随便找的 \(j\) 可以用 lowbit 去找。
于是我们有了时间复杂度 \(O(m 2^m)\) 的做法,得到了 \(80\) 分的好成绩。
code
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define in read()
inline int read(){
int p=0,f=1;
char c=getchar();
while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){p=p*10+c-'0';c=getchar();}
return p*f;
}
const int N=1e5+5;
const int T=(1<<23);
inline int lowbit(int x){return x&(-x);}
int n,m,K,S;
int e[25][25],a[N];
int g[25][T];
int f[T],num[T];
signed main(){
n=in,m=in,K=in,S=(1<<m)-1;
for(int i=1;i<=n;i++)a[i]=in;
for(int i=1;i<n;i++)e[a[i]][a[i+1]]++;
for(int i=1;i<=m;i++)
for(int j=1;j<=m;j++)
if(i!=j)g[i][0]+=-e[i][j]+K*e[j][i];
for(int i=1,t=lowbit(i),k=0;i<=S;i++,t=lowbit(i),k=0){
while((1<<k)!=t)k++;k++;
for(int j=1;j<=m;j++)
if(!(i&(1<<(j-1))))
g[j][i]=g[j][i-t]+(1-K)*e[k][j]+(1+K)*e[j][k];
}
memset(f,127,sizeof(f));f[0]=0,num[0]=1;
for(int i=0;i<S;i++)
for(int j=1;j<=m;j++){
if(i&(1<<(j-1)))continue;
f[i|(1<<(j-1))]=min(f[i|(1<<(j-1))],f[i]+num[i]*g[j][i]),
num[i|(1<<(j-1))]=num[i]+1;
}
cout<<f[S];
return 0;
}
空间复杂度优化
我们惊讶的发现空间居然爆了,显然我们空间复杂度的瓶颈也在 \(g\) 数组上,该怎么从 \(g\) 身上榨出一些空间出来呢。
我们注意到 \(g[x][S]\) 这样的 \(x\) 有一些 \(S\) 是无用的,就是 \(S\) 中包含 \(x\) 的情况。这些情况去掉后不会对求解产生影响。
那么我们可以尝试把含有 \(x\) 的 \(S\) 去掉,这样 \(g\) 的大小就从 \(23\times 2^{23}\) 的约 \(736MB\) 减小到 \(23\times 2^{22}\) 约 \(368MB\) 可以获得开O2 \(100\) 分的好成绩。
别看说起来挺麻烦的,其实实现非常简单,只需要在用到在原本 \(80\) 分代码的基础上把出现 \(g\) 的地方微调一下第二维,把 \(x\) 的前后拼接起来。
code
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define in read()
inline int read(){
int p=0,f=1;
char c=getchar();
while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){p=p*10+c-'0';c=getchar();}
return p*f;
}
const int N=1e5+5;
const int T=(1<<23);
inline int lowbit(int x){return x&(-x);}
inline int gety(int x,int y){return ((y>>x)<<(x-1))+y%(1<<(x-1));}
int n,m,K,S;
int e[25][25],a[N];
int g[25][T>>1];
int f[T],num[T];
#define g(x,y) g[x][gety(x,y)]
signed main(){
n=in,m=in,K=in,S=(1<<m)-1;
for(int i=1;i<=n;i++)a[i]=in;
for(int i=1;i<n;i++)e[a[i]][a[i+1]]++;
for(int i=1;i<=m;i++)
for(int j=1;j<=m;j++)
if(i!=j)g(i,0)+=-e[i][j]+K*e[j][i];
for(int i=1,t=lowbit(i),k=0;i<=S;i++,t=lowbit(i),k=0){
while((1<<k)!=t)k++;k++;
for(int j=1;j<=m;j++)
if(!(i&(1<<(j-1))))
g(j,i)=g(j,i-t)+(1-K)*e[k][j]+(1+K)*e[j][k];
}
memset(f,127,sizeof(f));f[0]=0,num[0]=1;
for(int i=0;i<S;i++)
for(int j=1;j<=m;j++){
if(i&(1<<(j-1)))continue;
f[i|(1<<(j-1))]=min(f[i|(1<<(j-1))],f[i]+num[i]*g(j,i)),
num[i|(1<<(j-1))]=num[i]+1;
}
cout<<f[S];
return 0;
}
卡常
最后要想不开 O2 通过需要一定的卡常,这里把我卡完的代码贴出来,有一些常见的卡常,比如 fread,交换数组两维,register int,预处理2的幂,define,还有一个重要的是求解 \(g\) 和求 \(f\) 的两段可以合到一起。(其实仍然没卡过因为原题时限是2s)
code
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define in read()
#define lowbit(x) (x&(-x))
#define g(x,y) g[((((y)>>(x))<<(x-1))+(y)%(1<<(x-1)))][x]
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read(){
char ch=nc();int sum=0;
while(!(ch>='0'&&ch<='9'))ch=nc();
while(ch>='0'&&ch<='9')sum=sum*10+ch-48,ch=nc();
return sum;
}
const int N=1e5+5;
const int T=(1<<23);
int n,m,K,S;
int e[25][25],a[N];
int g[T>>1][25];
int f[T],num[T];
int lg2[T];
signed main(){
n=in,m=in,K=in,S=(1<<m)-1,num[0]=1;
for(register int i=2;i<=S;i++)lg2[i]=lg2[i/2]+1;
for(register int i=1;i<=n;i++)a[i]=in;
for(register int i=1;i<n;i++)e[a[i]][a[i+1]]++;
for(register int i=1;i<=S;i++)f[i]=1000000000;
for(register int i=1;i<=m;i++)
for(int j=1;j<=m;j++)
if(i!=j)g(i,0)+=-e[i][j]+K*e[j][i];
for(int i=1;i<=m;i++)
f[1<<(i-1)]=g(i,0);
for(register int i=1,t=lowbit(i),k=lg2[t]+1;i<=S;i++,t=lowbit(i),k=lg2[t]+1){
num[i]=num[i-t]+1;
for(int j=1;j<=m;j++)
if(!(i&(1<<(j-1))))
g(j,i)=g(j,i-t)+(1-K)*e[k][j]+(1+K)*e[j][k],
f[i|(1<<(j-1))]=min(f[i|(1<<(j-1))],f[i]+num[i]*g(j,i));
}
cout<<f[S];
return 0;
}
要我说做到 \(80\) 分就已经差不多了。