仓鼠的DP课 学习笔记

Part1 一点杂题

agc034_e Complete Compress

题目链接

枚举最终这些棋子被移到了哪个节点,把这个终点拿出来作为根\(root\)

我们一次操作一定是把两个棋子各向根移动一步,这需要这两个棋子不是“祖先-后代”的关系。则一个节点\(u\)需要操作的次数是\(dis(u,root)\)。我们把每个初始时有棋子的节点\(u\)看做\(dis(u,root)\)个小石子。考虑根的每个“儿子的子树”。我们一次操作可以选择两个不同的子树,然后把它们的总石子数各减少\(1\)(这样可以保证选择的两个石子所在的节点不是“祖先-后代”关系)。目标是要让每个儿子的总石子数都为\(0\)

这就涉及到一个经典的模型。有\(n\)堆石子,每堆石子有\(a_i\)个。每次可以选择两个不同的堆,同时取走一枚石子。问能否通过若干次操作取完所有石子(即让每堆的石子总数都变为\(0\))。设所有堆的石子总数为\(sum\),最大的那一堆的石子数为\(max\)

  • \(max>sum-max\)时,显然无论怎么操作都无法取完所有石子。因为最大的那一堆一定会剩下\(max-(sum-max)\)个石子。
  • \(max\leq sum-max\)时,有方法可以取完所有石子(或者在\(sum\)为奇数时让石子只剩\(1\)个),构造如下:把所有石子排成一排,同一堆内的石子放在连续的一段。我们把石子按如下方法两两配对:\((1,1+\lfloor\frac{sum}{2}\rfloor),(2,2+\lfloor\frac{sum}{2}\rfloor),\dots,(\lfloor\frac{sum}{2}\rfloor,\lfloor\frac{sum}{2}\rfloor+\lfloor\frac{sum}{2}\rfloor)\)。显然,因为没有一堆石子的数量超过\(\lfloor\frac{sum}{2}\rfloor\),所以每一对石子都来自不同的堆。

在本题中,取一个石子就相当于向根走一步,因此,在第二种情况下,一定能把所有石子都移动到根。

考虑第一种情况,此时\(max\)的子树内,石子太多了,我们要让它内部消化掉一些。于是我们可以递归考虑\(max\)这个节点的所有儿子的子树。重复上面所描述的判断。

具体地,我们记一个\(f(u)\)表示在\(u\)的子树内,最多能进行多少次“把两个棋子同时向上移”的操作,也即最多能消掉多少对石子。显然,实际上我们可以选择进行\([0,f(u)]\)之间的任意次操作。

在以一个节点\(u\)为根的子树内,求出每个儿子的石子数。

  • \(max>sum-max\)时,设最大的那个儿子为\(v\),则\(f(u)=sum-max+\min(f(v),2max-sum)\)
  • \(max\leq sum-max\)时,\(f(u)=\lfloor\frac{sum}{2}\rfloor\)

如果根节点处的\(sum\)是偶数,且\(f(root)=\frac{sum}{2}\),则有解,否则以\(root\)为根时无解。有解时,答案对\(f(root)\)\(\min\)

这个DP是\(O(n)\)的。因为要枚举根。故总时间复杂度\(O(n^2)\)

参考代码:

//problem:agc034_e
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

/*  ------  by:duyi  ------  */ // myt天下第一
const int MAXN=2000,INF=1e9;
int n,s[MAXN+5],num[MAXN+5],f[MAXN+5];
char str[MAXN+5];
vector<int>G[MAXN+5];
void dfs(int u,int fa){
	int son=0;
	s[u]=0;num[u]=(str[u]=='1');
	for(int i=0;i<(int)G[u].size();++i){
		int v=G[u][i];
		if(v==fa)continue;
		dfs(v,u);
		num[u]+=num[v];
		s[v]+=num[v];
		s[u]+=s[v];
		if(!son||s[v]>s[son]){
			son=v;
		}
	}
	if(!son){f[u]=0;return;}
	
	if(s[son]<=s[u]-s[son]){
		f[u]=s[u]/2;
	}
	else{
		f[u]=s[u]-s[son]+min(f[son],(s[u]-(s[u]-s[son]))/2);
	}
}
int main() {
	scanf("%d%s",&n,str+1);
	for(int i=1,u,v;i<n;++i)scanf("%d%d",&u,&v),G[u].pb(v),G[v].pb(u);
	int ans=INF;
	for(int rt=1;rt<=n;++rt){
		dfs(rt,0);
		if(s[rt]&1)continue;
		if(f[rt]>=s[rt]/2){
			ans=min(ans,s[rt]/2);
		}
	}
	if(ans==INF)puts("-1");
	else printf("%d\n",ans);
	return 0;
}

CF908G New Year and Original Order

题目链接

考虑每个数码\(d(1\leq d\leq9)\)对答案的贡献。

考虑朴素的数位DP。设\(dp[i][j][k][0/1]\)表示从高到低考虑了前\(i\)位,有\(j\)个等于\(d\)的数位,有\(k\)个大于\(d\)的数位,当前数是否\(=X\),这些条件下的方案数。转移时枚举下一位填什么数。DP的复杂度为\(O(n^3\cdot 10)\),考虑优化。

\(c(d)\)\(d\)对答案贡献的系数,则\(c(d)\)形式上应该是\(\sum10^i\)。则\(ans=\sum_{d=1}^{9}d\cdot c(d)=\sum_{d=1}^{9}\sum_{i=d}^{9}c(i)\)。这是因为考虑每个\(c(i)\)会被计算\(i\)次,即在\(\leq i\)的每个\(d\)的位置都被计算到一次。

考虑枚举\(d\),求\(\sum_{i=d}^{9}c(i)\)。这相当于求所有\(\geq d\)的数位的\(10^i\)之和。而按题目要求把数位排序后,\(\geq d\)的数位一定是从最低位开始的连续的一段。于是我们只要记录\(\geq d\)的数位有多少个即可。设\(dp[i][j][0/1]\)表示考虑了前\(i\)位,\(\geq d\)的数位有\(j\)个的方案数。DP的复杂度变为\(O(n^2\cdot 10)\)。因为要枚举\(d\),故总时间复杂度\(O(n^2\cdot 10\cdot 9)\)

参考代码:

//problem:CF908G
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

