题解 浑水摸鱼

传送门

求最小表示法意义下的不同子串数

当字符集很小时存在一个做法是枚举映射
设字符集大小为 \(m\)
则若我们保证了一个子串每种 原字符 和 最小表示法下对应的字符 的映射都在原串中出现过,
那答案就是本质不同子串数除以 \(m!\)
为了保证每种映射都出现过我们可以枚举映射将原串映射一遍接在后面
但这样在子串中出现的字符数不是 \(m\) 时会出问题
所以要算子串中出现次数 \(\leqslant x\) 的再容斥
这时问题转化为求出现字符数 \(\leqslant x\) 的不同子串数量
听说 SA 可以求,但我先跑路了

正解考虑魔改一下 SA 求不同子串数
当比较两个后缀的字典序时,将这两个后缀都转成最小表示法比较
并且比较后缀的时候可以二分+hash 优化
暴力预处理 hash 可以做到 \(O(n\log^2*size)\)

再优化一下:考虑一种与这个字符到底是什么无关的 hash
\(h(a_i)=nxt_{a_i}*base^i\),其中 \(nxt_{k}\) 为字符 \(k\) 下一次出现的位置
可以主席树求出,于是可以优化到 \(O(n\log^3 n)\)

但本题卡常
一个可能的优化是主席树维护 hash 的时候不用 val(p1)=val(ls(p1))+val(rs(p1))*pw[mid-tl+1]; 合并
相反的,我们令位置 \(i\) 的初始值为 \(a_i*base^i\)
这样查询区间 hash 值就是区间和乘 \(base^{-l}\)

  • 欧拉定理:若 \(\gcd(a, m)=1\),则 \(a^{\varphi(m)}\equiv 1\pmod m\)
    注意一些不太显然的事情是 \(\varphi(2^{64})=2^{63}\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define fir first
#define sec second
#define ll long long
#define ull unsigned long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
int a[N];

// namespace force{
// 	const ull base=13131;
// 	int id[N], rub[N], rtop;
// 	// unordered_map<ull, bool> mp;
// 	struct hash_map{
// 		static const int SIZE=10000010;
// 		int head[SIZE], ecnt;
// 		hash_map(){memset(head, -1, sizeof(head));}
// 		struct edge{ull val; int next;}e[SIZE*10];
// 		inline int size() {return ecnt;}
// 		inline void insert(ull t) {
// 			int h=t*13131%SIZE;
// 			for (int i=head[h]; ~i; i=e[i].next)
// 				if (e[i].val==t) return ;
// 			// if (ecnt>SIZE-10) {puts("error"); exit(0);}
// 			e[++ecnt]={t, head[h]}; head[h]=ecnt;
// 		}
// 	}mp;
// 	void solve() {
// 		for (int i=1; i<=n; ++i) {
// 			while (rtop) id[rub[rtop--]]=0;
// 			ull h=0; int tot=0;
// 			for (int j=i; j<=n; ++j) {
// 				if (!id[a[j]]) id[a[j]]=++tot, rub[++rtop]=a[j];
// 				h=h*base+id[a[j]];
// 				// mp[h]=1;
// 				mp.insert(h);
// 			}
// 		}
// 		cout<<mp.size()<<endl;
// 	}
// }

namespace task{
	ll ans;
	vector<int> s[N];
	const ull base=13131;
	// map<pair<int, int>, int> mp;
	int lst[N], sa[N], h[N], cnt;
	ull val[N*10], pw[N*10], inv[N];
	int rot[N*10], lson[N*10], rson[N*10], tot;
	#define val(p) val[p]
	#define ls(p) lson[p]
	#define rs(p) rson[p]
	inline ull qpow(ull a, ull b) {ull ans=1; for (; b; a=a*a,b>>=1) if (b&1) ans=ans*a; return ans;}
	void upd(int& p1, int p2, int tl, int tr, int pos, int dat) {
		p1=++tot; ls(p1)=ls(p2); rs(p1)=rs(p2);
		if (tl==tr) {val(p1)=dat*pw[tl]; return ;}
		int mid=(tl+tr)>>1;
		if (pos<=mid) upd(ls(p1), ls(p2), tl, mid, pos, dat);
		else upd(rs(p1), rs(p2), mid+1, tr, pos, dat);
		val(p1)=val(ls(p1))+val(rs(p1)); //*pw[mid-tl+1];
	}
	ull query(int p, int tl, int tr, int l, int r) {
		if (!p) return 0;
		if (l<=tl&&r>=tr) return val(p);
		int mid=(tl+tr)>>1;
		if (l<=mid&&r>mid) return query(ls(p), tl, mid, l, r)+query(rs(p), mid+1, tr, l, r); //*pw[mid-(tl>=l?tl:l)+1];
		else if (l<=mid) return query(ls(p), tl, mid, l, r);
		else return query(rs(p), mid+1, tr, l, r);
	}
	struct hash_map{
		static const int SIZE=1000010;
		int head[SIZE], ecnt;
		hash_map(){memset(head, -1, sizeof(head));}
		struct edge{pair<int, int> val; int dat, next;}e[SIZE*10];
		inline int end() {return -1;}
		inline int find(pair<int, int> t) {
			int h=(t.fir*13131ll+t.sec*1919810ll)%SIZE;
			for (int i=head[h]; ~i; i=e[i].next)
				if (e[i].val==t) return 1;
			return -1;
		}
		inline int& operator [] (pair<int, int> t) {
			int h=(t.fir*13131ll+t.sec*1919810ll)%SIZE;
			for (int i=head[h]; ~i; i=e[i].next)
				if (e[i].val==t) return e[i].dat;
			e[++ecnt]={t, 0, head[h]}; head[h]=ecnt;
			return e[ecnt].dat;
		}
	}mp;
	int lcp(int i, int j) {
		if (i>j) swap(i, j);
		if (mp.find({i, j})!=mp.end()) return mp[{i, j}];
		int len=min(n-i+1, n-j+1);
		int l=1, r=len, mid;
		while (l<=r) {
			mid=(l+r)>>1;
			if (query(rot[i+mid-1], 1, n, i, i+mid-1)*inv[i]==query(rot[j+mid-1], 1, n, j, j+mid-1)*inv[j]) l=mid+1;
			else r=mid-1;
		}
		return mp[{i, j}]=l-1;
		// return l-1;
	}
	bool cmp(int i, int j) {
		int len=min(n-i+1, n-j+1), mid=lcp(i, j);
		if (mid==len) return i>j;
		// return (*s[a[i+mid]].lower_bound(i)-i)<(*s[a[j+mid]].lower_bound(j)-j);
		return (*lower_bound(s[a[i+mid]].begin(), s[a[i+mid]].end(), i)-i)<(*lower_bound(s[a[j+mid]].begin(), s[a[j+mid]].end(), j)-j);
	}
	void solve() {
		pw[0]=inv[0]=1; ans=1ll*n*(n+1)/2;
		ull iv=qpow(13131, (1ull<<63)-1);
		cout<<(1ull<<63)<<endl; exit(0);
		for (int i=1; i<=n; ++i) pw[i]=pw[i-1]*base;
		for (int i=1; i<=n; ++i) inv[i]=inv[i-1]*iv;
		for (int i=1; i<=n; ++i) s[a[i]].push_back(i);
		for (int i=1; i<=n; ++i) {
			if (lst[a[i]]) upd(rot[i], rot[i-1], 1, n, lst[a[i]], i-lst[a[i]]);
			else rot[i]=rot[i-1];
			lst[a[i]]=i;
		}
		for (int i=1; i<=n; ++i) sa[i]=i;
		stable_sort(sa+1, sa+n+1, cmp);
		// cout<<"sa: "; for (int i=1; i<=n; ++i) cout<<sa[i]<<' '; cout<<endl;
		for (int i=2; i<=n; ++i) ans-=lcp(sa[i-1], sa[i]); //, cout<<lcp(sa[i-1], sa[i])<<endl;
		cout<<ans<<endl;
		// cout<<lcp(1, 2)<<endl;
		// cout<<query(rot[4], 1, n, 2, 4)<<' '<<query(rot[5], 1, n, 3, 5)<<endl;
		// cout<<"cnt: "<<cnt<<endl;
	}
}

signed main()
{
	freopen("waterflow.in", "r", stdin);
	freopen("waterflow.out", "w", stdout);

	n=read();
	bool all_one=1;
	for (int i=1; i<=n; ++i) if ((a[i]=read())!=1) all_one=0;
	// if (all_one) cout<<n<<endl;
	// else force::solve();
	task::solve();
	
	return 0;
}
posted @ 2022-03-10 17:09  Administrator-09  阅读(3)  评论(0编辑  收藏  举报