题解 传统艺能
首先子序列经典DP
可以 \(n^2\) 暴力了
然后发现可以写成矩阵区间乘积
于是线段树维护
发现复杂度 \(27^3mlogn\),不可过
等到考完之后仔细阅读题面,发现字符集大小只有3
于是复杂度变为 \(O(4^3mlogn)\),可以通过
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
//#define int long long
int n, m;
char s[N];
const ll mod=998244353;
namespace force{
ll dp[30], pre[30];
ll calc(int l, int r) {
memset(dp, 0, sizeof(dp));
for (int i=0; i<26; ++i) pre[i]=1;
for (int i=l; i<=r; ++i) {
ll tem=0;
for (int j=0; j<26; ++j) tem=(tem+dp[j])%mod;
dp[s[i]-'A']=tem+1;
}
ll ans=0;
for (int i=0; i<26; ++i) ans=(ans+dp[i])%mod;
return ans;
}
void solve() {
char c[5];
for (int i=1,op,p,l,r; i<=m; ++i) {
scanf("%d", &op);
if (op&1) {
scanf("%d%s", &p, c);
s[p]=*c;
}
else {
scanf("%d%d", &l, &r);
printf("%lld\n", calc(l, r));
}
}
exit(0);
}
}
namespace task1{
void solve() {
char c[5];
for (int i=1,op,p,l,r; i<=m; ++i) {
scanf("%d", &op);
if (op&1) {
scanf("%d%s", &p, c);
s[p]=*c;
}
else {
scanf("%d%d", &l, &r);
printf("%d\n", r-l+1);
}
}
exit(0);
}
}
namespace task2{
struct matrix{
int n, m;
int a[5][5];
matrix(){}
matrix(int x, int y) {n=x; m=y; memset(a, 0, sizeof(a));}
void resize(int x, int y) {n=x; m=y; memset(a, 0, sizeof(a));}
inline int* operator [] (int t) {return a[t];}
inline ll sum() {ll ans=0; for (int i=1; i<=3; ++i) ans=(ans+a[1][i])%mod; return ans;}
inline matrix operator * (matrix b) {
matrix ans(n, b.m);
for (int i=1; i<=n; ++i)
for (int k=1; k<=m; ++k) if (a[i][k])
for (int j=1; j<=b.m; ++j)
ans[i][j]=(ans[i][j]+1ll*a[i][k]*b[k][j])%mod;
return ans;
}
}dat[N<<2], base[30], f0;
int tl[N<<2], tr[N<<2];
#define tl(p) tl[p]
#define tr(p) tr[p]
#define dat(p) dat[p]
#define pushup(p) dat(p)=dat(p<<1)*dat(p<<1|1)
void build(int p, int l, int r) {
tl(p)=l; tr(p)=r;
if (l==r) {dat(p)=base[s[l]-'A']; return ;}
int mid=(l+r)>>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
pushup(p);
}
void upd(int p, int pos, char c) {
if (tl(p)==tr(p)) {dat(p)=base[c-'A']; return ;}
int mid=(tl(p)+tr(p))>>1;
if (pos<=mid) upd(p<<1, pos, c);
else upd(p<<1|1, pos, c);
pushup(p);
}
matrix query(int p, int l, int r) {
if (l<=tl(p)&&r>=tr(p)) return dat(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid && r>mid) return query(p<<1, l, r)*query(p<<1|1, l, r);
else if (l<=mid) return query(p<<1, l, r);
else return query(p<<1|1, l, r);
}
void solve() {
f0.resize(1, 4); f0[1][4]=1;
for (int i=0; i<3; ++i) {
base[i].resize(4, 4);
for (int j=1; j<=4; ++j) {
base[i][j][j]=1;
if (j!=i+1) base[i][j][i+1]=1;
}
}
build(1, 1, n);
char c[5];
for (int i=1,op,p,l,r; i<=m; ++i) {
scanf("%d", &op);
if (op&1) {
scanf("%d%s", &p, c);
upd(1, p, *c);
}
else {
scanf("%d%d", &l, &r);
printf("%lld\n", (f0*query(1, l, r)).sum());
}
}
exit(0);
}
}
signed main()
{
freopen("string.in", "r", stdin);
freopen("string.out", "w", stdout);
scanf("%d%d", &n, &m);
scanf("%s", s+1);
// force::solve();
bool all_a=1;
for (int i=1; i<=n; ++i) if (s[i]!='A') all_a=0;
if (n>2000 && all_a) task1::solve();
else task2::solve();
return 0;
}