/*  ------  by:duyi  ------  */ // myt天下第一
const int MAXN=700,MOD=1e9+7;
inline int mod1(int x){return x<MOD?x:x-MOD;}
inline int mod2(int x){return x<0?x+MOD:x;}
inline void add(int &x,int y){x=mod1(x+y);}
inline void sub(int &x,int y){x=mod2(x-y);}

char s[MAXN+5];
int n,dp[MAXN+5][MAXN+5][2];
int solve(int d){
	memset(dp,0,sizeof(dp));
	dp[0][0][1]=1;
	for(int i=0;i<n;++i){
		// dp[i] -> dp[i+1]
		for(int j=0;j<=i;++j){
			for(int t=0;t<=1;++t)if(dp[i][j][t]){
				for(int x=0;x<=(t?s[i+1]-'0':9);++x){
					add(dp[i+1][j+(x>=d)][t&&(x==(s[i+1]-'0'))],dp[i][j][t]);
				}
			}
		}
	}
	int cur=1,sum=0,res=0;
	for(int j=1;j<=n;++j){
		add(sum,cur);
		add(res,(ll)sum*mod1(dp[n][j][0]+dp[n][j][1])%MOD);
		cur=10LL*cur%MOD;
	}
	return res;
}
int main() {
	scanf("%s",s+1);n=strlen(s+1);
	int ans=0;
	for(int i=1;i<=9;++i){
		add(ans,solve(i));
	}
	printf("%d\n",ans);
	return 0;
}

agc024_f Simple Subsequence Problem

题目链接

因为所有可能的答案只有\(2^{n+1}-1\)种,故考虑求出每个长度\(\leq n\)的01串分别是多少个给定串的子序列。

考虑识别一个串\(A\)是不是另一个串\(B\)的子序列,我们可以做一个简单的贪心。维护一个指向\(B\)的下标的“指针”\(p\),初始时是\(0\)。对于\(A\)的每一位,把\(p\)移到\(p\)后面第一个等于\(A\)的这一位的位置。

考虑用一个自动机来描述这个贪心的过程。自动机的一个节点\((S,T)\)表示当前已经匹配好的\(A\)的前缀是\(S\),剩下的部分是\(T\)的一个子序列。当接下来要匹配\(A\)的某一位时,如果这一位是1,则转移到\((S+'1',T_1)\),否则转移到\((S+'0',T_0)\)。其中\(T_c\)表示串\(T\)的第一个\(c\)后面的部分。如\(T=\)00110时,\(T_1=\)10\(T_0=\)0110

这条路径的起点为\((\emptyset,B)\),终点为\((A,\emptyset)\)\(\emptyset\)表示空串)。

我们以所有题目给出的集合中的串为起点,就可以得到一张图。因为在走的过程中\(S\)长度严格递增,\(T\)长度严格递减,所以一定不会走出环。因此这个图是一个DAG。

又因为我们匹配的过程是基于“每次走到接下来第一个能匹配的位置”的贪心,因此从某个起点出发,识别一个串\(A\)时走的路径是唯一的。也即从每个起点到点\((A,\emptyset)\),要么没有路径,要么有一条唯一的路径

因此,我们可以在DAG上做DP。求出到每个终点的路径数,就是这个串是多少个起点的子序列。

\(dp(S,T)\)表示走到DAG上\((S,T)\)这个节点的方案数。对于每个题目给出的集合中的串\(s\),初始化\(dp(\emptyset,s)=1\)。在具体实现中,我们开一个二维数组,\(dp[i][j]\),其中\(i\)表示\(T\)这个串的长度,\(j\)\(S+T\)这个串的二进制状压。因为\(S+T\)这个串可能有前导零,我们可以在最高位的下一位放一个\(1\),表示这里是最高位。

DP时先枚举\(T\)的长度\(i\)(因为一定是从长到段转移),再枚举\(S+T\)的长度\(j(j\geq i)\),再枚举\(S+T\)这个串\(k\)。则当前的状态就是\(dp[i][k|(1<<j)]\)。转移有三种:直接结束(走到节点\((S,\emptyset)\)),匹配到下一个\(1\),或匹配到下一个\(0\)

因为总状态数是\(O(2^nn)\)的。转移时找下一位\(0/1\)的复杂度为\(O(n)\),故总时间复杂度\(O(2^nn^2)\)(常数小可过)。如果预处理每个长度\(\leq n\)的二进制串的下一个\(0/1\)在哪一位,则可以优化到\(O(2^nn)\)

参考代码:

//problem:agc024_f
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

/*  ------  by:duyi  ------  */ // myt天下第一
int n,K,dp[21][1<<21];
char s[(1<<20)+5];
int main() {
	scanf("%d%d",&n,&K);
	for(int i=0;i<=n;++i){
		scanf("%s",s);
		for(int j=0;j<(1<<i);++j)if(s[j]=='1'){
			dp[i][j|(1<<i)]=1;
		}
	}
	for(int i=n;i>=1;--i){//未确定部分的长度
		for(int j=i;j<=n;++j){//整个串的长度
			int b=(1<<j);
			for(int k=0;k<b;++k){
				int tmp=dp[i][k|b],p=-1,q=-1;if(!tmp)continue;
				for(int t=i-1;t>=0;--t)if(((k>>t)&1)==1){p=t;break;}
				for(int t=i-1;t>=0;--t)if(((k>>t)&1)==0){q=t;break;}
				dp[0][(k+b)>>i]+=tmp;
				if(p!=-1)dp[p][((((k>>i)<<1)|1)<<p)|(((1<<p)-1)&k)|(1<<(j-(i-p)+1))]+=tmp;
				if(q!=-1)dp[q][(((k>>i)<<1)<<q)|(((1<<q)-1)&k)|(1<<(j-(i-q)+1))]+=tmp;
			}
		}
	}
	for(int i=n;i>=0;--i){
		for(int j=0;j<(1<<i);++j){
			if(dp[0][j|(1<<i)]>=K){
				for(int k=i-1;k>=0;--k)if((j>>k)&1)putchar('1');else putchar('0');
				puts("");
				return 0;
			}
		}
	}
	assert(0);
}

nflsoj49 【清华集训2017】某位歌姬的故事

题目链接-nflsoj49

题目链接-loj2331

对序列的每个位置\(i\),我们维护一个\(up_i\),表示这个位置最大能填几。那么一个限制\((l,r,w)\),就相当于让\(i\in[l,r]\)的每个\(up_i\)的值对\(w\)\(\min\)

