Codeforces997C Sky Full of Stars 【FMT】【组合数】

题目大意:

一个$n*n$的格子,每个格子由你填色,有三种允许填色的方法,问有一行或者一列相同的方案数。

题目分析:

标题的FMT是我吓人用的。

一行或一列的问题不好解决,转成它的反面,没有一行和一列相同的方案数。

从一个方向入手,比如列,把一列看成一个整体。把颜色看成二进制数,$001$,$010$,$100$。

那么一列构成了一个长度为$3n$的二进制数,$n$列之间互相与出来的结果为$0$。实际我要统计这个东西。

注意到每一列的取法是不能取相同颜色的,所以剔除相同。之后我们得到了每一列可选的情况。

将它做FMT,之后做$n$次方,然后做IFMT,$0$位上的就是答案。用$9^n$减去这个数字就行。

直接做的时间复杂度是$O(n*2^n)$的,我们远不能承受。

但是我们有用的状态却不多,甚至还有规律。比如FMT后的某个位$bit$如果每三位出现两个$1$那么这个的FMT值一定是$0$,然后如果每三位只有$1$个$1$那么该位贡献$1$次,否则贡献$3$次。

然后是IFMT的还原问题,经过观察,不难发现如果某个位$bit$的$1$的个数为奇数,那么对$0$位产生减的影响,否则产生加的影响。

综合上面两个因素,可以利用组合数来统计方案数。值得注意的是如果每三位的1的位置相同那么要提防填充出相同结果。

时间复杂度$O(n)$

代码:

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 
 4 const int maxn = 1020000;
 5 const int mod = 998244353;
 6 
 7 int n;
 8 
 9 int f[maxn][2],g[maxn][2];
10 int pw3[maxn];
11 int c[maxn];
12 
13 int fast_pow(int now,int pw){
14     int ans = 1,dt = now,bit = 1;
15     while(bit <= pw){
16         if(bit & pw) ans = (1ll*ans*dt)%mod;
17         bit <<=1;dt = (1ll*dt*dt)%mod;
18     }
19     return ans;
20 }
21 
22 void work(){
23     if(n == 1) {puts("3");return;}
24     pw3[0] = 1;for(int i=1;i<=n;i++) pw3[i] = 3ll*pw3[i-1]%mod;
25     c[0] = 1;
26     for(int i=1;i<=n;i++){
27         c[i]=(1ll*c[i-1]*(n-i+1))%mod;
28         c[i]=(1ll*c[i]*fast_pow(i,mod-2))%mod;
29     }
30     int sum = fast_pow(pw3[n],n);
31     f[0][1] = 1; g[0][1] = (pw3[n]-3+mod)%mod;
32     f[n][0] = (pw3[n]-3+mod)%mod; g[n][0] = 1;
33     for(int i=1;i<n;i++){
34         f[i][1] = 3ll*c[i]%mod;
35         f[i][0] = (1ll*pw3[i]*c[i])%mod;
36         f[i][0] -= f[i][1]; f[i][0] += mod; f[i][0] %= mod;
37         g[i][0] = pw3[n-i];    g[i][1] = (pw3[n-i]-1+mod)%mod;
38     }
39     for(int i=0;i<=n;i++){
40         g[i][0] = fast_pow(g[i][0],n); g[i][1] = fast_pow(g[i][1],n);
41         int dr = ((i&1)?1:-1);
42         sum += dr*(1ll*g[i][0]*f[i][0])%mod;
43         if(sum >= mod) sum-=mod; if(sum < 0) sum += mod;
44         sum += dr*(1ll*g[i][1]*f[i][1])%mod;
45         if(sum >= mod) sum-=mod; if(sum < 0) sum += mod;
46     }
47     printf("%d",sum);
48     
49 }
50 
51 int main(){
52     scanf("%d",&n);
53     work();
54     return 0;
55 }

 

posted @ 2018-07-07 16:26  menhera  阅读(322)  评论(0编辑  收藏  举报