expression template

表达式模板是一种C++模板元编程(template metaprogram)技术。典型情况下,表达式模板自身代表一种操作,模板参数代表该操作的操作数。模板表达式可将子表达式的计算推迟,这样 有利于优化(特别是减少临时变量的使用)。表达式模板也可以作为参数传递给一个函数。

例子:我们实现一个用来求表达式 x = 1.2*x + x*y 的模板表达式,其中x、y为数组

//exprarray.h
#include <stddef.h>
#include <cassert>
#include "sarray.h"

template<typename T>
class A_Scale
{
public:
	A_Scale(T const& t):value(t){}
	T operator[](size_t) const
	{
		return value;
	}
	size_t size() const
	{
		return 0;
	}
private:
	T const& value;
};

template<typename T>
class A_Traits
{
public:
	typedef T const& exprRef;
};
template<typename T>
class A_Traits<A_Scale<T> >
{
public:
	typedef A_Scale<T> exprRef;
};

template<typename T,typename L1,typename R2>
class A_Add
{
private:
	typename A_Traits<L1>::exprRef op1;
	typename A_Traits<R2>::exprRef op2;
public:
	A_Add(L1 const& a,R2 const& b):op1(a),op2(b)
	{
	}
	T operator[](size_t indx) const
	{
		return op1[indx] + op2[indx];
	}
	size_t size() const
	{
		assert(op1.size()==0 || op2.size()==0 || op1.size() == op2.size());
		return op1.size() != 0 ? op1.size() : op2.size();
	}
};

template<typename T,typename L1,typename R2>
class A_Mul
{
private:
	typename A_Traits<L1>::exprRef op1;
	typename A_Traits<R2>::exprRef op2;
public:
	A_Mul(L1 const& a,R2 const& b):op1(a),op2(b)
	{
	}
	T operator[](size_t indx) const
	{
		return op1[indx] * op2[indx];
	}
	size_t size() const
	{
		assert(op1.size()==0 || op2.size()==0 || op1.size() == op2.size());
		return op1.size() != 0 ? op1.size():op2.size();
	}
};

template<typename T,typename Rep = SArray<T> >
class Array
{
public:
	explicit Array(size_t N):expr_Rep(N){}
	Array(Rep const& rep):expr_Rep(rep){}
	Array& operator=(Array<T> const& orig)
	{
		assert(size() == orig.size());
		for (size_t indx=0;indx < orig.size();indx++)
		{
			expr_Rep[indx] = orig[indx];
		}
		return *this;
	}
	template<typename T2,typename Rep2>
	Array& operator=(Array<T2,Rep2> const& orig)
	{
		assert(size() == orig.size());
		for (size_t indx=0;indx<orig.size();indx++)
		{
			expr_Rep[indx] = orig[indx];
		}
		return *this;
	}
	size_t size() const
	{
		return expr_Rep.size();
	}
	T operator[](size_t indx) const
	{
		assert(indx < size());
		return expr_Rep[indx];
	}
	T& operator[](size_t indx)
	{
		assert(indx < size());
		return expr_Rep[indx];
	}
	Rep const& rep() const
	{
		return expr_Rep;
	}
	Rep& rep()
	{
		return expr_Rep;
	}
private:
	Rep expr_Rep;
};

template<typename T,typename L1,typename R2>
Array<T,A_Add<T,L1,R2> >
operator+(Array<T,L1> const& a,Array<T,R2> const& b)
{
	return Array<T,A_Add<T,L1,R2> >(A_Add<T,L1,R2>(a.rep(),b.rep()));
}

template<typename T,typename L1,typename R2>
Array<T,A_Mul<T,L1,R2> >
operator*(Array<T,L1> const& a,Array<T,R2> const& b)
{
	return Array<T,A_Mul<T,L1,R2> >(A_Mul<T,L1,R2>(a.rep(),b.rep()));
}

template<typename T,typename R2>
Array<T,A_Mul<T,A_Scale<T>,R2> >
operator*(T const& a,Array<T,R2> const& b)
{
	return Array<T,A_Mul<T,A_Scale<T>,R2> >(A_Mul<T,A_Scale<T>,R2>(A_Scale<T>(a),b.rep()));
}

 测试代码(求解表达式1.2*x+x*y):

//test.cpp
#include "exprarray.h"
#include <iostream>
using namespace std;

template <typename T>
void print (T const& c)
{
	for (int i=0; i<8; ++i) {
		std::cout << c[i] << ' ';
	}
	std::cout << "..." << std::endl;
}