当然,还可能有一些位置从始至终未被任何一个限制覆盖到。记这样的位置有\(cnt\)个,则我们最后把答案乘以\(A^{cnt}\)即可。

题目里的每个限制\((l,r,w)\),相当于如下两条要求:

  • \(\forall i \in[l,r]\ a_i\leq w\)

  • \(\exist i\in[l,r]\ a_i=w\)

如果让所有\(a_i\)做到\(a_i\leq up_i\),则第一条限制就已经满足了。考虑如何满足第二条限制。

考虑每个\(w\)。我们发现\(\exist i\in[l,r]\ a_i=w\)只能由\(up_i=w\)\(i\)来实现。因为\(up_i<w\)\(i\)显然无法实现(否则与第一条要求矛盾);而\(up_i>w\)\(i\)说明这个位置根本没被\([l,r]\)覆盖到。

于是,我们枚举每个\(w\),把\(up_i=w\)的这些点单独拿出来做DP。设\(dp[i][j]\)表示考虑了前\(i\)\(up_i=w\)的位置,最后一个满足\(a_i=w\)(取到了这个上界)的位置在\(j\)时的方案数。对于每个位置\(i\),我们可以维护一个\(L[i]\)表示它的上一个被覆盖的位置最远可以在哪。则从\(dp[i][j]\)转移到\(dp[i+1]\)时,转移有两种:

  • \(j\geq L[i+1]\),则可以从\(dp[i][j]\)转移到\(dp[i+1][j]\)
  • 在任何情况下,我们都能从\(dp[i][j]\)转移到\(dp[i+1][i+1]\)

\(L[i]\):我们先初始化所有\(L[i]\)\(0\)。然后对于每个限制\((l,r,w)\),让\(L[r]\)\(l\)\(\max\)即可。

将所有位置都离散化后,一次DP的时间复杂度为\(O(Q^2)\)。故总时间复杂度\(O(Q^3)\)

参考代码:

//problem:nflsoj49
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
	if(S==T){
		T=(S=buf)+fread(buf,1,MAXN,stdin);
		if(S==T)return EOF;
	}
	return *S++;
}
}//namespace Fread
#ifdef ONLINE_JUDGE
	#define getchar Fread::getchar
#endif
inline int read(){
	int f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline ll readll(){
	ll f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
/*  ------  by:duyi  ------  */ // myt天下第一
const int MAXN=500,INF=0x3f3f3f3f,MOD=998244353;
inline int mod1(int x){return x<MOD?x:x-MOD;}
inline int mod2(int x){return x<0?x+MOD:x;}
inline void add(int &x,int y){x=mod1(x+y);}
inline void sub(int &x,int y){x=mod2(x-y);}
inline int pow_mod(int x,int i){int y=1;while(i){if(i&1)y=(ll)y*x%MOD;x=(ll)x*x%MOD;i>>=1;}return y;}

int len,n,A,a[MAXN*2+5],pos[MAXN*2+5],val[MAXN*2+5],cnt,cnt_w;
struct Limits{int l,r,m;}q[MAXN+5];
bool fail;
int solve(int w){
	static int p[MAXN*2+5],L[MAXN*2+5],dp[MAXN*2+5][MAXN*2+5];
	int cnt_p=0;
	for(int i=1;i<cnt;++i)if(a[i]==w)p[++cnt_p]=i,L[cnt_p]=0;
	for(int i=1;i<=n;++i)if(q[i].m==w){
		int l=lob(p+1,p+cnt_p+1,q[i].l)-p,
			r=lob(p+1,p+cnt_p+1,q[i].r)-p-1;
		if(l>r){fail=1;return 0;}
		L[r]=max(L[r],l);
	}
	memset(dp,0,sizeof(dp));
	dp[0][0]=1;
	for(int i=1;i<=cnt_p;++i){
		int v1=pow_mod(w,pos[p[i]+1]-pos[p[i]]),v2=pow_mod(w-1,pos[p[i]+1]-pos[p[i]]);
		for(int j=0;j<i;++j){
			if(j>=L[i])add(dp[i][j],(ll)dp[i-1][j]*v2%MOD);
			add(dp[i][i],(ll)dp[i-1][j]*mod2(v1-v2)%MOD);
		}
	}
	int res=0;
	for(int j=0;j<=cnt_p;++j)add(res,dp[cnt_p][j]);
	return res;
}
int main() {
	int Testcases=read();while(Testcases--){
		len=read();n=read();A=read();
		cnt=cnt_w=0;
		memset(a,0x3f,sizeof(a));
		for(int i=1;i<=n;++i){
			q[i].l=read(),q[i].r=read()+1,q[i].m=read();
			pos[++cnt]=q[i].l,pos[++cnt]=q[i].r,val[++cnt_w]=q[i].m;
		}
		pos[++cnt]=1;pos[++cnt]=len+1;
		sort(pos+1,pos+cnt+1);
		cnt=unique(pos+1,pos+cnt+1)-(pos+1);
		
		sort(val+1,val+cnt_w+1);
		cnt_w=unique(val+1,val+cnt_w+1)-(val+1);
		
		for(int i=1;i<=n;++i){
			q[i].l=lob(pos+1,pos+cnt+1,q[i].l)-pos;
			q[i].r=lob(pos+1,pos+cnt+1,q[i].r)-pos;
			for(int j=q[i].l;j<q[i].r;++j)a[j]=min(a[j],q[i].m);
		}
		fail=0;
		int ans=1;
		for(int i=1;i<=cnt_w;++i){
			ans=(ll)ans*solve(val[i])%MOD;
			if(fail)break;
		}
		if(fail){puts("0");continue;}
		for(int i=1;i<cnt;++i){
			if(a[i]==INF)ans=(ll)ans*pow_mod(A,pos[i+1]-pos[i])%MOD;
		}
		printf("%d\n",ans);
	}
	return 0;
}

Part2 笛卡尔树DP

loj2688 「POI2015」洗车 Car washes

题目链接

先把权值离散化。

考虑原序列的笛卡尔树(以最小值为根)。笛卡尔树的每个子树对应原序列的一个区间。本题中,我们用一个区间来表示一棵子树,在写法上类似于区间DP。对于笛卡尔树上一个子树\([l,r]\),设它的最小值所在位置为\(p\)(也就是子树\([l,r]\)的根节点为\(p\)),最小值为\(v\),则我们在当前子树上统计所有完全包含在本区间内,且经过根节点的车产生的贡献。即\(l\leq a_i\leq p\leq b_i\leq r\)的这些车。把这个贡献记为\(cost(l,r,p,v)\)

\(dp[l][r][v]\)表示整个\([l,r]\)子树内,最小值为\(v\)时,产生的贡献的最大值。转移时枚举最小值的位置\(p\),则:

\[dp[l][r][v]=\max_{p=l}^{r}\{dp[l][p-1][x\geq v]+dp[p+1][r][y\geq v]+cost(l,r,p,v)\} \]

可以发现,\(cost(l,r,p,v)\)就等于满足\(l\leq a_i\leq p\leq b_i\leq r\)\(c_i\geq v\)的车的数量,乘以\(v\)。而这个数量可以做三维前缀和\(O(1)\)查询。我们求完每个子树\(dp[l][r][\dots]\)后,预处理出DP数组关于最后一维的后缀最大值,就可以对每个 \(p\) \(O(1)\) 计算了。

时间复杂度\(O(n^3m)\)

参考代码:

//problem:loj2688
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
	if(S==T){
		T=(S=buf)+fread(buf,1,MAXN,stdin);
		if(S==T)return EOF;
	}
	return *S++;
}
}//namespace Fread
#ifdef ONLINE_JUDGE
	#define getchar Fread::getchar
