【HDOJ】5657 CA Loves Math
1. 题目描述
对于给定的$a, n, mod, a \in [2,11], n \in [0, 10^9], mod \in [1, 10^9]$求出在$[1, a^n]$内的所有$a$进制下的数并且不含重复数字。
2. 基本思路
这题比赛的时候,没人做出来,但是基本思路大家都有。显然可以直接将$n$改写为$\min(n,a)$。
我比赛的代码TLE,思路是这样的:首先$mod$很小时,可以数位DP解;当$mod$很大时,可以先找到所有的排列然后,然后令$delta = fact(a)/fact(a-n)$,然后以这个作为循环间隔找到满足不重复的数字,然后再判断是否是$mod$的倍数。
hack的时候,其实可以直接以$mod$作为阈值解。题解也提到了这个思路。
这样原问题可以分两种情况:
(1) 大于阈值,枚举$mod$的倍数,然后判断是否包含重复数字;
(2) 小于等于阈值,数位DP。
然后,赛后交还是wa了几次。这里有几个特殊情况需要单独考虑:
(1) n = 0时,只能取1,直接判断是否为$mod$倍数。
(2) n = 1时,可以取[1, a],同样需要判断是否为$mod$倍数。
并且,数位DP是累加DP的。即长度为$[1,n]$的满足条件的总和。
3. 代码
1 /* */ 2 #include <iostream> 3 #include <sstream> 4 #include <string> 5 #include <map> 6 #include <queue> 7 #include <set> 8 #include <stack> 9 #include <vector> 10 #include <deque> 11 #include <bitset> 12 #include <algorithm> 13 #include <cstdio> 14 #include <cmath> 15 #include <ctime> 16 #include <cstring> 17 #include <climits> 18 #include <cctype> 19 #include <cassert> 20 #include <functional> 21 #include <iterator> 22 #include <iomanip> 23 using namespace std; 24 //#pragma comment(linker,"/STACK:102400000,1024000") 25 26 #define sti set<int> 27 #define stpii set<pair<int, int> > 28 #define mpii map<int,int> 29 #define vi vector<int> 30 #define pii pair<int,int> 31 #define vpii vector<pair<int,int> > 32 #define rep(i, a, n) for (int i=a;i<n;++i) 33 #define per(i, a, n) for (int i=n-1;i>=a;--i) 34 #define clr clear 35 #define pb push_back 36 #define mp make_pair 37 #define fir first 38 #define sec second 39 #define all(x) (x).begin(),(x).end() 40 #define SZ(x) ((int)(x).size()) 41 #define lson l, mid, rt<<1 42 #define rson mid+1, r, rt<<1|1 43 #define INF 0x3f3f3f3f 44 #define mset(a, val) memset(a, (val), sizeof(a)) 45 46 #define LL __int64 47 48 const int bound = 23333; 49 const int maxn = 12; 50 int dp[1<<11][bound]; 51 int Bits[1<<11]; 52 vector<int> St[maxn]; 53 int Sz[maxn]; 54 int a, n, mod; 55 56 void solve(); 57 void _solve(); 58 59 inline int lowest(int x) { 60 return -x & x; 61 } 62 63 inline int getBits(int x) { 64 int ret = 0; 65 66 while (x) { 67 ++ret; 68 x -= lowest(x); 69 } 70 return ret; 71 } 72 73 void init() { 74 int mst = 1 << 11; 75 76 rep(i, 0, mst) { 77 Bits[i] = getBits(i); 78 St[Bits[i]].pb(i); 79 } 80 81 rep(i, 0, maxn) 82 Sz[i] = SZ(St[i]); 83 } 84 85 bool vis[12]; 86 inline bool judge(LL x) { 87 if (x == 0) return false; 88 89 memset(vis, false, sizeof(vis)); 90 while (x) { 91 int tmp = x % a; 92 if (vis[tmp]) 93 return false; 94 x /= a; 95 vis[tmp] = true; 96 } 97 return true; 98 } 99 100 void solve() { 101 if (n == 0) { 102 printf("%d\n", 1%mod==0 ? 1:0); 103 return ; 104 } 105 if (n == 1) { 106 int ans = 0; 107 rep(i, 1, a+1) 108 ans += i%mod == 0 ? 1:0; 109 printf("%d\n", ans); 110 return ; 111 } 112 113 n = min(n, a); 114 if (mod > bound) { 115 _solve(); 116 return ; 117 } 118 119 int mst = 1<<a; 120 memset(dp, 0, sizeof(dp)); 121 122 rep(j, 1, a) 123 ++dp[1<<j][j%mod]; 124 125 rep(l, 1, n) { 126 rep(j, 0, Sz[l]) { 127 const int st = St[l][j]; 128 if (st >= mst) 129 continue; 130 rep(k, 0, mod) { 131 const int& cnt = dp[st][k]; 132 if (cnt == 0) 133 continue; 134 135 rep(i, 0, a) { 136 if (st & (1<<i)) 137 continue; 138 139 int nst = st | (1<<i); 140 int nk = (k * a + i) % mod; 141 dp[nst][nk] += cnt; 142 } 143 } 144 } 145 } 146 147 int ans = 0; 148 149 rep(l, 1, n+1) { 150 rep(j, 0, Sz[l]) { 151 const int& st = St[l][j]; 152 ans += dp[st][0]; 153 } 154 } 155 156 printf("%d\n", ans); 157 } 158 159 LL Pow(LL base, int n) { 160 LL ret = 1; 161 162 while (n) { 163 if (n & 1) 164 ret = ret * base; 165 base = base * base; 166 n >>= 1; 167 } 168 169 return ret; 170 } 171 172 void _solve() { 173 LL tmp = mod, ubound = Pow(a, n); 174 int ans = 0; 175 176 while (tmp <= ubound) { 177 if (judge(tmp)) 178 ++ans; 179 tmp += mod; 180 } 181 182 printf("%d\n", ans); 183 } 184 185 int main() { 186 ios::sync_with_stdio(false); 187 #ifndef ONLINE_JUDGE 188 freopen("data.in", "r", stdin); 189 freopen("data.out", "w", stdout); 190 #endif 191 192 int t; 193 194 init(); 195 scanf("%d", &t); 196 while (t--) { 197 scanf("%d%d%d",&a,&n,&mod); 198 solve(); 199 } 200 201 #ifndef ONLINE_JUDGE 202 printf("time = %ldms.\n", clock()); 203 #endif 204 205 return 0; 206 }