【HDOJ】5657 CA Loves Math

1. 题目描述
对于给定的$a, n, mod, a \in [2,11], n \in [0, 10^9], mod \in [1, 10^9]$求出在$[1, a^n]$内的所有$a$进制下的数并且不含重复数字。

2. 基本思路
这题比赛的时候,没人做出来,但是基本思路大家都有。显然可以直接将$n$改写为$\min(n,a)$。
我比赛的代码TLE,思路是这样的:首先$mod$很小时,可以数位DP解;当$mod$很大时,可以先找到所有的排列然后,然后令$delta = fact(a)/fact(a-n)$,然后以这个作为循环间隔找到满足不重复的数字,然后再判断是否是$mod$的倍数。
hack的时候,其实可以直接以$mod$作为阈值解。题解也提到了这个思路。
这样原问题可以分两种情况:
(1) 大于阈值,枚举$mod$的倍数,然后判断是否包含重复数字;
(2) 小于等于阈值,数位DP。
然后,赛后交还是wa了几次。这里有几个特殊情况需要单独考虑:
(1) n = 0时,只能取1,直接判断是否为$mod$倍数。
(2) n = 1时,可以取[1, a],同样需要判断是否为$mod$倍数。
并且,数位DP是累加DP的。即长度为$[1,n]$的满足条件的总和。

3. 代码

  1 /*  */
  2 #include <iostream>
  3 #include <sstream>
  4 #include <string>
  5 #include <map>
  6 #include <queue>
  7 #include <set>
  8 #include <stack>
  9 #include <vector>
 10 #include <deque>
 11 #include <bitset>
 12 #include <algorithm>
 13 #include <cstdio>
 14 #include <cmath>
 15 #include <ctime>
 16 #include <cstring>
 17 #include <climits>
 18 #include <cctype>
 19 #include <cassert>
 20 #include <functional>
 21 #include <iterator>
 22 #include <iomanip>
 23 using namespace std;
 24 //#pragma comment(linker,"/STACK:102400000,1024000")
 25 
 26 #define sti                set<int>
 27 #define stpii            set<pair<int, int> >
 28 #define mpii            map<int,int>
 29 #define vi                vector<int>
 30 #define pii                pair<int,int>
 31 #define vpii            vector<pair<int,int> >
 32 #define rep(i, a, n)     for (int i=a;i<n;++i)
 33 #define per(i, a, n)     for (int i=n-1;i>=a;--i)
 34 #define clr                clear
 35 #define pb                 push_back
 36 #define mp                 make_pair
 37 #define fir                first
 38 #define sec                second
 39 #define all(x)             (x).begin(),(x).end()
 40 #define SZ(x)             ((int)(x).size())
 41 #define lson            l, mid, rt<<1
 42 #define rson            mid+1, r, rt<<1|1
 43 #define INF                0x3f3f3f3f
 44 #define mset(a, val)    memset(a, (val), sizeof(a))
 45 
 46 #define LL __int64
 47 
 48 const int bound = 23333;
 49 const int maxn = 12;
 50 int dp[1<<11][bound];
 51 int Bits[1<<11];
 52 vector<int> St[maxn];
 53 int Sz[maxn];
 54 int a, n, mod;
 55 
 56 void solve();
 57 void _solve();
 58 
 59 inline int lowest(int x) {
 60     return -x & x;
 61 }
 62 
 63 inline int getBits(int x) {
 64     int ret = 0;
 65     
 66     while (x) {
 67         ++ret;
 68         x -= lowest(x);
 69     }
 70     return ret;
 71 }
 72 
 73 void init() {
 74     int mst = 1 << 11;
 75     
 76     rep(i, 0, mst) {
 77         Bits[i] = getBits(i);
 78         St[Bits[i]].pb(i);
 79     }
 80     
 81     rep(i, 0, maxn)
 82         Sz[i] = SZ(St[i]);
 83 }
 84 
 85 bool vis[12];
 86 inline bool judge(LL x) {
 87     if (x == 0)    return false;
 88     
 89     memset(vis, false, sizeof(vis));
 90     while (x) {
 91         int tmp = x % a;
 92         if (vis[tmp])
 93             return false;
 94         x /= a;
 95         vis[tmp] = true;
 96     }
 97     return true;
 98 }
 99 
100 void solve() {
101     if (n == 0) {
102         printf("%d\n", 1%mod==0 ? 1:0);
103         return ;
104     }
105     if (n == 1) {
106         int ans = 0;
107         rep(i, 1, a+1)
108             ans += i%mod == 0 ? 1:0;
109         printf("%d\n", ans);
110         return ;
111     }
112     
113     n = min(n, a);
114     if (mod > bound) {
115         _solve();
116         return ;
117     }
118 
119     int mst = 1<<a;
120     memset(dp, 0, sizeof(dp));
121     
122     rep(j, 1, a)
123         ++dp[1<<j][j%mod];
124     
125     rep(l, 1, n) {
126         rep(j, 0, Sz[l]) {
127             const int st = St[l][j];
128             if (st >= mst)
129                 continue;
130             rep(k, 0, mod) {
131                 const int& cnt = dp[st][k];
132                 if (cnt == 0)
133                     continue;
134                 
135                 rep(i, 0, a) {
136                     if (st & (1<<i))
137                         continue;
138                     
139                     int nst = st | (1<<i);
140                     int nk = (k * a + i) % mod;
141                     dp[nst][nk] += cnt;
142                 }
143             }
144         }
145     }
146     
147     int ans = 0;
148     
149     rep(l, 1, n+1) {
150         rep(j, 0, Sz[l]) {
151             const int& st = St[l][j];
152             ans += dp[st][0];
153         }
154     }
155     
156     printf("%d\n", ans);
157 }
158 
159 LL Pow(LL base, int n) {
160     LL ret = 1;
161     
162     while (n) {
163         if (n & 1)
164             ret = ret * base;
165         base = base * base;
166         n >>= 1;
167     }
168     
169     return ret;
170 }
171 
172 void _solve() {
173     LL tmp = mod, ubound = Pow(a, n);
174     int ans = 0;
175     
176     while (tmp <= ubound) {
177         if (judge(tmp))
178             ++ans;
179         tmp += mod;
180     }
181     
182     printf("%d\n", ans);
183 }
184 
185 int main() {
186     ios::sync_with_stdio(false);
187     #ifndef ONLINE_JUDGE
188         freopen("data.in", "r", stdin);
189         freopen("data.out", "w", stdout);
190     #endif
191 
192     int t;
193 
194     init();
195     scanf("%d", &t);
196     while (t--) {
197         scanf("%d%d%d",&a,&n,&mod);
198         solve();
199     }
200 
201     #ifndef ONLINE_JUDGE
202         printf("time = %ldms.\n", clock());
203     #endif
204 
205     return 0;
206 }

 

posted on 2016-04-02 22:29  Bombe  阅读(269)  评论(0编辑  收藏  举报

导航