FFT与游戏开发(四)

FFT与游戏开发(四)

在海浪的计算中,实际上用的是FFT的逆运算,IFFT,它的套路和FFT是类似的。

推导过程

  1. iFFT的原始公式。

    \[x(m) = \frac{1}{N} \sum_{n=0}^{N-1} X(n)e^{j2\pi nm/N} \]

  2. 这里我们先把归一化用的$ 1/N $去掉,方便后面推导。

    \[\begin{aligned} x(m) =& \sum_{n=0}^{N-1} X(n)e^{j2\pi nm/N}\\ =& \sum_{n=0}^{N/2-1} X(2n)e^{j2\pi (2n)m/N} + \sum_{n=0}^{N/2-1} X(2n+1)e^{j2\pi (2n+1)m/N} \\ =& \sum_{n=0}^{N/2-1} X(2n)e^{j2\pi nm/(N/2)} + e^{j2\pi m / N} \sum_{n=0}^{N/2-1} X(2n+1)e^{j2\pi nm/(N/2)} \\ \end{aligned} \]

  3. 类似的,可以用 $ W_{N/2}^{m} = e^{j2\pi nm / (N/2)} $ 去代入。

    \[x(m) = \sum_{n=0}^{N/2-1} X(2n) W_{N/2}^{m} + W_N^m \sum_{n=0}^{N/2-1} X(2n+1) W_{N/2}^{m} \]

  4. 对于 $ m \geq N/2 $ 的情况来说,可以用 $ m = m' + N/2 $ 进行代入。

    \[\begin{aligned} x(m) =& \sum_{n=0}^{N/2-1} X(2n) W_{N/2}^{m' + N/2} + W_N^{m' + N/2} \sum_{n=0}^{N/2-1} X(2n+1) W_{N/2}^{m' + N/2} \\ =& \sum_{n=0}^{N/2-1} X(2n) W_{N/2}^{m'} - W_N^{m'} \sum_{n=0}^{N/2-1} X(2n+1) W_{N/2}^{m'} \\ \end{aligned} \]

    1. 可以看到,只需要改变中间的的符号即可,这和FFT的思路是一致的,只是中间的twiddle factor是不一样的。

GPU实现

iFFT的GPU实现和FFT的非常类似,这里我放了两者的实现,可以方便对比他们的区别。

#pragma kernel FFT
#pragma kernel iFFT

static const uint FFT_DIMENSION = 8;
static const uint FFT_BUTTERFLYS = 4;
static const uint FFT_STAGES = 3;
static const float PI = 3.14159265;
groupshared float2 pingPongArray[FFT_DIMENSION * 2];

RWStructuredBuffer<float2> srcData;
RWStructuredBuffer<float2> dstData;

uint ReverseBits(uint index, uint count) {
	return reversebits(index) >> (32 - count);
}

float2 ComplexMultiply(float2 a, float2 b) {
	return float2(a.x * b.x - a.y * b.y, a.y * b.x + a.x * b.y);
}

void ButterFlyOnce(float2 input0, float2 input1, float2 twiddleFactor, out float2 output0, out float2 output1) {
	float2 t = ComplexMultiply(twiddleFactor, input1);
	output0 = input0 + t;
	output1 = input0 - t;
}

float2 Euler(float theta) {
	float2 ret;
	sincos(theta, ret.y, ret.x);
	return ret;
}

[numthreads(FFT_BUTTERFLYS, 1, 1)]
void FFT(uint3 id : SV_DispatchThreadID)
{
	uint butterFlyID = id.x;
	uint index0 = butterFlyID * 2;
	uint index1 = butterFlyID * 2 + 1;
	pingPongArray[index0] = srcData[ReverseBits(index0, FFT_STAGES)];
	pingPongArray[index1] = srcData[ReverseBits(index1, FFT_STAGES)];

	uint2 offset = uint2(0, FFT_BUTTERFLYS);
	[unroll]
	for (uint s = 1; s <= FFT_STAGES; s++) {
		GroupMemoryBarrierWithGroupSync();
		// 每个stage中独立的FFT的宽度
		uint m = 1 << s;
		uint halfWidth = m >> 1;
		// 属于第几个FFT
		uint nFFT = butterFlyID / halfWidth;
		// 在FFT中属于第几个输入
		uint k = butterFlyID % halfWidth;
		index0 = k + nFFT * m;
		index1 = index0 + halfWidth;
		if (s != FFT_STAGES) {
			ButterFlyOnce(
				pingPongArray[offset.x + index0], pingPongArray[offset.x + index1],
				Euler(-2 * PI * k / m),
				pingPongArray[offset.y + index0], pingPongArray[offset.y + index1]);
			offset.xy = offset.yx;
		} else {
			ButterFlyOnce(
				pingPongArray[offset.x + index0], pingPongArray[offset.x + index1],
				Euler(-2 * PI * k / m),
				dstData[index0], dstData[index1]);
		}
	}
}

[numthreads(FFT_BUTTERFLYS, 1, 1)]
void iFFT(uint3 id : SV_DispatchThreadID)
{
	uint butterFlyID = id.x;
	uint index0 = butterFlyID * 2;
	uint index1 = butterFlyID * 2 + 1;
	pingPongArray[index0] = srcData[ReverseBits(index0, FFT_STAGES)];
	pingPongArray[index1] = srcData[ReverseBits(index1, FFT_STAGES)];

	uint2 offset = uint2(0, FFT_BUTTERFLYS);
	[unroll]
	for (uint s = 1; s <= FFT_STAGES; s++) {
		GroupMemoryBarrierWithGroupSync();
		// 每个stage中独立的FFT的宽度
		uint m = 1 << s;
		uint halfWidth = m >> 1;
		// 属于第几个iFFT
		uint nFFT = butterFlyID / halfWidth;
		// 在iFFT中属于第几个输入
		uint k = butterFlyID % halfWidth;
		index0 = k + nFFT * m;
		index1 = index0 + halfWidth;
		if (s != FFT_STAGES) {
			ButterFlyOnce(
				pingPongArray[offset.x + index0], pingPongArray[offset.x + index1],
				Euler(2 * PI * k / m),
				pingPongArray[offset.y + index0], pingPongArray[offset.y + index1]);
			offset.xy = offset.yx;
		} else {
			float2 output0;
			float2 output1;
			ButterFlyOnce(
				pingPongArray[offset.x + index0], pingPongArray[offset.x + index1],
				Euler(2 * PI * k / m),
				output0, output1);
			dstData[index0] = output0 / FFT_DIMENSION;
			dstData[index1] = output1 / FFT_DIMENSION;
		}
	}
}

参考

  1. Understanding Digital Signal Processing
  2. Introduction to Algorithms, 3rd
posted @ 2020-03-18 21:43  马子哥  阅读(197)  评论(0编辑  收藏  举报