#endif
inline int read(){
	int f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline ll readll(){
	ll f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
/*  ------  by:duyi  ------  */ // myt天下第一
const int MAXN=50,MAXM=4000;
int n,m,a[MAXM+5],b[MAXM+5],c[MAXM+5],val[MAXM+5],cnt,s[MAXN+5][MAXN+5][MAXM+5],f[MAXN+5][MAXN+5][MAXM+5];
pii g[MAXN+5][MAXN+5][MAXM+5];
inline int cost(int i,int p,int j,int v){
	/*
	i<=a<=p<=b<=j
	c>=v
	*/
	//int res=0;for(int t=1;t<=m;++t)if(i<=a[t] && a[t]<=p && p<=b[t] && b[t]<=j && c[t]>=v)res++;return res;
	return (s[p][j][cnt]-s[p][p-1][cnt]-s[i-1][j][cnt]+s[i-1][p-1][cnt])
			-(s[p][j][v-1]-s[p][p-1][v-1]-s[i-1][j][v-1]+s[i-1][p-1][v-1]);
}
struct node{
	int lv,rv,p;
	node(){}
	node(int _lv,int _rv,int _p){lv=_lv,rv=_rv,p=_p;}
}tr[MAXN+5][MAXN+5][MAXM+5];
int res[MAXN+5];
void get_ans(int l,int r,int v,int p){
	//cout<<l<<" "<<r<<" "<<v<<" "<<p<<endl;
	//assert(v);assert(p);
	res[p]=v;
	if(p!=l){
		get_ans(l,p-1,tr[l][r][v].lv,tr[l][p-1][tr[l][r][v].lv].p);
	}
	if(p!=r){
		get_ans(p+1,r,tr[l][r][v].rv,tr[p+1][r][tr[l][r][v].rv].p);
	}
}
int main() {
	//freopen("1.in","r",stdin);
	n=read();m=read();
	for(int i=1;i<=m;++i)a[i]=read(),b[i]=read(),c[i]=read(),val[i]=c[i];
	sort(val+1,val+m+1);cnt=unique(val+1,val+m+1)-(val+1);
	for(int i=1;i<=m;++i)c[i]=lob(val+1,val+cnt+1,c[i])-val,s[a[i]][b[i]][c[i]]++;//离散化
	
	for(int i=1;i<=n;++i)for(int j=1;j<=n;++j)for(int k=1;k<=cnt;++k)s[i][j][k]+=s[i][j][k-1];
	for(int i=1;i<=n;++i)for(int j=1;j<=n;++j)for(int k=1;k<=cnt;++k)s[i][j][k]+=s[i][j-1][k];
	for(int i=1;i<=n;++i)for(int j=1;j<=n;++j)for(int k=1;k<=cnt;++k)s[i][j][k]+=s[i-1][j][k];//三维前缀和
	
	for(int i=1;i<=n;++i){
		for(int j=1;j<=cnt;++j)f[i][i][j]=cost(i,i,i,j)*val[j],tr[i][i][j].p=i;
		for(int j=cnt;j>=1;--j)g[i][i][j]=max(g[i][i][j+1],mk(f[i][i][j],j));
	}
	for(int len=2;len<=n;++len){
		for(int i=1;i+len-1<=n;++i){
			int j=i+len-1;
			for(int p=i;p<=j;++p){
				for(int v=1;v<=cnt;++v){
					int l=(p==i?0:g[i][p-1][v].fst);
					int r=(p==j?0:g[p+1][j][v].fst);
					if(l+r+cost(i,p,j,v)*val[v]>=f[i][j][v]){
						f[i][j][v]=l+r+cost(i,p,j,v)*val[v];
						int _l=(p==i?0:g[i][p-1][v].scd);
						int _r=(p==j?0:g[p+1][j][v].scd);
						tr[i][j][v]=node(_l,_r,p);
					}
					//f[i][j][v]=max(f[i][j][v],l+r+cost(i,p,j,v)*val[v]);
				}
			}
			for(int v=cnt;v>=1;--v){
				g[i][j][v]=max(g[i][j][v+1],mk(f[i][j][v],v));
			}
		}
	}
	int ans=0,v=0;
	for(int i=1;i<=cnt;++i)if(f[1][n][i]>=ans)ans=f[1][n][i],v=i;
	printf("%d\n",ans);
	get_ans(1,n,v,tr[1][n][v].p);
	for(int i=1;i<=n;++i)printf("%d ",val[res[i]]);puts("");
	return 0;
}

bzoj2616 SPOJ PERIODNI

题目链接

定义柱状图的一个子图是柱状图横坐标的一个区间,截掉了下面的一定高度得到的柱状图,要求子图中每列高度至少为\(1\)。按定义,我们从完整的原图开始递归,每次找出当前区间中高度的最小值,把高度最小的这些列去掉之后,分出若干个更小的区间,将它们的下面(等于最小列高度的部分)砍掉,继续递归这些小子图。递归的边界是如果当前区间内所有列高度相同,则不再继续递归。按此方法,显然可以划分出\(O(n)\)个子图,这些子图之间形成了树状的结构。我们可以把这个结构理解为一种广义的“笛卡尔树”,虽然它并不是二叉树。

注意到:原序列的一个区间、柱状图的一个子图、树上的一个节点这三个概念现在是等价的。

我们在这个树形结构上DP。用区间\([l,r]\)来表示树上的一个节点(也就是一个子图)。则这个子图可以被划分为下方的一个极大完整矩形,和上面的若干小子图(也就是它的儿子)。设\(dp[l,r][k]\)表示考虑了\([l,r]\)这个子图,在里面放置了\(k\)个车的方案数。

先递归所有儿子,显然这些儿子子图之间互不影响,所以把它们用背包合并起来。

然后考虑下方的极大完整矩形。假设这个矩形大小为\(x\times y\)。如果我们在所有儿子中一共放了\(k\)个车,则会有\(k\)列是不能再放的了。那么在下面再放\(i\)个车方案数就是\({x\choose i}{y-k\choose i}i!\),我们枚举\(i,k\),用类似于卷积的方法暴力合并即可。

本题的关键思想是:先算上方小区间(儿子),再算下方大区间(父亲),因为大区间里一定包括了小区间的所有列,这样计算时直接减去这\(k\)个用过的列即可。而不用关心具体是那些列被用过了,这就大大提高了效率。

背包暴力合并是\(O(n^2)\)的,因为共有\(O(n)\)个子图,故总时间复杂度\(O(n^3)\)

参考代码:

//problem:bzoj2616
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
	if(S==T){
		T=(S=buf)+fread(buf,1,MAXN,stdin);
		if(S==T)return EOF;
	}
	return *S++;
}
}//namespace Fread
#ifdef ONLINE_JUDGE
	#define getchar Fread::getchar
