题解 传统艺能

传送门

首先子序列经典DP
可以 \(n^2\) 暴力了
然后发现可以写成矩阵区间乘积
于是线段树维护
发现复杂度 \(27^3mlogn\),不可过

等到考完之后仔细阅读题面,发现字符集大小只有3
于是复杂度变为 \(O(4^3mlogn)\),可以通过

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
//#define int long long

int n, m;
char s[N];
const ll mod=998244353;

namespace force{
	ll dp[30], pre[30];
	ll calc(int l, int r) {
		memset(dp, 0, sizeof(dp));
		for (int i=0; i<26; ++i) pre[i]=1;
		for (int i=l; i<=r; ++i) {
			ll tem=0;
			for (int j=0; j<26; ++j) tem=(tem+dp[j])%mod;
			dp[s[i]-'A']=tem+1;
		}
		ll ans=0;
		for (int i=0; i<26; ++i) ans=(ans+dp[i])%mod;
		return ans;
	}
	void solve() {
		char c[5];
		for (int i=1,op,p,l,r; i<=m; ++i) {
			scanf("%d", &op);
			if (op&1) {
				scanf("%d%s", &p, c);
				s[p]=*c;
			}
			else {
				scanf("%d%d", &l, &r);
				printf("%lld\n", calc(l, r));
			}
		}
		exit(0);
	}
}

namespace task1{
	void solve() {
		char c[5];
		for (int i=1,op,p,l,r; i<=m; ++i) {
			scanf("%d", &op);
			if (op&1) {
				scanf("%d%s", &p, c);
				s[p]=*c;
			}
			else {
				scanf("%d%d", &l, &r);
				printf("%d\n", r-l+1);
			}
		}
		exit(0);
	}
}

namespace task2{
	struct matrix{
		int n, m;
		int a[5][5];
		matrix(){}
		matrix(int x, int y) {n=x; m=y; memset(a, 0, sizeof(a));}
		void resize(int x, int y) {n=x; m=y; memset(a, 0, sizeof(a));}
		inline int* operator [] (int t) {return a[t];}
		inline ll sum() {ll ans=0; for (int i=1; i<=3; ++i) ans=(ans+a[1][i])%mod; return ans;}
		inline matrix operator * (matrix b) {
			matrix ans(n, b.m);
			for (int i=1; i<=n; ++i)
				for (int k=1; k<=m; ++k) if (a[i][k])
					for (int j=1; j<=b.m; ++j)
						ans[i][j]=(ans[i][j]+1ll*a[i][k]*b[k][j])%mod;
			return ans;
		}
	}dat[N<<2], base[30], f0;
	int tl[N<<2], tr[N<<2];
	#define tl(p) tl[p]
	#define tr(p) tr[p]
	#define dat(p) dat[p]
	#define pushup(p) dat(p)=dat(p<<1)*dat(p<<1|1)
	void build(int p, int l, int r) {
		tl(p)=l; tr(p)=r;
		if (l==r) {dat(p)=base[s[l]-'A']; return ;}
		int mid=(l+r)>>1;
		build(p<<1, l, mid);
		build(p<<1|1, mid+1, r);
		pushup(p);
	}
	void upd(int p, int pos, char c) {
		if (tl(p)==tr(p)) {dat(p)=base[c-'A']; return ;}
		int mid=(tl(p)+tr(p))>>1;
		if (pos<=mid) upd(p<<1, pos, c);
		else upd(p<<1|1, pos, c);
		pushup(p);
	}
	matrix query(int p, int l, int r) {
		if (l<=tl(p)&&r>=tr(p)) return dat(p);
		int mid=(tl(p)+tr(p))>>1;
		if (l<=mid && r>mid) return query(p<<1, l, r)*query(p<<1|1, l, r);
		else if (l<=mid) return query(p<<1, l, r);
		else return query(p<<1|1, l, r);
	}
	void solve() {
		f0.resize(1, 4); f0[1][4]=1;
		for (int i=0; i<3; ++i) {
			base[i].resize(4, 4);
			for (int j=1; j<=4; ++j) {
				base[i][j][j]=1;
				if (j!=i+1) base[i][j][i+1]=1;
			}
		}
		build(1, 1, n);
		char c[5];
		for (int i=1,op,p,l,r; i<=m; ++i) {
			scanf("%d", &op);
			if (op&1) {
				scanf("%d%s", &p, c);
				upd(1, p, *c);
			}
			else {
				scanf("%d%d", &l, &r);
				printf("%lld\n", (f0*query(1, l, r)).sum());
			}
		}
		exit(0);
	}
}

signed main()
{
	freopen("string.in", "r", stdin);
	freopen("string.out", "w", stdout);

	scanf("%d%d", &n, &m);
	scanf("%s", s+1);
	// force::solve();
	bool all_a=1;
	for (int i=1; i<=n; ++i) if (s[i]!='A') all_a=0;
	if (n>2000 && all_a) task1::solve();
	else task2::solve();
	
	return 0;
}
posted @ 2021-10-26 16:17  Administrator-09  阅读(0)  评论(0编辑  收藏  举报