题解 束/Light

传送门
传送门

Q:我的学生都不改题怎么办?
A:把没人改的题在几天后的模拟赛里再考一次,记得提前把前一场的题解给他们

于是就改了一场题(

感性发现貌似有好多点的答案都是一样的
并且发现貌似好多点都是等价的
仔细思考发现相邻两个障碍之间的线段中的点是等价的,一定可以通过同样的代价到达
又发现每条横线段可以以 1 的代价到达与之有交的竖线段
于是可以设法建图 BFS
发现图没必要建出来
可以两棵线段树套 set 维护线段
然后直接 BFS,松弛的时候直接查询所有不同方向的与之有交的线段并将其删除就行了
因为被查到的线段以后再被松弛一定不优,所以是正确的
然后发现现在每个点会被一横一竖两条线段统计
发现每条边都连接两个不同方向的线段,所以两次统计一定分别为 \(x\)\(x+1\)
那么多余的部分就容易减去了
\(\sum dis\) 容易求了观察下实现发现 \(\sum dis^2\) 也是可以求的
于是就做完了,复杂度 \(O(n\log^2n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 500010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
pair<int, int> p[N];
const ll mod=998244353, inv2=(mod+1)>>1;
inline ll sqr(ll a) {return a*a%mod;}

// namespace task1{
// 	ll ans1, ans2;
// 	int dis[N<<1], tot;
// 	vector<int> ln[N], col[N];
// 	struct line{int l, r, bel, id, dir;}sta[N<<1];
// 	inline bool operator < (line a, line b) {return a.bel<b.bel;}
// 	queue<line> q;
// 	struct segment{
// 		#define tl(p) tl[p]
// 		#define tr(p) tr[p]
// 		set<line> dat[N<<2];
// 		int tl[N<<2], tr[N<<2];
// 		void build(int p, int l, int r) {
// 			tl(p)=l; tr(p)=r; dat[p].clear();
// 			if (l==r) return ;
// 			int mid=(l+r)>>1;
// 			build(p<<1, l, mid);
// 			build(p<<1|1, mid+1, r);
// 		}
// 		void upd(int p, int l, int r, line val) {
// 			if (l<=tl(p)&&r>=tr(p)) {dat[p].insert(val); return ;}
// 			int mid=(tl(p)+tr(p))>>1;
// 			if (l<=mid) upd(p<<1, l, r, val);
// 			if (r>mid) upd(p<<1|1, l, r, val);
// 		}
// 		void query(int p, int pos, int l, int r, int now) {
// 			for (auto it=dat[p].lower_bound({0, 0, l, 0, 0}); it!=dat[p].end()&&it->bel<=r; it=dat[p].erase(it)) {
// 				if (now+1<dis[it->id]) {
// 					dis[it->id]=now+1;
// 					q.push(*it);
// 				}
// 			}
// 			if (tl(p)==tr(p)) return ;
// 			int mid=(tl(p)+tr(p))>>1;
// 			if (pos<=mid) query(p<<1, pos, l, r, now);
// 			else query(p<<1|1, pos, l, r, now);
// 		}
// 	}seg[2];
// 	void solve() {
// 		ans1=ans2=tot=0;
// 		for (int i=1; i<=n; ++i) ln[i].clear(), col[i].clear();
// 		for (int i=1; i<=m; ++i) ln[p[i].sec].pb(p[i].fir), col[p[i].fir].pb(p[i].sec);
// 		seg[0].build(1, 1, n); seg[1].build(1, 1, n);
// 		for (int i=1; i<=n; ++i) {
// 			sort(ln[i].begin(), ln[i].end());
// 			int lst=1;
// 			for (auto it:ln[i]) {
// 				if (it>lst) ++tot, seg[0].upd(1, lst, it-1, sta[tot]={lst, it-1, i, tot, 0});
// 				lst=it+1;
// 			}
// 			if (lst<=n) ++tot, seg[0].upd(1, lst, n, sta[tot]={lst, n, i, tot, 0});
// 		}
// 		// cout<<"---sta---"<<endl; for (int i=1; i<=tot; ++i) cout<<"("<<sta[i].l<<','<<sta[i].r<<','<<sta[i].bel<<','<<sta[i].id<<','<<sta[i].dir<<")"<<endl;
// 		for (int i=1; i<=n; ++i) {
// 			sort(col[i].begin(), col[i].end());
// 			int lst=1;
// 			for (auto it:col[i]) {
// 				if (it>lst) ++tot, seg[1].upd(1, lst, it-1, sta[tot]={lst, it-1, i, tot, 1});
// 				lst=it+1;
// 			}
// 			if (lst<=n) ++tot, seg[1].upd(1, lst, n, sta[tot]={lst, n, i, tot, 1});
// 		}
// 		for (int i=1; i<=tot; ++i) dis[i]=INF;
// 		dis[1]=0; q.push(sta[1]);
// 		while (q.size()) {
// 			line u=q.front(); q.pop();
// 			ans1=(ans1+sqr(dis[u.id])*(u.r-u.l+1))%mod;
// 			ans2=(ans2+dis[u.id]*(u.r-u.l+1))%mod;
// 			seg[u.dir^1].query(1, u.bel, u.l, u.r, dis[u.id]);
// 		}
// 		ans2=(ans2-sqr(n)+m)%mod;
// 		ans1=(ans1-ans2-sqr(n)+m)*inv2%mod;
// 		printf("%lld\n", (ans1%mod+mod)%mod);
// 	}
// }

namespace task2{
	ll ans1, ans2;
	int dis[N<<1], tot;
	vector<int> ln[N], col[N];
	struct line{int l, r, bel, id, dir;}sta[N<<1];
	queue<int> q;
	struct segment{
		#define tl(p) tl[p]
		#define tr(p) tr[p]
		int tl[N<<2], tr[N<<2];
		set<pair<int, int>> dat[N<<2];
		void build(int p, int l, int r) {
			tl(p)=l; tr(p)=r; dat[p].clear();
			if (l==r) return ;
			int mid=(l+r)>>1;
			build(p<<1, l, mid);
			build(p<<1|1, mid+1, r);
		}
		void upd(int p, int l, int r, pair<int, int> val) {
			if (l<=tl(p)&&r>=tr(p)) {dat[p].insert(val); return ;}
			int mid=(tl(p)+tr(p))>>1;
			if (l<=mid) upd(p<<1, l, r, val);
			if (r>mid) upd(p<<1|1, l, r, val);
		}
		void query(int p, int pos, int l, int r, int now) {
			for (auto it=dat[p].lower_bound({l, 0}); it!=dat[p].end()&&it->fir<=r; it=dat[p].erase(it)) {
				if (now+1<dis[it->sec]) {
					dis[it->sec]=now+1;
					q.push(it->sec);
				}
			}
			if (tl(p)==tr(p)) return ;
			int mid=(tl(p)+tr(p))>>1;
			if (pos<=mid) query(p<<1, pos, l, r, now);
			else query(p<<1|1, pos, l, r, now);
		}
	}seg[2];
	void solve() {
		ans1=ans2=tot=0;
		for (int i=1; i<=n; ++i) ln[i].clear(), col[i].clear();
		for (int i=1; i<=m; ++i) ln[p[i].sec].pb(p[i].fir), col[p[i].fir].pb(p[i].sec);
		seg[0].build(1, 1, n); seg[1].build(1, 1, n);
		for (int i=1; i<=n; ++i) {
			sort(ln[i].begin(), ln[i].end());
			int lst=1;
			for (auto it:ln[i]) {
				if (it>lst) ++tot, sta[tot]={lst, it-1, i, tot, 0}, seg[0].upd(1, lst, it-1, {i, tot});
				lst=it+1;
			}
			if (lst<=n) ++tot, sta[tot]={lst, n, i, tot, 0}, seg[0].upd(1, lst, n, {i, tot});
		}
		// cout<<"---sta---"<<endl; for (int i=1; i<=tot; ++i) cout<<"("<<sta[i].l<<','<<sta[i].r<<','<<sta[i].bel<<','<<sta[i].id<<','<<sta[i].dir<<")"<<endl;
		for (int i=1; i<=n; ++i) {
			sort(col[i].begin(), col[i].end());
			int lst=1;
			for (auto it:col[i]) {
				if (it>lst) ++tot, sta[tot]={lst, it-1, i, tot, 1}, seg[1].upd(1, lst, it-1, {i, tot});
				lst=it+1;
			}
			if (lst<=n) ++tot, sta[tot]={lst, n, i, tot, 1}, seg[1].upd(1, lst, n, {i, tot});
		}
		for (int i=1; i<=tot; ++i) dis[i]=INF;
		dis[1]=0; q.push(1);
		while (q.size()) {
			int u=q.front(); q.pop();
			ans1=(ans1+sqr(dis[sta[u].id])*(sta[u].r-sta[u].l+1))%mod;
			ans2=(ans2+dis[sta[u].id]*(sta[u].r-sta[u].l+1))%mod;
			seg[sta[u].dir^1].query(1, sta[u].bel, sta[u].l, sta[u].r, dis[sta[u].id]);
		}
		ans2=(ans2-sqr(n)+m)%mod;
		ans1=(ans1-ans2-sqr(n)+m)*inv2%mod;
		printf("%lld\n", (ans1%mod+mod)%mod);
	}
}

signed main()
{
	freopen("light.in", "r", stdin);
	freopen("light.out", "w", stdout);

	int T=read();
	while (T--) {
		n=read(); m=read();
		for (int i=1; i<=m; ++i) p[i].fir=read(), p[i].sec=read();
		// task1::solve();
		task2::solve();
	}

	return 0;
}
posted @ 2022-06-27 20:43  Administrator-09  阅读(2)  评论(0编辑  收藏  举报