#endif
inline int read(){
	int f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline ll readll(){
	ll f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
/*  ------  by:duyi  ------  */ // myt天下第一
const int MAXN=500,MOD=1000000007,MAXH=1000000;
inline int mod1(int x){return x<MOD?x:x-MOD;}
inline int mod2(int x){return x<0?x+MOD:x;}
inline void add(int &x,int y){x=mod1(x+y);}
inline void sub(int &x,int y){x=mod2(x-y);}
inline int pow_mod(int x,int i){int y=1;while(i){if(i&1)y=(ll)y*x%MOD;x=(ll)x*x%MOD;i>>=1;}return y;}
int fac[MAXH+5],invf[MAXH+5];
inline int comb(int n,int k){
	if(n<k)return 0;
	return (ll)fac[n]*invf[k]%MOD*invf[n-k]%MOD;
}

int n,K,a[MAXN+5],dp[MAXN+5][MAXN+5],id[MAXN+5][MAXN+5],cnt,tmp[MAXN+5];
void solve(int l,int r,int h){
	//cout<<l<<" "<<r<<endl;
	int x;assert(!id[l][r]);x=id[l][r]=++cnt;
	int mn=MAXH+1;
	vector<int>pos;
	for(int i=l;i<=r;++i){
		if(a[i]<mn)mn=a[i],pos.clear();
		if(a[i]==mn)pos.pb(i);
	}
	assert(mn>h);
	if(pos.size()==r-l+1){
		for(int i=0;i<=min(mn-h,r-l+1);++i)dp[x][i]=(ll)comb(mn-h,i)*comb(r-l+1,i)%MOD*fac[i]%MOD;
		return;
	}
	pos.pb(r+1);
	int lst=l-1;
	dp[x][0]=1;
	for(int i=0;i<(int)pos.size();++i){
		if(pos[i]>lst+1){
			solve(lst+1,pos[i]-1,mn);
			int y=id[lst+1][pos[i]-1];
			for(int j=0;j<=r-l+1;++j){
				tmp[j]=0;
				for(int k=0;k<=j;++k){
					add(tmp[j],(ll)dp[x][k]*dp[y][j-k]%MOD);
				}
			}
			for(int j=0;j<=r-l+1;++j)dp[x][j]=tmp[j];
		}
		lst=pos[i];
	}
	memset(tmp,0,sizeof(tmp));
	for(int i=0;i<=min(mn-h,r-l+1);++i){
		for(int j=0;j+i<=r-l+1;++j){
			add(tmp[i+j],(ll)dp[x][j]*comb(r-l+1-j,i)%MOD*comb(mn-h,i)%MOD*fac[i]%MOD);
		}
	}
	for(int i=0;i<=r-l+1;++i)dp[x][i]=tmp[i];
}
int main() {
	fac[0]=1;
	for(int i=1;i<=MAXH;++i)fac[i]=(ll)fac[i-1]*i%MOD;
	invf[MAXH]=pow_mod(fac[MAXH],MOD-2);
	for(int i=MAXH-1;i>=0;--i)invf[i]=(ll)invf[i+1]*(i+1)%MOD;
	
	n=read();K=read();if(K>n){cout<<0<<endl;return 0;}
	for(int i=1;i<=n;++i)a[i]=read();
	solve(1,n,0);
//	while(1){
//		int l=read(),r=read(),x=read();
//		cout<<dp[id[l][r]][x]<<endl;
//	}
	cout<<dp[id[1][n]][K]<<endl;
	return 0;
}

agc026_d Histogram Coloring

题目链接

考虑一个完整矩形的情况。如果第一行是01交替出现的,则下一行的每个位置可以和第一行相同,也可以不同;否则每个位置必须与前一行不同

和上一题类似地,我们定义柱状图的子图。以及这些子图间构成了树形结构。这里不再赘述。

\(dp1[H]\)表示子图\(H\)的最下面一行按01间隔染色时整个子图的合法染色方案数,\(dp2[H]\)表示子图\(H\)最下面一行没有限制时整张图的合法染色方案数。设\(H\)下方的极大完整矩形宽度为\(x\)(也就是\(H\)里最低的一列高度为\(x\)),\(H\)中有\(w\)列高度为\(x\)(并列最低),将\(H\)从下面截取\(x\)的高度后得到若干个子图(也就是\(H\)的儿子)为\(c_1,c_2,\dots,c_k\)

\(dp1[H]=2^x\prod_{i=1}^{k}dp1[c_i]\),因为第一行交替染色时图中某一列可以任意染色,染好这一列后其他位置的颜色都可以确定。

\(dp2[H]=2^w\prod_{i=1}^{k}(dp1[c_i]+dp2[c_i])+(2^x-2)\prod_{i=1}^{k}dp1[c_i]\)。后半部分是钦定每一列都不交替染色,则一列有\((2^x-2)\)种染法,此时第一行必须01交替,因此截下的子图\(c_i\)的第一行也是交替的(\(dp1\))。前半部分让每一列都01交替,则要么0开头,要么1开头,有\(2\)种染法。因此不在\(c_i\)下方的部分方案数为\(2^w\)。如果\(c_i\)第一行01交替的话\(c_i\)正下方的矩形部分会有\(2\)种染色方法,所以要\(dp1[c_i]\)要被计算\(2\)次(其中有一次已经包含在\(dp2[ci]\)中了)。

时间复杂度\(O(n^2)\)

参考代码(写的略显繁琐。但本题实现不难,读者可以尝试自己实现):

//problem:agc026_d
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

inline int read(){
	int f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline ll readll(){
	ll f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

const int MAXN=1005,MOD=1e9+7;
inline int pow_mod(int x,int i){int y=1;while(i){if(i&1)y=(ll)y*x%MOD;x=(ll)x*x%MOD;i>>=1;}return y;}
int n,h[MAXN];
vector<int>vec,w[MAXN];
int solve1(int l,int r,int k,int low){
	//cout<<"dp1 "<<l<<" "<<r<<" "<<k<<endl;
	int mx=0,mn=2e9;
	for(int i=l;i<=r;++i)mx=max(mx,h[i]),mn=min(mn,h[i]);
	if(mn>k)return solve1(l,r,mn,low);
	if(mx==k&&mx==mn)return pow_mod(2,vec[k]-low);
	
	int ans=1,st=lob(w[k].begin(),w[k].end(),l)-w[k].begin(),ed=lob(w[k].begin(),w[k].end(),r)-w[k].begin();
	ed=min(ed,(int)w[k].size()-1);
	if(w[k][ed]>r)ed--;
	int lst=l;
	for(int i=st;i<=ed;++i){
		assert(w[k][i]>=l&&w[k][i]<=r);
		if(lst<=w[k][i]-1)ans=(ll)ans*solve1(lst,w[k][i]-1,k+1,vec[k])%MOD;
		lst=w[k][i]+1;
	}
	if(lst<=r)ans=(ll)ans*solve1(lst,r,k+1,vec[k])%MOD;
	ans=(ll)ans*pow_mod(2,vec[k]-low)%MOD;
	return ans;
}
int solve2(int l,int r,int k,int low){
	//cout<<l<<" "<<r<<" "<<k<<endl;
	int mx=0,mn=2e9;
	for(int i=l;i<=r;++i)mx=max(mx,h[i]),mn=min(mn,h[i]);
	if(mn>k)return solve2(l,r,mn,low);
	if(mx==k&&mx==mn)return (pow_mod(2,r-l+1)-2+pow_mod(2,vec[k]-low))%MOD;
	
	int st=lob(w[k].begin(),w[k].end(),l)-w[k].begin(),ed=lob(w[k].begin(),w[k].end(),r)-w[k].begin();
	ed=min(ed,(int)w[k].size()-1);
	if(w[k][ed]>r)ed--;
	vector<int>dp1,dp2;
	int lst=l,t=0;
	for(int i=st;i<=ed;++i){
		assert(w[k][i]>=l&&w[k][i]<=r);
		t++;
		if(lst<=w[k][i]-1)dp1.pb(solve1(lst,w[k][i]-1,k+1,vec[k])),dp2.pb(solve2(lst,w[k][i]-1,k+1,vec[k]));
		lst=w[k][i]+1;
	}
	if(lst<=r)dp1.pb(solve1(lst,r,k+1,vec[k])),dp2.pb(solve2(lst,r,k+1,vec[k]));
	int ans1=1,ans2=1;
	for(int i=0;i<(int)dp1.size();++i)ans1=(ll)ans1*(dp1[i]+dp2[i])%MOD,ans2=(ll)ans2*dp1[i]%MOD;
	ans1=(ll)ans1*pow_mod(2,t)%MOD;
	ans2=(ll)ans2*(pow_mod(2,vec[k]-low)-2)%MOD;
	return (ans1+ans2)%MOD;
}
int main() {
//	freopen("data.txt","r",stdin);
	n=read();
	for(int i=1;i<=n;++i)h[i]=read(),vec.pb(h[i]);
	vec.pb(0);
	sort(vec.begin(),vec.end());
	vec.erase(unique(vec.begin(),vec.end()),vec.end());
	for(int i=1;i<=n;++i)h[i]=lob(vec.begin(),vec.end(),h[i])-vec.begin(),w[h[i]].pb(i);
	cout<<solve2(1,n,1,0)<<endl;
	return 0;
}

loj2743 「JOI Open 2016」摩天大楼

题目链接

考虑把值按从大到小的顺序插入序列中。那么每插入一个元素,其所在的连续段就是以其为根的笛卡尔树上的子树。

\(a\)序列按从大到小排序(现在 \(a_i \geq a_{i + 1}\))。我们把绝对值这个贡献做差分,然后摊到每次加入的数上。比方说加入\(a_i\)后当前序列中共有\(j\)个连续段,则差分的贡献就是\(2\cdot j\cdot(a_i - a_{i + 1})\),因为每个连续段两侧会有两个必定\(\leq a_{i+1}\)的数。特别地,当有某个连续段位于最左边或最右边时,要特判。

于是可以设计DP,令\(dp[i][j][k][x\in\{0,1\}][y\in\{0,1\}]\)表示考虑了前\(i\)大的数,当前序列中共有\(j\)个连续段,当前的差分总贡献是\(k\),序列的左、右端点有没有被加入。根据这些信息可以计算出,被拉大差值的两个相邻元素的数目,也就是对 \(k\) 新增的差分贡献。转移时,考虑新加入的元素,分“单独成为一段”、“紧贴着某个原来的段”、“合并了两个原来的段”三种情况讨论。

时间复杂度\(O(n^2L)\)

参考代码:

//problem:loj2743
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
	if(S==T){
		T=(S=buf)+fread(buf,1,MAXN,stdin);
		if(S==T)return EOF;
	}
	return *S++;
}
}//namespace Fread
#ifdef ONLINE_JUDGE
	#define getchar Fread::getchar
#endif
inline int read(){
	int f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline ll readll(){
	ll f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
/*  ------  by:duyi  ------  */ // myt天下第一
const int MAXN=100,MAXL=1000,MOD=1e9+7;
inline int mod1(int x){return x<MOD?x:x-MOD;}
inline int mod2(int x){return x<0?x+MOD:x;}
inline void add(int &x,int y){x=mod1(x+y);}
inline void sub(int &x,int y){x=mod2(x-y);}
inline int pow_mod(int x,int i){int y=1;while(i){if(i&1)y=(ll)y*x%MOD;x=(ll)x*x%MOD;i>>=1;}return y;}

int n,L,a[MAXN+5],dp[MAXN+5][MAXN+5][MAXL+5][2][2];

int main() {
	n=read();L=read();
	for(int i=1;i<=n;++i)a[i]=read();
	sort(a+1,a+n+1);reverse(a+1,a+n+1);
	if(n==1){cout<<1<<endl;return 0;}
	dp[0][0][0][0][0]=1;
	for(int i=0;i<n;++i){
		for(int j=0;j<=i;++j){
			for(int k=0;k<=L;++k){
				for(int tl=0;tl<=1;++tl){
					for(int tr=0;tr<=1;++tr)if(dp[i][j][k][tl][tr]){
						//把第i+1个数放进来
						int v=dp[i][j][k][tl][tr];
						//case1:单独一段
						if(k+((j+1)*2-tl-tr)*(a[i+1]-a[i+2])<=L)if(j+(!tl)+(!tr)+i+1<=n)
							add(dp[i+1][j+1][k+((j+1)*2-tl-tr)*(a[i+1]-a[i+2])][tl][tr],(ll)v*(j-1+(!tl)+(!tr))%MOD);
						if(!tl)if(k+((j+1)*2-1-tr)*(a[i+1]-a[i+2])<=L)if(j+(!tr)+i+1<=n)
							add(dp[i+1][j+1][k+((j+1)*2-1-tr)*(a[i+1]-a[i+2])][1][tr],v);
						if(!tr)if(k+((j+1)*2-tl-1)*(a[i+1]-a[i+2])<=L)if(j+(!tl)+i+1<=n)
							add(dp[i+1][j+1][k+((j+1)*2-tl-1)*(a[i+1]-a[i+2])][tl][1],v);
						//case2:紧贴着一段
						if(j&&k+(j*2-tl-tr)*(a[i+1]-a[i+2])<=L)if(j-1+(!tl)+(!tr)+i+1<=n)
							add(dp[i+1][j][k+(j*2-tl-tr)*(a[i+1]-a[i+2])][tl][tr],(ll)v*(j*2-tl-tr)%MOD);
						if(!tl)if(j&&k+(j*2-1-tr)*(a[i+1]-a[i+2])<=L)if(j-1+(!tr)+i+1<=n)
							add(dp[i+1][j][k+(j*2-1-tr)*(a[i+1]-a[i+2])][1][tr],v);
						if(!tr)if(j&&k+(j*2-tl-1)*(a[i+1]-a[i+2])<=L)if(j-1+(!tl)+i+1<=n)
							add(dp[i+1][j][k+(j*2-tl-1)*(a[i+1]-a[i+2])][tl][1],v);
						//case3:合并两段
						if(j>=2&&k+((j-1)*2-tl-tr)*(a[i+1]-a[i+2])<=L)if(j-2+(!tl)+(!tr)+i+1<=n)
							add(dp[i+1][j-1][k+((j-1)*2-tl-tr)*(a[i+1]-a[i+2])][tl][tr],(ll)v*(j-1)%MOD);
					}
				}
			}
		}
	}
	int ans=0;
	for(int i=0;i<=L;++i)add(ans,dp[n][1][i][1][1]);
	cout<<ans<<endl;
	return 0;
}

loj3228 「USACO 2019.12 Platinum」Tree Depth

题目链接

考虑一个经典模型:求\(1\dots n\)的恰有\(K\)个逆序对的排列数。通常有两种做法(它们本质相同):

  • 做法一:从小到大枚举每个数,把当前数\(i\)插入排列中。此时比\(i\)小的所有数的相对位置关系已经确定,而\(i\)可以插到它们之间的任意位置。所以插入\(i\)新增的逆序对数可以为\([0,i-1]\)中任意整数。
  • 做法二:从左到右考虑每个位置。对于当前考虑的位置\(i\),前\(i-1\)个位置上值的相对大小关系已经确定。而当前位置上的值在前\(i\)个值中的排名可以任意指定,故新增的逆序对数也是\([0,i-1]\)中任意整数。

本题中,题目定义的构造过程就是构造一个笛卡尔树。容易知道,对于两个位置\(i,j\)\(j\)\(i\)在笛卡尔树上的祖先,当且仅当\(\min_{k=\min(i,j)}^{\max(i,j)}\{p_k\}=p_j\)

根据期望的线性性,我们枚举\(i,j\),求出\(j\)\(i\)的祖先的合法方案数后累加到\(i\)的答案中。

考虑上述经典模型的第二种做法,但不是从左到右插入,而把插入的过程分为两部分。第一部分先从\(i\)开始,向\(j\)的方向,一直走到序列的某一端,依次决定每个位置上元素的相对大小关系;第二部分从\(i\)开始,向\(j\)反方向,一直走到序列的某一端,依次确定每个位置上元素的相对大小关系。具体地:

  • \(j<i\)时,我们先从\(i\)\(1\);再从\(i+1\)\(n\)
  • \(j>i\)时,我们先从\(i\)\(n\);再从\(i-1\)\(1\)

这样,每个新位置所新增的逆序对数是:\([0,0],[0,1],[0,2],\dots,[0,n-1]\),除了\(j\)这个位置比较特殊,它需要保证是\([\min(i,j),\max(i,j)]\)这段区间内最小的。因此当\(j<i\)时,位置\(j\)新增的逆序对数为\(0\)\(j>i\)时,位置\(j\)新增的逆序对数为\(j-i\)

朴素地做一次这样的DP,复杂度是\(O(n^2K)\)的。又因为要枚举位置\(i,j\),总复杂度为\(O(n^4K)\)

发现DP的转移相当于是乘以了一个形如\(x^0+x^1+\dots+x^d\)的生成函数。因为生成函数的最高次项是\(O(K)\)的,且所有项系数都为\(1\),故乘法可以\(O(K)\)实现。(也可以理解为是前缀和优化DP)。此时DP的复杂度降为\(O(nK)\)

我们可以先\(O(nK)\)预处理好所有生成函数的积。然后枚举\(i,j\),用预处理好的积除以一个最高次项为\(j-i\)的生成函数,再乘以\(x^0\)\(x^{j-i}\)。和乘法同理,这个除法也可以\(O(K)\)实现。又发现我们要考虑的只有\(j-i\)的差值,因此不用枚举\(i,j\),只需要枚举\(j-i\)的差即可。

时间复杂度\(O(nK)\)

参考代码:

//problem:loj3228
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

/*  ------  by:duyi  ------  */ // myt天下第一
const int MAXN=300;
int n,K,MOD;
inline int mod1(int x){return x<MOD?x:x-MOD;}
inline int mod2(int x){return x<0?x+MOD:x;}
inline void add(int &x,int y){x=mod1(x+y);}
inline void sub(int &x,int y){x=mod2(x-y);}
inline int pow_mod(int x,int i){int y=1;while(i){if(i&1)y=(ll)y*x%MOD;x=(ll)x*x%MOD;i>>=1;}return y;}

int f[MAXN*MAXN+5],D,ans[MAXN+5];
void mul(int d){
	//乘以 (x^0+x^1+...+x^{d-1})
	for(int i=D;i<D+d-1;++i)f[i]=0;
	D+=d-1;
	static int s[MAXN*MAXN+5];
	s[0]=f[0];
	for(int i=1;i<D;++i)s[i]=mod1(s[i-1]+f[i]);
	for(int i=0;i<D;++i)f[i]=mod2(s[i]-(i-(d-1)>0?s[i-(d-1)-1]:0));
}
void div(int d){
	//除以 (x^0+x^1+...+x^{d-1})
	static int g[MAXN*MAXN+5];
	assert(D>=d);
	for(int i=D-d,s=0;i>=0;--i){
		if(i+d<=D-d)sub(s,g[i+d]);
		g[i]=mod2(f[i+(d-1)]-s);
		add(s,g[i]);
	}
	for(int i=D-(d-1);i<D;++i)f[i]=0;
	D-=(d-1);
	for(int i=0;i<D;++i)f[i]=g[i];
}
int main() {
	scanf("%d%d%d",&n,&K,&MOD);
	f[0]=1;D=1;
	for(int i=2;i<=n;++i)mul(i);
	//for(int i=0;i<D;++i)cout<<f[i]<<" ";cout<<endl;
	for(int d=1;d<n;++d){//(i,j)的距离 abs(i-j)=d
		div(d+1);
		for(int i=d+1;i<=n;++i)add(ans[i],f[K]);
		if(K>=d)for(int i=1;i+d<=n;++i)add(ans[i],f[K-d]);
		mul(d+1);
	}
	for(int i=1;i<=n;++i)printf("%d ",mod1(ans[i]+f[K]));puts("");
	return 0;
}

Part3 DP套DP

DP套DP是给定一个DP问题A,用另一个DP去计算一种可能的A的输入,使得A的DP结果为x。

说白了就是,外层的DP的状态是另一个DP的结果

这样的问题,往往需要深入挖掘内层DP的性质,有时候还要对状态数有一个合理的估计甚至是大胆的猜想。

loj6274 数字

题目链接

考虑这样一个问题:给定两个数字\(P,Q\),求是否存在\(x,y\),满足\(L_x\leq x\leq R_x,L_y\leq y\leq R_y\),使得\(x\operatorname{OR}y=P,x\operatorname{AND}y=Q\)

这个问题可以用数位DP解决,令\(f[i][a][b][c][d]\ (a,b,c,d\in\{0,1\})\)表示从高到低考虑到第\(i\)位,\(x\)是否大于\(L_x\),是否小于\(R_x\)\(y\)是否大于\(L_y\),是否小于\(R_y\)。转移时枚举\(x,y\)的第\(i\)位分别是什么即可。且因为我们只需要判断\(x,y\)的存在性,所以\(f\)数组的取值为\(\{0,1\}\)

回到原问题。我们令\(F[i][s]\)表示从高到低考虑到第\(i\)位,小\(f\)数组是\(s\)时,共有多少个\(Q\)能达到此状态。至于我们怎么用一个整数\(s\)来描述小\(f\)数组,因为发现小\(f\)数组下标共\(2^4\)种,取值共\(2\)种,所以把每个下标对应的取值压到一起就好了,共\(2^{2^4}=2^{16}\)种状态。

转移时枚举\(Q\)的第\(i\)位是什么即可。

时间复杂度\(O(60\cdot2^{16}\cdot\text{一大堆常数})\)

参考代码:

//problem:loj6274
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

/*  ------  by:duyi  ------  */ // myt天下第一
ull P,Lx,Rx,Ly,Ry,dp[61][1<<16];
int state(int a,int b,int c,int d){
	return a+(b<<1)+(c<<2)+(d<<3);
}
#define forbit(i) for(int i=0;i<=1;++i)
int main() {
	cin>>P>>Lx>>Rx>>Ly>>Ry;
	dp[60][1<<state(0,0,0,0)]=1;
	for(int i=59;i>=0;--i)for(int s=0;s<(1<<16);++s)if(dp[i+1][s]){
		int p=((P>>i)&1ull);
		int lx=((Lx>>i)&1ull),rx=((Rx>>i)&1ull),ly=((Ly>>i)&1ull),ry=((Ry>>i)&1ull);
		forbit(q){
			int new_s=0;
			forbit(a)forbit(b)forbit(c)forbit(d)if(s&(1<<state(a,b,c,d))){
				forbit(x)forbit(y){
					if((x|y)!=p||(x&y)!=q)continue;
					if(!a&&x<lx)continue;
					if(!b&&x>rx)continue;
					if(!c&&y<ly)continue;
					if(!d&&y>ry)continue;
					new_s|=(1<<state(a||x>lx,b||x<rx,c||y>ly,d||y<ry));
				}
			}
			dp[i][new_s]+=dp[i+1][s];
		}
	}
	ull ans=0;
	for(int s=1;s<(1<<16);++s)ans+=dp[0][s];
	cout<<ans<<endl;
	return 0;
}
posted @ 2020-03-19 15:20  duyiblue  阅读(983)  评论(9编辑  收藏  举报