【XSY3921】简单的数据结构题(多项式,拉格朗日插值,线段树)

题面

简单的数据结构题

题解

直接考虑我们要计算的式子。为了方便,我们先设 \(l=1,r=n\)

\[\begin{aligned} &\sum_{i=1}^na_i^k\prod_{j\neq i}\frac{1-a_ia_j}{a_i-a_j}\\ =&\sum_{i=1}^na_i^k \left(\prod_{j\neq i}\frac{1}{a_i-a_j}\right) \left(\sum_{l=0}^{n-1}\sum_{1\leq j_1<j_2<\cdots<j_l\leq n \atop j_1,j_2,\cdots,j_l\neq i}(-a_ia_{j_1})(-a_ia_{j_2})\cdots(-a_ia_{j_l})\right)\\ =&\sum_{i=1}^na_i^k \left(\prod_{j\neq i}\frac{1}{a_i-a_j}\right) \left(\sum_{l=0}^{n-1} a_i^l[x^{n-1-l}]\left(\prod_{j\neq i}(x-a_j)\right)\right)\\ =&\sum_{l=0}^{n-1}[x^{n-1-1}]\left(\sum_{i=1}^n a_i^{k+l}\prod_{j\neq i}\frac{x-a_j}{a_i-a_j}\right)\\ \end{aligned} \]

发现 \(\sum\limits_{i=1}^n a_i^{k+l} \prod\limits_{j\neq i}\frac{x-a_j}{a_i-a_j}\) 是拉格朗日插值的形式。

\(f(x)=\sum\limits_{i=1}^n a_i^{k+l} \prod\limits_{j\neq i}\frac{x-a_j}{a_i-a_j}\),那么 \(f(x)\) 是恰好经过 \((a_1,a_1^{k+l}),(a_2,a_2^{k+l}),\cdots,(a_n,a_n^{k+l})\)\(n\) 个点的小于等于 \(n-1\) 次的唯一的多项式。注意,满足这个条件的多项式是唯一的,即 \(f(x)\)

考虑构造出 \(f(x)\)

首先注意到 \(g(x)=x^{k+l}\) 肯定是经过这 \(n\) 个点的,但它不一定小于等于 \(n-1\) 次。

然后注意到对于多项式取模 \(A(x) \bmod B(x)=C(x)\) 来说(其中 \(A(x),B(x),C(x)\) 均为多项式),若有一 \(x_0\) 满足 \(B(x_0)=0\),则 \(C(x_0)=A(x_0) \bmod B(x_0)=A(x_0)\)。也就是在多项式取模中,若除数是 \(0\),则商也是 \(0\),余数和被除数相同。

那么我们设 \(h(x)=g(x)\bmod (x-a_1)(x-a_2)\cdots(x-a_n)=x^{k+l}\bmod (x-a_1)(x-a_2)\cdots(x-a_n)\)。显然,当 \(x=a_i\) 时,模数是 \(0\),有 \(h(a_i)=g(a_i)=a_i^{k+l}\)

所以 \(h(x)\) 也是满足 “恰好经过 \((a_1,a_1^{k+l}),(a_2,a_2^{k+l}),\cdots,(a_n,a_n^{k+l})\)\(n\) 个点的小于等于 \(n-1\) 次” 的多项式,而满足这个条件的多项式又是唯一的,所以 \(h(x)=f(x)\)

所以原式转化为:

\[=\sum_{l=0}^{n-1}[x^{n-l-1}]\left(x^{k+l}\bmod (x-a_1)(x-a_2)\cdots(x-a_n)\right) \]

注意到 \(k\) 很小,而 \(l\leq n-1\),所以 \(k+l\geq n\) 的情况不多,分类讨论即可:

偷一张 solution 的图

在这里插入图片描述

(图片最后一行最后应该是 “\(=(-1)^{n-1}(a_1+a_2+\cdots+a_n)a_1a_2\cdots a_n\)”)

于是只需要维护 \(\sum\limits_i a_i\)\(\prod\limits_i a_i\)\(\sum\limits_{i}\prod\limits_{j\neq i}a_j\),在区间乘的操作下用线段树简单维护即可。

