复习
普通平衡树
需要注意的点:
1.哨兵节点提前插入
2.父亲节点注意一下
3.细心一点
#include <bits/stdc++.h> #define N 300009 #define lson s[x].ch[0] #define rson s[x].ch[1] #define setIO(s) freopen(s".in","r",stdin) using namespace std; const int inf=1000000009; struct data { int v,si,ch[2],f; data() { v=si=ch[0]=ch[1]=f=0; } }s[N]; int tot,root; inline void pushup(int x) { s[x].si=s[lson].si+s[rson].si+1; } inline int get(int x) { return s[s[x].f].ch[1]==x; } inline void rotate(int x) { int old=s[x].f,fold=s[old].f,which=get(x); if(fold) { s[fold].ch[s[fold].ch[1]==old]=x; } s[old].ch[which]=s[x].ch[which^1]; if(s[old].ch[which]) { s[s[old].ch[which]].f=old; } s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold; pushup(old),pushup(x); } void splay(int x,int &tar) { int u=s[tar].f; for(int fa;(fa=s[x].f)!=u;rotate(x)) { if(s[fa].f!=u) { rotate(get(fa)==get(x)?fa:x); } } tar=x; } void ins(int &x,int ff,int v) { if(!x) { x=++tot; s[x].f=ff,s[x].v=v,s[x].si=1; return; } ins(s[x].ch[v>s[x].v],x,v); pushup(x); } int get_pre(int x,int v) { if(!x) return -1; if(s[x].v<v) { int det=get_pre(rson,v); return det==-1?x:det; } else { return get_pre(lson,v); } } int get_aft(int x,int v) { if(!x) return -1; if(s[x].v>v) { int det=get_aft(lson,v); return det==-1?x:det; } else return get_aft(rson,v); } int get_rank(int v) { int x=get_pre(root,v); splay(x,root); return s[s[x].ch[0]].si; } int get_kth(int x,int kth) { if(s[lson].si+1==kth) return x; else if(s[lson].si>=kth) return get_kth(lson,kth); else return get_kth(rson,kth-s[lson].si-1); } int find(int x,int v) { if(s[x].v==v) return x; if(s[x].v<v) return find(rson,v); else return find(lson,v); } void del(int v) { int x=find(root,v); splay(x,root); int l=s[x].ch[0],r=s[x].ch[1]; while(s[l].ch[1]) l=s[l].ch[1]; splay(l,s[x].ch[0]); s[l].f=0,s[r].f=l,s[l].ch[1]=r,pushup(l); s[x].ch[0]=s[x].ch[1]=s[x].f=0; root=l; } int main() { srand(time(NULL)); int m,x,y,z; scanf("%d",&m); ins(root,0,inf); ins(root,0,-inf); for(int i=1;i<=m;++i) { int op; scanf("%d",&op); ++op; if(op==1) { scanf("%d",&x); ins(root,0,x); splay(tot,root); } if(op==2) { scanf("%d",&x); del(x); } if(op==3) { scanf("%d",&x); int p=get_kth(root,x+1); splay(p,root); printf("%d\n",s[p].v); } if(op==4) { scanf("%d",&x); printf("%d\n",get_rank(x)); } if(op==5) { scanf("%d",&x); int p=get_pre(root,x); splay(p,root); if(s[p].v==-inf) printf("-1\n"); else printf("%d\n",s[p].v); } if(op==6) { scanf("%d",&x); int p=get_aft(root,x); splay(p,root); if(s[p].v==inf) printf("-1\n"); else printf("%d\n",s[p].v); } } return 0; }
矩阵乘法
需要注意的点:
1. 矩阵的初始化
2. 注意新矩阵的 $n,m$ 以及 3 重循环中上界
#include <bits/stdc++.h> #define ll long long #define mod 1000000007 #define setIO(s) freopen(s".in","r",stdin) using namespace std; inline int ADD(int x,int y) { return (x+y)>=mod?x+y-mod:x+y; } struct M { int c[501][501],n,m; M() { memset(c,0,sizeof(c));} int *operator[](int x) { return c[x]; } M operator*(const M b) const { M an; an.n=n; an.m=b.m; for(int i=0;i<n;++i) { for(int j=0;j<b.m;++j) { for(int k=0;k<m;++k) { an.c[i][j]=ADD(an.c[i][j],(ll)c[i][k]*b.c[k][j]%mod); } } } return an; } }A,B; int main() { // setIO("input"); int n,p,m; scanf("%d%d%d",&n,&p,&m); A.n=n,A.m=p; for(int i=0;i<n;++i) { for(int j=0;j<p;++j) scanf("%d",&A[i][j]),A[i][j]=ADD(A[i][j],mod); } for(int i=0;i<p;++i) { for(int j=0;j<m;++j) scanf("%d",&B[i][j]),B[i][j]=ADD(B[i][j],mod); } B.n=p,B.m=m; A=A*B; for(int i=0;i<A.n;++i) { for(int j=0;j<A.m;++j) printf("%d ",A[i][j]); printf("\n"); } return 0; }
多项式
快速傅里叶变换
使用 FFT 的场合比较少,一般都是要结合 MTT 之类的.
对于复数 $(x,y)$,有 3 种运算:
$(x,y)+(x',y')=(x+x',y+y')$
$(x,y)-(x',y')=(x-x',y-y')$
$(x,y)*(x',y')=(x*x'-y*y',x*y'+y*x')$
#include <cstdio> #include <vector> #include <cmath> #include <cstring> #include <algorithm> #define ll long long #define db long double #define pb push_back #define N 1000007 #define setIO(s) freopen(s".in","r",stdin) using namespace std; const db pi=acos(-1.0); struct cp { db x,y; cp(db a=0,db b=0) { x=a,y=b; } cp operator+(const cp &b) const { return cp(x+b.x,y+b.y); } cp operator-(const cp &b) const { return cp(x-b.x,y-b.y); } cp operator*(const cp &b) const { return cp(x*b.x-y*b.y,x*b.y+y*b.x); } }A[N<<2],B[N<<2]; void FFT(cp *a,int len,int op) { for(int i=0,k=0;i<len;++i) { if(i>k) swap(a[i],a[k]); for(int j=len>>1;(k^=j)<j;j>>=1); } for(int l=1;l<len;l<<=1) { cp wn(cos(pi/l),op*sin(pi/l)),x,y; for(int i=0;i<len;i+=l<<1) { cp w(1,0); for(int j=0;j<l;++j) { x=a[i+j],y=w*a[i+j+l]; a[i+j]=x+y; a[i+j+l]=x-y; w=w*wn; } } } if(op==-1) { for(int i=0;i<len;++i) a[i].x/=len; } } int main() { // setIO("input"); int n,m,lim,x; scanf("%d%d",&n,&m); for(lim=1;lim<(n+m+1);lim<<=1); for(int i=0;i<=n;++i) { scanf("%d",&x),A[i].x=(db)x; } for(int i=0;i<=m;++i) { scanf("%d",&x),B[i].x=(db)x; } FFT(A,lim,1),FFT(B,lim,1); for(int i=0;i<lim;++i) A[i]=A[i]*B[i]; FFT(A,lim,-1); for(int i=0;i<=n+m;++i) { printf("%d ",(int)(A[i].x+0.5)); } return 0; }
任意模数NTT (MTT)
当模数不能写成 $a \times 2^k+1$ 的时候就需要用到拆系数 FFT (MTT)了.
令 $f(x)=wf_{0}(x)+f_{1}(x)$,$g(x)$ 同理.
然后 $f*g=(wf_{0}+f_{1})(wg_{0}+g_{1})=(f_{0}g_{0})w^2+(f_{0}g_{1}+f_{1}g_{0})w+f_{1}g_{1}$.
做 7 次 FFT 即可,这个 w 选 $2^{15}$ 就好了.
#include <cstdio> #include <vector> #include <cmath> #include <cstring> #include <algorithm> #define ll long long #define db long double #define pb push_back #define N 100007 #define setIO(s) freopen(s".in","r",stdin) using namespace std; const db pi=acos(-1.0); struct cp { db x,y; cp(db a=0,db b=0) { x=a,y=b; } cp operator+(const cp &b) const { return cp(x+b.x,y+b.y); } cp operator-(const cp &b) const { return cp(x-b.x,y-b.y); } cp operator*(const cp &b) const { return cp(x*b.x-y*b.y,x*b.y+y*b.x); } }f[2][N<<2],g[2][N<<2],ans[3][N<<2]; int A[N],B[N]; int lim; ll C[N]; void FFT(cp *a,int len,int op) { for(int i=0,k=0;i<len;++i) { if(i>k) swap(a[i],a[k]); for(int j=len>>1;(k^=j)<j;j>>=1); } for(int l=1;l<len;l<<=1) { cp wn(cos(pi/l),op*sin(pi/l)),x,y; for(int i=0;i<len;i+=l<<1) { cp w(1,0); for(int j=0;j<l;++j) { x=a[i+j],y=w*a[i+j+l]; a[i+j]=x+y; a[i+j+l]=x-y; w=w*wn; } } } } ll nor(db x,ll mod) { return (ll)((ll)(x/lim+0.5)%mod+mod)%mod; } void MTT(int *a,int n,int *b,int m,ll mod,ll *c) { for(lim=1;lim<=(n+m);lim<<=1); for(int i=0;i<=n;++i) { f[0][i].x=a[i]>>15; f[1][i].x=a[i]&0x7fff; } for(int i=0;i<=m;++i) { g[0][i].x=b[i]>>15; g[1][i].x=b[i]&0x7fff; } FFT(f[0],lim,1),FFT(f[1],lim,1); FFT(g[0],lim,1),FFT(g[1],lim,1); for(int i=0;i<lim;++i) { ans[0][i]=f[0][i]*g[0][i]; ans[1][i]=f[0][i]*g[1][i]+f[1][i]*g[0][i]; ans[2][i]=f[1][i]*g[1][i]; } FFT(ans[0],lim,-1); FFT(ans[1],lim,-1); FFT(ans[2],lim,-1); for(int i=0;i<=n+m;++i) { ll x=(nor(ans[0][i].x,mod)<<30ll)%mod; ll y=(nor(ans[1][i].x,mod)<<15ll)%mod; ll z=nor(ans[2][i].x,mod)%mod; c[i]=((x+y)%mod+z)%mod; } } int main() { //setIO("input"); int n,m; ll mod; scanf("%d%d%lld",&n,&m,&mod); for(int i=0;i<=n;++i) scanf("%d",&A[i]); for(int i=0;i<=m;++i) scanf("%d",&B[i]); MTT(A,n,B,m,mod,C); for(int i=0;i<=n+m;++i) printf("%lld ",C[i]); return 0; }
多项式求逆
公式 $B=2B'-AB'^2$
这里注意复制 $A$ 数组的时候不要复制多了,否则会让前面的 B 多算.
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 100008 #define ll long long #define pb push_back #define mod 998244353 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int A[N<<2],B[N<<2],f[N<<1],g[N<<1]; int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) { if(y&1) tmp=(ll)tmp*x%mod; } return tmp; } int get_inv(int x) { return qpow(x,mod-2); } void NTT(int *a,int len,int op) { for(int i=0,k=0;i<len;++i) { if(i>k) swap(a[i],a[k]); for(int j=len>>1;(k^=j)<j;j>>=1); } for(int l=1;l<len;l<<=1) { int wn=qpow(3,(mod-1)/(l<<1)); if(op==-1) { wn=get_inv(wn); } for(int i=0;i<len;i+=l<<1) { int w=1,x,y; for(int j=0;j<l;++j) { x=a[i+j],y=(ll)w*a[i+j+l]%mod; a[i+j]=(ll)(x+y)%mod; a[i+j+l]=(ll)(x-y+mod)%mod; w=(ll)w*wn%mod; } } } if(op==-1) { int iv=get_inv(len); for(int i=0;i<len;++i) { a[i]=(ll)a[i]*iv%mod; } } } void get_inv(int *a,int *b,int len,int la) { if(len==1) { b[0]=get_inv(a[0]); return; } get_inv(a,b,len>>1,la); int l=len<<1; for(int i=0;i<min(len,la);++i) A[i]=a[i]; for(int i=0;i<len>>1;++i) B[i]=b[i]; for(int i=min(len,la);i<l;++i) A[i]=0; for(int i=len>>1;i<l;++i) B[i]=0; NTT(A,l,1),NTT(B,l,1); for(int i=0;i<l;++i) { A[i]=(ll)A[i]*B[i]%mod*B[i]%mod; } NTT(A,l,-1); for(int i=0;i<len;++i) { b[i]=(ll)((ll)(b[i]<<1)%mod-A[i]+mod)%mod; } } int main() { // setIO("input"); int n,lim; scanf("%d",&n); for(int i=0;i<n;++i) { scanf("%d",&g[i]); } for(lim=1;lim<n;lim<<=1); get_inv(g,f,lim,n); for(int i=0;i<n;++i) { printf("%d ",f[i]); } return 0; }
分治NTT
#include <cstdio> #include <cstring> #include <algorithm> #define N 100008 #define ll long long #define mod 998244353 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int A[N<<2],B[N<<2],f[N],g[N]; int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) { if(y&1) tmp=(ll)tmp*x%mod; } return tmp; } int get_inv(int x) { return qpow(x,mod-2); } void NTT(int *a,int len,int op) { for(int i=0,k=0;i<len;++i) { if(i>k) swap(a[i],a[k]); for(int j=len>>1;(k^=j)<j;j>>=1); } for(int l=1;l<len;l<<=1) { int wn=qpow(3,(mod-1)/(l<<1)),x,y,w; if(op==-1) { wn=get_inv(wn); } for(int i=0;i<len;i+=l<<1) { w=1; for(int j=0;j<l;++j) { x=a[i+j],y=(ll)a[i+j+l]*w%mod; a[i+j]=(ll)(x+y)%mod; a[i+j+l]=(ll)(x-y+mod)%mod; w=(ll)w*wn%mod; } } } if(op==-1) { int iv=get_inv(len); for(int i=0;i<len;++i) { a[i]=(ll)a[i]*iv%mod; } } } void solve(int l,int r) { if(l==r) { return; } int mid=(l+r)>>1,lim,s1=0,s2=0; solve(l,mid); for(int i=l;i<=mid;++i) A[s1++]=f[i]; for(int i=0;i<=r-l;++i) B[s2++]=g[i]; for(lim=1;lim<(s1+s2);lim<<=1); for(int i=s1;i<lim;++i) A[i]=0; for(int i=s2;i<lim;++i) B[i]=0; NTT(A,lim,1),NTT(B,lim,1); for(int i=0;i<lim;++i) A[i]=(ll)A[i]*B[i]%mod; NTT(A,lim,-1); for(int i=mid+1;i<=r;++i) { (f[i]+=A[i-l])%=mod; } solve(mid+1,r); } int main() { // setIO("input"); int n; scanf("%d",&n); for(int i=1;i<n;++i) { scanf("%d",&g[i]); } f[0]=1; solve(0,n-1); for(int i=0;i<n;++i) { printf("%d ",f[i]); } return 0; }