[UOJ UR#16]破坏发射台

来自FallDream的博客,未经允许,请勿转载,谢谢。


传送门

 

先考虑n是奇数的情况,很容易想到一个dp,f[i][0/1]表示转移到第i个数,第i个数是不是第一个数的方案数,然后用矩阵乘法优化一下就好了。

然后考虑n是偶数的情况,发现可以把圈分成两个半圆,dp就多了几维,需要表示两个数分别是是第一个数/第二个数/都不是的方案数。

#include<iostream>
#include<cstring>
#include<cstdio>
#define MN 10
#define mod 998244353
using namespace std;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
struct Matrix
{
    int s[MN+1][MN+1];    
    Matrix(){memset(s,0,sizeof(s));}
    Matrix operator * (const Matrix&b)
    {
        Matrix c;    
        for(int i=0;i<=MN;++i)
            for(int j=0;j<=MN;++j)
                for(int k=0;k<=MN;++k)
                    c.s[i][j]=(c.s[i][j]+1LL*s[i][k]*b.s[k][j])%mod;
        return c;
    }
}A,B;
int n,m;

int Calc(int x,int y)
{
    if(!x)
    {
        if(!y) return m-3;    
        else return 1;
    }
    if(x==1)
    {
        if(y==0) return m-2;
        else if(y==2) return 1;
        else return 0;     
    }
    if(x==2)
    {
        if(y==0) return m-2;    
        else if(y==1) return 1;
        else return 0;        
    }
}

int main()
{
    n=read();m=read();
    if(m==1) return 0*printf("%d\n",n==1);
    if(n&1)
    {
        B.s[1][0]=1;
        A.s[1][0]=1;
        A.s[0][0]=m-2;
        A.s[0][1]=m-1;
        for(--n;n;n>>=1,A=A*A) if(n&1) B=A*B;
        printf("%d\n",1LL*B.s[0][0]*m%mod);
    }
    else
    {
        if(n==4&&m==2) return 0*puts("0");
        if(m==2) return 0*puts("2");B.s[5][0]=1;
        for(int i=0;i<3;++i)
            for(int j=0;j<3;++j) if(i!=j||i==0) 
                for(int k=0;k<3;++k)    
                    for(int l=0;l<3;++l) if(k!=l||k==0)    
                    {
                        A.s[k*3+l][i*3+j]=1LL*Calc(i,k)*Calc(j,l)%mod;
                        if(k==0&&l==0) 
                        {
                            if(1LL*(m-2)*(m-2)+m-2-(i==0)-(j==0)<0) A.s[k*3+l][i*3+j]=0;
                            else (A.s[k*3+l][i*3+j]+=mod-(m-2-(i==0)-(j==0)))%=mod;
                        }
                    }
        int ans=0;
        for(n>>=1,--n;n;n>>=1,A=A*A) if(n&1) B=A*B;
        for(int i=0;i<9;++i) 
        {
            int j=i%3,k=i/3;
            if(j==1||k==2) continue;
            ans=(ans+B.s[i][0])%mod;
        }
        printf("%d\n",1LL*ans*m%mod*(m-1)%mod);
    }
    return 0;
}

 

posted @ 2017-09-04 19:54  FallDream  阅读(328)  评论(0编辑  收藏  举报