洛谷 4106 / bzoj 3614 [HEOI2014]逻辑翻译——思路+类似FWT
题目:https://www.luogu.org/problemnew/show/P4106
https://www.lydsy.com/JudgeOnline/problem.php?id=3614
可以先把给出的东西排序成这样:
-1 -1 -1
-1 -1 1
-1 1 -1
-1 1 1
1 -1 -1
1 -1 1
1 1 -1
1 1 1
就是后面看成低位、前面看成高位,1看成1、-1看成0的二进制的顺序。
发现把第1行和第2行相加再除以2,得到的就是与 x3 无关的所有系数 a 在 x1 = -1 , x2 = -1 的情况下的值;
第2行减第1行再除以2,得到的就是与 x3 有关的所有系数 a 在 x3 = 1 , x1 = -1 , x2 = -1 的情况下的值;
把所有行两个一组相加的答案放在一起考虑,就是所有与 x3 无关的系数在 x1 , x2 取 -1 , -1 ; -1 , 1 ; 1 , -1 ; 1 , 1 的情况下的值;相减的话就是 x1 , x2 取各种值,x3的值都是1的情况;这就是一个子问题了。
所有把所有行两个一组相加除以2的值放在前半部分,两个一组相减(下面减上面)除以2的值放在后半部分,大概就能做了。
最后第1行就是和所有 x 都无关的那个 a ,也就是常数项;第2行仔细考虑一下,是 x1 的系数。即,算出来的值在第 i 行的就是取 x 方案为 i ( i 就像状压了取哪些x乘起来的那一项)的项的系数。排序输出即可。
注意分数中途爆 int 。
#include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define ll long long using namespace std; const int N=(1<<20)+5,M=25; int n,bin[M],r[N]; char ch[M]; int gcd(int a,int b){return b?gcd(b,a%b):a;} struct Node{ int x,y; void yf() { if(!y||!x)return; if(y<0)y=-y,x=-x; int g=gcd(x>0?x:-x,y);x/=g;y/=g; } Node operator+ (const Node &b)const { Node c; c.y=(ll)y*b.y/gcd(y,b.y); c.x=(ll)x*c.y/y+(ll)b.x*c.y/b.y; c.yf(); return c; } Node operator- (const Node &b)const { Node u,v;u.x=x;u.y=y; v.x=b.x;v.y=-b.y; return u+v; } }a[2][N]; bool cmp(int a,int b) { for(int i=0;i<n;i++) { if(!a)return true; if(!b)return false; if((a&bin[i])&&(b&bin[i])) { a-=bin[i];b-=bin[i];continue; } if(!(a&bin[i])&&!(b&bin[i]))continue; if(a&bin[i])return true; if(b&bin[i])return false; } } void solve(int len) { for(int i=0;i<len;i++) if(i<r[i])swap(a[0][i],a[0][r[i]]); for(int R=len,u=1,v=0;R>1;R>>=1,u=!u,v=!v) { int tot=-1;//here for(int i=0;i<len;i+=R) { for(int j=0;j<R;j+=2) { a[u][++tot]=a[v][i+j]+a[v][i+j+1]; a[u][tot].y*=2; a[u][tot].yf(); } for(int j=0;j<R;j+=2) { a[u][++tot]=a[v][i+j+1]-a[v][i+j]; a[u][tot].y*=2; a[u][tot].yf(); } } } } int main() { scanf("%d",&n); bin[0]=1;for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1; for(int i=0;i<bin[n];i++)r[i]=(r[i>>1]>>1)+((i&1)?bin[n-1]:0); for(int i=0;i<bin[n];i++) { scanf("%s",ch); long long d=0; for(int j=0;j<n;j++) d|=(ch[j]=='+'?bin[j]:0); double tmp; scanf("%lf",&tmp); a[0][d].x=(int)(tmp*100+(tmp>0?0.5:-0.5)); a[0][d].y=100;///round! a[0][d].yf(); } solve(bin[n]);int fx=n&1; for(int i=0;i<bin[n];i++)r[i]=i; sort(r,r+bin[n],cmp); for(int i=0,u=r[0];i<bin[n];i++,u=r[i]) if(a[fx][u].x) { printf("%d",a[fx][u].x); if(a[fx][u].y>1)printf("/%d",a[fx][u].y); if(u) { putchar(' '); for(int j=0;j<n;j++) if(u&bin[j])printf("x%d",j+1); } puts(""); } return 0; }
但这样空间在洛谷上能过, bzoj 上过不了。本来算下来就很大。
考虑不要把 a[ ] 开成滚动数组了。比如不要把两个一组相加的值放在前一半、相减的值放在后一半,而是把两行 i 和 i+1 相加的值放在第 i 行,相减的值放在第 i+1 行;这样就和 FWT 的模板长得更像,只开一个 a[ ] 而不用滚动也能应付过来。
考虑这样算了一次之后,下一次是哪里相加、相减。其实就相当于是原来的前一半,其间穿插上后一半的值;所以原来是前一半相邻两行再相加,现在就是隔一行相加;即原来是分治到前半部分和后半部分,现在是分治到奇数项和偶数项;这样下去就是 i 和 i+4 匹配、i 和 i+8 匹配……套上 FWT 的那个循环就行了。
仔细想一想,发现这样算出来,第1行是常数项,第2行是只和 x3 有关的项……也就是角标的二进制最低位是1表示有 x3 ,最高位是1表示有 x1 ……
如果一开始的排序是:
-1 -1 -1
1 -1 -1
-1 1 -1
1 1 -1
……
这样的话算出来的结果就是角标二进制最低位是1表示有x1……这样的。
输出可以写 dfs ,先搜这一位填1的,再搜这一位填0的;搜下一位之前输出一下,即每个方案在它填完最高位的1之后输出,如果是填了0就不输出,因为这个方案在最靠近的1被填了之后曾经输出过。
#include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define ll long long using namespace std; const int N=(1<<20)+5,M=25; int n,bin[M]; char ch[M]; int gcd(int a,int b){return b?gcd(b,a%b):a;} struct Node{ int x,y; void yf() { if(!y||!x)return; if(y<0)y=-y,x=-x; int g=gcd(x>0?x:-x,y);x/=g;y/=g; } Node operator+ (const Node &b)const { Node c; c.y=(ll)y*b.y/gcd(y,b.y);//ll c.x=(ll)x*c.y/y+(ll)b.x*c.y/b.y; // return c; c.yf(); return c; } Node operator- (const Node &b)const { Node u,v;u.x=x;u.y=y; v.x=b.x;v.y=-b.y; return u+v; } void print(int id) { if(!x)return; yf(); if(y>1)printf("%d/%d",x,y); else printf("%d",x); if(id) { putchar(' '); for(int i=0;i<n;i++) if(id&bin[i])printf("x%d",i+1); } puts(""); } }a[N]; void dfs(int cr,int ml) { if(!cr||ml&bin[cr-1]) a[ml].print(ml); if(cr==n)return; dfs(cr+1,ml|bin[cr]); dfs(cr+1,ml); } int main() { scanf("%d",&n); int len=(1<<n); bin[0]=1;for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1; for(int i=0;i<len;i++) { scanf("%s",ch); int d=0; for(int j=0;j<n;j++) d|=(ch[j]=='+'?bin[j]:0); double tmp; scanf("%lf",&tmp); a[d].x=round(tmp*100); a[d].y=100; a[d].yf(); } for(int R=2;R<=len;R<<=1) { for(int i=0,m=R>>1;i<len;i+=R) for(int j=0;j<m;j++) { Node x=a[i+j],y=a[i+m+j]; a[i+j]=x+y; a[i+m+j]=y-x; a[i+j].y*=2; a[i+j].yf(); a[i+m+j].y*=2; a[i+m+j].yf(); } } dfs(0,0); return 0; }