hdu5785--Interesting(manacher)
题意:求给定字符串的三元组(I,J,K) 使得S[i..j] 和 S[j+1..k] 都是回文串。求所有满足条件的三元组 ∑(i*k)
题解:求出以j为结尾的回文串起始位置的和记为lv[j],和以j+1为开始的回文串末位置的和rv[j+1]
答案就是∑[j:1-n](lv[j] * rv[j+1])
因为……
(a+b+c....)*(x+y+z.....) = a*x + a*y + a*z + ....
看了题解之后才恍然大悟ˊ_>ˋ有多蠢
然后就是自己写的代码
一直wa,以为哪里没有取模,瞪了一个小时,发现,哦,有一个除法,÷2,应该算逆元
天啦噜。。。
看到很多人分了奇偶,我也没想那么多,感觉是一样的,可能效率差一些吧……
我的想法是对于每一个i,它所能到达的地方就是,i+mp[i](manacher中数组),那么对于所有它能到达的位置,设为j,j所对应的起始位置就是i*2-j,于是每次只要把所能到达的点加i,记为rv[],也就是rv[j]+i, 每个点所有前面点的贡献值就是rv[j]*2-j*ti(所能到达j点的次数)
#include <cmath> #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #define ll long long using namespace std; const int N = 3000010; const ll MOD = 1000000007LL; const ll inv = 500000004; char str[N]; char ma[N]; int mp[N]; ll ti[N], lv[N], rv[N]; int Manacher() { int len = strlen(str); int l = 0; ma[l++] = '$'; ma[l++] = '#'; for (int i = 0 ; i < len ; i++) { ma[l++] = str[i]; ma[l++] = '#'; } ma[l] = 0; int mx = 0,id = 0; for (int i = 0 ; i < l ; i++) { mp[i] = mx > i ? min(mp[2 * id - i], mx - i) : 1; while (ma[i + mp[i]] == ma[i - mp[i]]) mp[i]++; if (i + mp[i] > mx) { mx = i + mp[i]; id = i; } } return l; } inline void up(ll &x, ll y) { x += y; if (x >= MOD) x -= MOD; if (x < 0) x += MOD; } ll solve() { int l = Manacher(); memset(lv, 0, sizeof lv); memset(ti, 0, sizeof ti); for (int i = 1; i < l; ++i) { up(lv[i], i); up(lv[i+mp[i]], -i); ti[i]++; ti[i+mp[i]]--; } for (int i = 1; i < l; ++i) { up(lv[i], lv[i-1]); up(ti[i], ti[i-1]); } for (int i = 1; i < l; ++i) { lv[i] = ((lv[i] * 2 % MOD - ti[i] * i % MOD) % MOD + MOD) % MOD; } memset(rv, 0, sizeof rv); memset(ti, 0, sizeof ti); for (int i = l-1; i > 0; --i) { up(rv[i], i); up(rv[i-mp[i]], -i); ti[i]++; ti[i-mp[i]]--; } for (int i = l-1; i > 0; --i) { up(rv[i], rv[i+1]); up(ti[i], ti[i+1]); } for (int i = l-1; i > 0; --i) { rv[i] = ((rv[i] * 2 % MOD - ti[i] * i % MOD) + MOD) % MOD; } ll ans = 0; for (int i = 2; i < l; i += 2) { ans = (ans + (lv[i] * inv % MOD) * (rv[i+2] * inv % MOD) % MOD) % MOD; } return ans; } int main() { //freopen("in", "r", stdin); while (~scanf("%s",str)) { cout << solve() << endl; } return 0; }