int main()
{
	Array<double> x(1000), y(1000);

	for (int i=0; i<1000; ++i) {
		x[i] = i;
		y[i] = x[i]+x[i];
	}

	std::cout << "x: ";
	print(x);

	std::cout << "y: ";
	print(y);

	x = 1.2 * x;
	std::cout << "x = 1.2 * x: ";
	print(x);

	x = 1.2*x + x*y;
	std::cout << "1.2*x + x*y: ";
	print(x);

	x = y;
	std::cout << "after x = y: ";
	print(x);

	return 0;
}

 下面我们来分析一下模板表达式的解析过程:
我们以表达式 x = 1.2*x + x*y为例
当编译器解析表达式:x = 1.2*x + x*y 的时候,编译器首先会应用最左边的*运算符,它是一个Scale-Array运算符。于是重载解析规则将会选择operator*的Scale-Array形式:

template<typename T,typename R2>
Array<T,A_Mul<T,A_Scale<T>,R2> >
operator*(T const& a,Array<T,R2> const& b)
{
	return Array<T,A_Mul<T,A_Scale<T>,R2> >(A_Mul<T,A_Scale<T>,R2>(A_Scale<T>(a),b.rep()));
}

 其中操作数的类型是double和Array<double,SArray<double> >,因此实际的结果类型是:

Array<double,A_Mul<double,A_Scale<double>,SArray<double> > >

接下来,编译器会对第二个乘法进行求值:x*y是一个array-array操作,这一次,我们将会选择operator*的Array-Array重载操作:

template<typename T,typename L1,typename R2>
Array<T,A_Mul<T,L1,R2> >
operator*(Array<T,L1> const& a,Array<T,R2> const& b)
{
	return Array<T,A_Mul<T,L1,R2> >(A_Mul<T,L1,R2>(a.rep(),b.rep()));
}

 其中两个操作数类型都是Array<double,SArray<double> >,因此结果类型为:

Array<double,A_Mul<double,SArray<double>,SArray<double> > >

这一次,A_Mul所封装的连个参数对象都引用了一个SArray<double>表示:即一个表示x对象,一个表示y对象。
现在开始对+运算符进行求值。这次还是Array-Array操作,因此调用Array-Array版本的operator+:

template<typename T,typename L1,typename R2>
Array<T,A_Add<T,L1,R2> >
operator+(SArray<T,L1> const& a,SArray<T,R2> const& b)
{
	return Array<T,L1,R2>(A_Add<T,L1,R2>(a.rep(),b.rep()));
}

 其中用double来替换T,则R1为:

A_Mul<double,A_Scale<double>,SArray<double> >

 R2为:

A_Mul<double,SArray<double>,SArray<double> >

 因此赋值表达式 x = 1.2*x + x*y的右边经过编译器解析后的最终类型为:

Array<double,
	A_Add<double,
		A_Mul<double,A_Scale<double>,SArray<double> >
		A_Mul<double,SArray<double>,SArray<double> > > >

 这个类型将与Array模板的赋值运算符模板进行匹配:

//针对不同类型数组的赋值运算符
template<typename T2,typename Rep2>
	Array& operator=(Array<T2,Rep2> const& orig)
	{
		assert(size() == orig.size());
		for (size_t indx=0;indx<orig.size();indx++)
		{
			expr_Rep[indx] = orig[indx];
		}
		return *this;
	}

 此时,赋值运算符将会运用右边Array的下标运算符来计算目标数组的每一个元素,而Array的实际类型为:

Array<double,
	A_Add<double,
		A_Mul<double,A_Scale<double>,SArray<double> >
		A_Mul<double,SArray<double>,SArray<double> > > >

 我们记为:ArrayTgt
此时,ArrayTgt[indx]将会匹配模板类A_Add中的重载操作符operator[],即:

T operator[](size_t indx) const
	{
		return op1[indx] + op2[indx];
	}

匹配之后就变成:

A_Mul<double,A_Scale<double>,SArray<double> >[indx]
+
A_Mul<double,SArray<double>,SArray<double> >[indx]; 

 而A_Mul[indx]又会匹配模板类A_Mul中的重载操作符operator[],即:

T operator[](size_t indx) const
	{
		return op1[indx] * op2[indx];
	}

 匹配之后就变成:

A_Scale<double>[indx] * SArray<double>[indx] 
+
SArray<double>[indx] * SArray<double>[indx]

 而A_Scale[indx]又会匹配模板类A_Scale中的重载操作符operator[],即:

T operator[](size_t) const
	{
		return value;
	}

 这样最终的结果就表达式就变成:

value[indx] * SArray<double>[indx] 
+
SArray<double>[indx] * SArray<double>[indx]

至此,整个模板表达式的解析工作已经完成,只需进行计算即可。在整个计算过程中,没有产生任何的中间变量,所以程序的效率得以大幅的提高。