代码如下:

#include<bits/stdc++.h>

#define N 300010

using namespace std;

namespace modular
{
	const int mod=998244353;
	inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
	inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
	inline int mul(int x,int y){return 1ll*x*y%mod;}
}using namespace modular;

inline int poww(int a,int b)
{
	int ans=1;
	while(b)
	{
		if(b&1) ans=mul(ans,a);
		a=mul(a,a);
		b>>=1;
	}
	return ans;
}

inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	return x*f;
}

struct data
{
	int sum,prod,ans;
	data(){sum=ans=0,prod=1;}
	data(int a,int b,int c){sum=a,prod=b,ans=c;}
}t[N<<2],now;

data merge(data a,data b)
{
	data c;
	c.sum=add(a.sum,b.sum);
	c.prod=mul(a.prod,b.prod);
	c.ans=add(mul(a.prod,b.ans),mul(a.ans,b.prod));
	return c;
}

int n,m,a[N];
int lazy[N<<2];

void up(int k)
{
	t[k]=merge(t[k<<1],t[k<<1|1]);
}

void downn(int k,int l,int r,int val)
{
	int tmp=poww(val,r-l);
	t[k].sum=mul(t[k].sum,val);
	t[k].prod=mul(t[k].prod,mul(tmp,val));
	t[k].ans=mul(t[k].ans,tmp);
	lazy[k]=mul(lazy[k],val);
}

void down(int k,int l,int r,int mid)
{
	if(lazy[k]!=1)
	{
		downn(k<<1,l,mid,lazy[k]);
		downn(k<<1|1,mid+1,r,lazy[k]);
		lazy[k]=1;
	}
}

void build(int k,int l,int r)
{
	lazy[k]=1;
	if(l==r)
	{
		t[k]=data(a[l],a[l],1);
		return;
	}
	int mid=(l+r)>>1;
	build(k<<1,l,mid);
	build(k<<1|1,mid+1,r);
	up(k);
}

void update(int k,int l,int r,int ql,int qr,int x)
{
	if(ql<=l&&r<=qr)
	{
		downn(k,l,r,x);
		return;
	}
	int mid=(l+r)>>1;
	down(k,l,r,mid);
	if(ql<=mid) update(k<<1,l,mid,ql,qr,x);
	if(qr>mid) update(k<<1|1,mid+1,r,ql,qr,x);
	up(k);
}

void query(int k,int l,int r,int ql,int qr)
{
	if(ql<=l&&r<=qr)
	{
		now=t[k];
		return;
	}
	int mid=(l+r)>>1;
	down(k,l,r,mid);
	data ans;
	if(ql<=mid)
	{
		query(k<<1,l,mid,ql,qr);
		ans=now;
	}
	if(qr>mid)
	{
		query(k<<1|1,mid+1,r,ql,qr);
		ans=merge(ans,now);
	}
	now=ans;
}

int main()
{
	n=read(),m=read();
	for(int i=1;i<=n;i++) a[i]=read();
	build(1,1,n);
	while(m--)
	{
		int opt=read(),l=read(),r=read(),k=read();
		if(opt==1) update(1,1,n,l,r,k);
		else
		{
			int nn=r-l+1,ans=0;
			if(!((nn-k-1)&1)) ans++;
			if(k==1)
			{
				query(1,1,n,l,r);
				if(nn&1) ans=add(ans,now.prod);
				else ans=dec(ans,now.prod);
			}
			if(k==2)
			{
				query(1,1,n,l,r);
				if(nn&1) ans=dec(ans,now.ans);
				else ans=add(ans,now.ans);
				int tmp=mul(now.sum,now.prod);
				if(nn&1) ans=add(ans,tmp);
				else ans=dec(ans,tmp);
			}
			printf("%d\n",ans);
		}
	}
	return 0;
}
/*
3 3
1 2 3
2 1 2 0
1 1 3 5
2 1 2 0
*/
posted @ 2022-10-30 14:26  ez_lcw  阅读(24)  评论(0编辑  收藏  举报