食物(矩阵快速幂)(DP)
这个题。。我们可以想到用递推写!!qwq(好吧,其实我的DP水平不高啊qwq)
就是我们以两个为单位(一共九种组合情况),然后往后面推下一位的情况。
通过手动模拟,我们可以找到它们之间的递推关系(详见代码)
先放上我的暴力代码。。。。(60分)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define mod 1000000007
using namespace std;
long long f[2][10];
int t,cnt=1;
struct Node{int id,que;long long ans=0;}node[1010];
bool cmp1(struct Node x,struct Node y)
{
if(x.que<y.que) return 1;
else return 0;
}
bool cmp2(struct Node x,struct Node y)
{
if(x.id<y.id) return 1;
else return 0;
}
inline int read()
{
int f=1,x=0; char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch<='9'&&ch>='0')
{
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
int main()
{
freopen("food.in","r",stdin);
freopen("food.out","w",stdout);
t=read();
//scanf("%d",&t);
for(int i=1;i<=t;i++) node[i].que=read(),node[i].id=i;
//scanf("%d",&node[i].que),node[i].id=i;
sort(node+1,node+1+t,cmp1);
//cout<<endl;
//for(int i=1;i<=t;i++)
// printf("%d ",node[i].que);
f[0][1]=2;
f[0][2]=3;
f[0][3]=2;
f[0][4]=3;
f[0][5]=2;
f[0][6]=2;
f[0][7]=2;
f[0][8]=2;
f[0][9]=2;
while(node[cnt].que==1) node[cnt].ans=3,cnt++;
while(node[cnt].que==2) node[cnt].ans=9,cnt++;
while(node[cnt].que==3) node[cnt].ans=20,cnt++;
for(int i=1;i<=node[t].que-3;i++)
{
f[1][1]=(f[0][4]+f[0][8])%mod;
f[1][2]=(f[0][1]+f[0][4]+f[0][8])%mod;
f[1][3]=(f[0][1]+f[0][4])%mod;
f[1][4]=(f[0][2]+f[0][5]+f[0][7])%mod;
f[1][5]=(f[0][2]+f[0][7])%mod;
f[1][6]=(f[0][2]+f[0][5])%mod;
f[1][7]=(f[0][6]+f[0][9])%mod;
f[1][8]=(f[0][3]+f[0][9])%mod;
f[1][9]=(f[0][3]+f[0][6])%mod;
//for(int j=1;j<10;j++)
// printf("f[1][%d]=%lld\n",j,f[1][j]);
for(int j=1;j<10;j++)
swap(f[0][j],f[1][j]);
while(i==node[cnt].que-3)
{
long long ans=0;
for(int j=1;j<10;j++)
ans=(ans+f[0][j])%mod;
node[cnt].ans=ans;
// cout<<"ans="<<ans<<endl;
cnt++;
}
}
sort(node+1,node+1+t,cmp2);
for(int i=1;i<=t;i++)
printf("%lld\n",node[i].ans%mod);
return 0;
}
然后我们看到数据范围。。。好大呀qwq线性算法肯定会T啊qwq,那。。。。写矩阵加速吧!qwq
其实有了暴力程序之后矩阵很好写(就是把对应的行和列上面的数设成1,然后做一次矩阵乘法就相当于一次转移。
详见代码:
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
#define mod 1000000007
using namespace std;
int t;
int f[10];
struct Node{long long m[10][10];}node;
inline Node mul(Node x,Node y)
{
Node cur;
for(int i=1;i<=9;i++)
for(int j=1;j<=9;j++)
cur.m[i][j]=0;
for(int i=1;i<=9;i++)
for(int j=1;j<=9;j++)
for(int k=1;k<=9;k++)
cur.m[i][j]=(cur.m[i][j]+x.m[i][k]*y.m[k][j])%mod;
return cur;
}
inline void solve(Node x)
{
int cur[10];
memset(cur,0,sizeof(cur));
for(int j=1;j<=9;j++)
for(int k=1;k<=9;k++)
cur[j]=(cur[j]+f[k]*x.m[k][j])%mod;
for(int i=1;i<=9;i++)
f[i]=cur[i];
}
int main()
{
scanf("%d",&t);
while(t--)
{
for(int i=1;i<=9;i++)
for(int j=1;j<=9;j++)
node.m[i][j]=0;
node.m[1][4]=1,node.m[1][8]=1;
node.m[2][1]=1,node.m[2][4]=1,node.m[2][8]=1;
node.m[3][1]=1,node.m[3][4]=1;
node.m[4][2]=1,node.m[4][5]=1,node.m[4][7]=1;
node.m[5][2]=1,node.m[5][7]=1;
node.m[6][2]=1,node.m[6][5]=1;
node.m[7][6]=1,node.m[7][9]=1;
node.m[8][3]=1,node.m[8][9]=1;
node.m[9][3]=1,node.m[9][6]=1;
memset(f,0,sizeof(f));
int n;
for(int i=1;i<=9;i++) f[i]=1;
scanf("%d",&n);
n-=2;
while(n)
{
if(n&1) solve(node);
node=mul(node,node);
n>>=1;
}
long long ans=0;
for(int i=1;i<=9;i++)
ans=(ans+f[i])%mod;
printf("%lld\n",ans%mod);
}
return 0;
}