程序注意事项:
1.在上述代码中,如果将模板类Array的代码:

Array& operator=(Array<T2,Rep2> const& orig)
	{
		assert(size() == orig.size());
		for (size_t indx=0;indx<orig.size();indx++)
		{
			expr_Rep[indx] = orig[indx];
		}
		return *this;
	}

 中的参数改为Array<T2,Rep2> & orig,即变成:

Array& operator=(Array<T2,Rep2>& orig)
	{
		assert(size() == orig.size());
		for (size_t indx=0;indx<orig.size();indx++)
		{
			expr_Rep[indx] = orig[indx];
		}
		return *this;
	}

将会导致编译出错,原因是:
在test.cpp文件中,我们使用了表达式:x = 1.2 * x ,这个表达式的右边将会被编译器解析为如下形式的表达式:

Array<double,A_Mul<double,A_Scale<double>,SArray<doube> > >

 这样在进行重载操作符operator[]的匹配时,将会变成如下形式:

SArray[indx] = A_Scale[indx] * SArray[indx]

 到了这一步,问题就出现了,因为A_Scale[indx]会匹配模板类Array中的重载操作符operator[],但是我们发现在模板类Array代码中,有两个重载的operator[],即:

T operator[](size_t indx) const
	{
		assert(indx < size());
		return expr_Rep[indx];
	}
T& operator[](size_t indx)
	{
		assert(indx < size());
		return expr_Rep[indx];
	}

如果我们没在重载操作符operator=的参数中写入const的话,这里会优先调用无const的operator[]重载函数,但是A_Scale[indx]是个常数,在本例中也就是一个double类型,这样最后在调用operator[]返回的时候就出现了类型不匹配的现象,因为无const的operator[]返回的类型是double&,所以会报错。当然,我们可以将test.cpp程序中的表达式1.2*x去掉,我们会发现,这个时候无const的operator=就会编译通过。

2.在上述代码中,模板类Array的构造函数代码为:
explicit Array(size_t N):expr_Rep(N){}
这表明定义Array必须通过显式转型,不能通过隐式转型。下述代码会导致编译出错:
Array a = 5;
我们只能使用
Array a(5);
进行显式初始化。
下面用一个例子来区别显式转型和隐式转型的细微区别:
X x;
Y y(x); //显式转型
Y y = x;//隐式转型
其中前者通过使用从X到Y类型的显式转型,新建一个类型为Y的对象。后者使用了一个从类型X到Y类型的隐式转型,新建了一个类型Y的对象。

3.在上述代码中,模板类Array的两个重载操作符operator[]代码:

T operator[](size_t indx) const
	{
		assert(indx < size());
		return expr_Rep[indx];
	}
T& operator[](size_t indx)
	{
		assert(indx < size());
		return expr_Rep[indx];
	}

 注意在一个重载操作符函数后面的const一定不能少,否则会导致编译错误。因为没有const的话,函数

T operator[](size_t indx)
	{
		assert(indx < size());
		return expr_Rep[indx];
	}

 和

T& operator[](size_t indx)
	{
		assert(indx < size());
		return expr_Rep[indx];
	}

会被认为是一个函数,因为他们静静是返回类型不同而已。函数
int test(){}

int test() const{}
会被编译器理解为两个不同的函数。

最后将SArray的代码附上:

#ifndef SARRAY_H
#define SARRAY_H

#include <stddef.h>
#include <cassert>

template<typename T>
class SArray
{
public:
	explicit SArray(size_t N):ptr(new T[N]),_size(N)
	{
		init();
	}
	SArray(SArray<T> const& orig):ptr(new T[orig.size()]),_size(orig.size())
	{
		copy(orig);
	}
	~SArray()
	{
		delete[] ptr;
	}
	size_t size() const
	{
		return _size;
	}
	T operator[](size_t indx) const
	{
		return ptr[indx];
	}
	T& operator[](size_t indx)
	{
		return ptr[indx];
	}
	SArray<T>& operator=(SArray<T> const& orig)
	{
		if (&orig != this)
		{
			copy(orig);
		}
		return *this;
	}
protected:
	void copy(SArray<T> const& orig)
	{
		assert(size() == orig.size());
		for (size_t indx=0;indx<orig.size();indx++)
		{
			ptr[indx] = orig[indx];
		}
	}
	void init()
	{
		for(size_t i=0;i<size();i++)
		{
			ptr[i] = T();
		}
	}
private:
	T* ptr;
	size_t _size;
};
#endif

 

posted @ 2011-11-17 10:58  MagiCube  阅读(1039)  评论(0编辑  收藏  举报