题目链接
https://atcoder.jp/contests/agc035/tasks/agc035_d
题解
想了两小时憋出来一个状压DP,发现人家怎么空间才十几MB,原来暴力就行了。。。
考虑原序列那个操作,我们可以建一个图,一开始有\(n\)个点没有边,每次选一个点向其左右未被选的点加两条边,一个点的贡献次数就是它到左右两个终点的路径数。
那么我们可以枚举最后选的点,把原序列分裂成两个区间,因此使用区间DP. 现在的问题是我们需要方便地统计一个点对总和的贡献。我们观察到,从\([l,r]\)这一层到\([1,n]\)最底层,每次要么拓展左边(左端点\(l\)连向新的\(l'\),\(r\)连向\(l'\)),要么拓展右边,这样总共的过程可以表达为一个长度不超过\((n-1)\)的01
串,且从\(l\)或\(r\)到达\(1\)或\(n\)的方案数可由这个串得到(具体地,维护两个变量\(x=1,y=1\), 拓展左边时\((x,y)\rightarrow (x+y,y)\), 拓展右边时\((x,y)\rightarrow (x,x+y)\))。那么预处理每个01
串对应的系数,就可以快速计算了。最终的DP状态是,\(f[l][r][k][S]\)表示区间\([l,r]\),01
串长度为\(k\),串本身为\(S\). 转移枚举图上这个区间里位置最低的点(也就是最后操作的点)就行了。
总共状态数是\(\sum^n_{i=1}2^i(n-i+1)=O(2^nn)\)的,转移需要\(O(n)\)的复杂度(其实远远不满),总时间复杂度\(O(2^nn^2)\).
但是如果我们设\(f[l][r][k][S]\)记忆化的话空间复杂度就变成\(O(2^nn^3)\)了,怎么办?智障的我就把后两维压到一起了,卡着空间限制过了……
然而事实是重复遍历一个状态的情况很少,不记忆化不仅能过,而且速度还快一倍。(如果不记忆化的话不需要预处理数组,把\(S\)串直接改成左右端点分别被算几次就行了)
实测效果如下:
(上:不记忆化 下:记忆化)
枯了
代码
记忆化:
#include<bits/stdc++.h>
#define llong long long
#define mkpr make_pair
#define riterator reverse_iterator
#define U ((1<<n)-1)
using namespace std;
inline int read()
{
int x = 0,f = 1; char ch = getchar();
for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
return x*f;
}
const int N = 18;
const llong INF = 1e17;
llong f[N+1][N+1][(1<<N)+3];
llong val[(1<<N)+3];
llong a[N+3];
int n;
void updmin(llong &x,llong y) {x = min(x,y);}
llong dfs(int l,int r,int sta)
{
if(f[l][r][sta]<INF) {return f[l][r][sta];}
if(r-l<=1) return f[l][r][sta]=0ll;
for(int i=l+1; i<r; i++)
{
updmin(f[l][r][sta],dfs(l,i,((sta<<1)|1)&U)+dfs(i,r,(sta<<1)&U)+a[i]*val[sta]);
}
return f[l][r][sta];
}
int main()
{
scanf("%d",&n);
for(int i=0; i<(1<<n)-1; i++)
{
int len = n; while(i&(1<<len-1)) {len--;} len--;
llong x = 1ll,y = 1ll; for(int j=0; j<len; j++) i&(1<<j)?y+=x:x+=y; val[i] = x+y;
}
for(int i=0; i<n; i++) scanf("%lld",&a[i]);
memset(f,10,sizeof(f));
printf("%lld\n",dfs(0,n-1,(1<<n)-2)+a[0]+a[n-1]);
return 0;
}
不记忆化
#include<bits/stdc++.h>
#define llong long long
#define mkpr make_pair
#define riterator reverse_iterator
#define U ((1<<n)-1)
using namespace std;
inline int read()
{
int x = 0,f = 1; char ch = getchar();
for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
return x*f;
}
const int N = 18;
const llong INF = 1e17;
llong a[N+3];
int n;
void updmin(llong &x,llong y) {x = min(x,y);}
llong dfs(int l,int r,llong x,llong y)
{
if(r-l<=1) return 0ll;
llong ret = INF;
for(int i=l+1; i<r; i++)
{
updmin(ret,dfs(l,i,x,x+y)+dfs(i,r,x+y,y)+a[i]*(x+y));
}
return ret;
}
int main()
{
scanf("%d",&n);
for(int i=0; i<n; i++) scanf("%lld",&a[i]);
printf("%lld\n",dfs(0,n-1,1ll,1ll)+a[0]+a[n-1]);
return 0;
}