Luogu 3943 星空
很妙的题。
首先我们发现区间操作不太好弄,我们想办法把它转化成单点操作,这样子处理的办法会多一点。
方法当然是差分了。
定义差分数组$b_i = a_i \^ a_{i + 1}$($b_i$的下标从$0$开始),在这里将$\^$记为异或。
那么$a_i = b_0 \^ b_1 \^ b_2 \^ ... \^ b_i$,如果我们将所有一开始没有点亮的灯记为$1$,而将所有点亮的灯记为$0$,如果要点亮所有灯,那么我们最后要使$\forall i \in [0, n] \ a_i == 0 $成立,就相当于使$\forall i \in [0, n] \ b_i == 0 $成立。
这样子我们改一段区间$[l, r]$的时候就相当于把$b_{l - 1}, b_{r}$都异或上$1$。
发现这样子最多不会超过有$2k$个为$1$的位置,直接状压起来。
我们可以推出$dp$了,假设$f_s$表示当前选的$s$集合最少需要多少操作的步数,那么$f_s + cost(i, j)$可以更新$f_{s \cup i \cup j\ (i \notin s,j \notin s)}$。
如果能预处理这个$cost(i, j)$,那么这个转移就可以写成$O(k2^k)$的,顺便一提在本题中写成$O(k^2 * 2^k)$也是能过的。
发现我们在选择点在更新的时候一定至少会选择有一个$0$的去更新(因为更新两个$1$没有意义……)。因此我们如果要更新$i$和$j$,假如$j - i$恰好等于一个可以更新的区间长度,那么只需要一步;但是如果没有怎么办,我们需要“绕路走”,引入一个或多个中间点$k$来使$i$和$j$同时更新,发现这样子所有的$k$都被异或了偶数次,所以最后的结果仍然只有$i$和$j$被更新。
发现了吧,这是一个最短路,预处理的时候对每一个点跑一遍$dij$或者$spfa$即可。
我实现的代码是$O(k^2 * 2^k)$的。
Code:
#include <cstdio> #include <cstring> #include <queue> #include <iostream> using namespace std; typedef pair <int, int> pin; const int N = 1e4 + 5; const int M = 105; const int W = 22; const int S = (1 << 20) + 5; const int inf = 0x3f3f3f3f; int n, m, K, a[N], b[N], idCnt = 0, id[W]; int len[M], dis[N], c[W][W], f[S]; bool vis[N]; inline void read(int &X) { X = 0; char ch = 0; int op = 1; for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') op = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } priority_queue <pin> Q; void dij(int st) { memset(dis, 0x3f, sizeof(dis)); memset(vis, 0, sizeof(vis)); Q.push(pin(dis[st] = 0, st)); for(; !Q.empty(); ) { int x = Q.top().second; Q.pop(); if(vis[x]) continue; vis[x] = 1; for(int i = 1; i <= m; i++) { int y = x + len[i]; if(y <= n && dis[y] > dis[x] + 1) { dis[y] = dis[x] + 1; Q.push(pin(-dis[y], y)); } y = x - len[i]; if(y >= 0 && dis[y] > dis[x] + 1) { dis[y] = dis[x] + 1; Q.push(pin(-dis[y], y)); } } } } inline void chkMin(int &x, int y) { if(y < x) x = y; } int main() { read(n), read(K), read(m); for(int pos, i = 1; i <= K; i++) read(pos), a[pos] = 1; for(int i = 1; i <= m; i++) read(len[i]); for(int i = 0; i <= n; i++) b[i] = a[i] ^ a[i + 1]; for(int i = 0; i <= n; i++) if(b[i]) id[++idCnt] = i; /* for(int i = 1; i <= idCnt; i++) printf("%d ", id[i]); printf("\n"); */ memset(c, 0x3f, sizeof(c)); for(int i = 1; i <= idCnt; i++) { dij(id[i]); for(int j = 1; j <= idCnt; j++) if(i != j) c[i][j] = dis[id[j]]; } /* for(int i = 1; i <= idCnt; i++, printf("\n")) for(int j = 1; j <= idCnt; j++) printf("%d ", c[i][j]); */ memset(f, 0x3f, sizeof(f)); f[0] = 0; for(int s = 0; s < (1 << idCnt); s++) { if(f[s] == inf) continue; for(int i = 1; i <= idCnt; i++) { if((s >> (i - 1)) & 1) continue; for(int j = 1; j <= idCnt; j++) { if((s >> (j - 1)) & 1) continue; int to = s | (1 << (i - 1)) | (1 << (j - 1)); chkMin(f[to], f[s] + c[i][j]); } } } int curS = (1 << idCnt) - 1; if(f[curS] == inf) puts("-1"); else printf("%d\n", f[curS]); return 0; }