您的位置:首页 > 编程语言 > C语言/C++

我的C++实践(8):表达式模板技术

2016-07-29 00:00 537 查看
表达式模板技术主要是为了提高存放数值的容器型对象的计算效率。对效率效率要求苛刻(比如大量的数值计算)的计算,使用表达式模板技术可以获得很高的效率,而且客户端的使用代码并没有什么变化,仍然非常紧凑。先看一个简单数组容器模板的实现:

//sarray.hpp:简单的数组容器类型实现
#ifndef SARRAY_HPP
#define SARRAY_HPP
#include <cstddef>
#include <cassert>
template<typename T>
class SArray{
private:
T* storage; //数组元素的存储空间
size_t storage_size; //元素的个数
protected:
void init(){
for(size_t idx=0;idx<size();++idx)
storage[idx]=T(); //初始化数组的各个元素
}
void copy(SArray<T> const& orig){
assert(size()==orig.size()); //验证两个数组大小是否一致
for(size_t idx=0;idx<size();++idx){ //拷贝另一个数组的值
storage[idx]=orig.storage[idx];
}
}
public:
explicit SArray(size_t s)
:storage(new T[s]),storage_size(s){ //创建一个具有初始值大小的数组
//由s个T类型的元素构成的数组
init();
}
SArray(SArray<T> const& orig)
:storage(new T[orig.size()]),storage_size(orig.size()){ //拷贝构造函数
copy(orig);
}
~SArray(){
delete[] storage; //释放数组的内存空间
}
SArray<T>& operator=(SArray<T> const& orig){ //赋值运算符
if(this!=&orig)
copy(orig);
return *this;
}
size_t size() const{ //返回数组大小
return storage_size;
}
T& operator[](size_t const idx){ //对SArray变量对象进行下标运算符
return storage[idx];
}
T operator[](size_t const idx) const{ //对Array常量对象(即const对象)进行下标运算符
return storage[idx];
}
SArray<T>& operator+=(SArray<T> const& b);
SArray<T>& operator*=(SArray<T> const& b);
SArray<T>& operator*=(T const& s);
//...
};
//对两个SArray求和
template<typename T>
SArray<T> operator+(SArray<T> const& a,SArray<T> const& b){
SArray<T> result(a.size()); //创建了临时的数组
for(size_t k=0;k<a.size();++k){
result[k]=a[k]+b[k];
}
return result;
}
//对两个SArray求积
template<typename T>
SArray<T> operator*(SArray<T> const& a,SArray<T> const& b){
SArray<T>  result(a.size());
for(size_t k=0;k<a.size();++k){
result[k]=a[k]*b[k];
}
return result;
}
//让一个SArray乘以一个放大位数(scalar)
template<typename T>
SArray<T> operator*(T const& s,SArray<T> const& a){
SArray<T> result(a.size());
for(size_t k=0;k<a.size();++k){
result[k]=s*a[k];
}
return result;
}
//对SArray和scalar求积
//对scalar和SArray求和
//对SArray和scalar求和
//SArray的自加运算
template<typename T>
SArray<T>& SArray<T>::operator+=(SArray<T> const& b){
for(size_t k=0;k<size();++k){
(*this)[k]+=b[k];
}
return *this;
}
//SArray的自乘运算
template<typename T>
SArray<T>& SArray<T>::operator*=(SArray<T> const& b){
for(size_t k=0;k<size();++k){
(*this)[k]*=b[k];
}
return *this;
}
//针对放大位数的自乘运算符
template<typename T>
SArray<T>& SArray<T>::operator*=(T const& s){
for(size_t k=0;k<size();++k){
(*this)[k]*=s;
}
return *this;
}
//...
#endif


数组SArray的每个运算符实现中都要创建一个临时数组对象,并扫描一次所有的元素来进行计算,当对复合的表达式进行计算时,效率会非常低。比如我们定义Array<double> x(1000),y(1000); 然后计算表达式x=1.2*x+x*y,这相当于tmp1=1.2*x,tmp2=x*y,tmp3=tmp1+tmp2,x=tmp3,每运算都要创建一个临时数组并扫描一次所有元素(各循环1000次)来计算,还要加上创建和删除tmp对象,最后的赋值又有1000次读和1000次写操作,性能大大降低。对于x=1.2*x+x*y这样的表达式,我们更希望只扫描一次所有的元素,并计算x[i]=1.2*x[i]+x[i]*y[i]。也就是说,进行一次扫描就把表达式的结果计算出来,这样性能会大大提高。我们可以创建表达式模板来实现这一点,把表达式1.2*x+x*y转化成如下类型的对象:
A_Add< double, A_Mult<double,A_Scalar<double>,Array<double> >, A_Mult<double,Array<double>,Array<double> > >
其中double是数组中元素的类型,A_Scalar<T>是对1.2这样的放大倍数的包装,而A_Add<T,OP1,OP2>、A_Mult<T,OP1,OP2>就是所谓的表达式模板,一个是加法表达式,一个是乘法表达式。这里是针对数组运算的表达式,因此T是数组元素的类型,OP1和OP2是数组类型Array<T>或者放大倍数类型A_Scalar<T>。数组类型Array<T>是对使用了表达式模板后的SArray的重新实现,其实Array还可以重用SArray的实现(后面我们会看到)。表达式模板对我们要计算的表达式进行模板化的封装,并最终对数值型的数组元素执行表达式所表示计算。
对数组元素进行计算的表达式模板如下:

//exprops.hpp:对数组元素进行计算的表达式模板
#ifndef EXPR_OPS_HPP
#define EXPR_OPS_HPP
#include <cstddef>
#include <cassert>
#include "exprtrait.hpp"
//这里的OP1和OP2可以是数组类型Array<T>或放大倍数类型A_Scalar<T>
template<typename T,typename OP1,typename OP2>
class A_Add{ //两个数组之和的表达式模板
private:
typename A_Traits<OP1>::ExprRef op1; //第1个操作数
typename A_Traits<OP2>::ExprRef op2; //第2个操作数
public:
//构造函数,对指向操作数的引用进行初始化
A_Add(OP1 const& a,OP2 const& b):op1(a),op2(b){
}
T& operator[](size_t const idx){ //在求值时计算和(非const版本)
return op1[idx]+op2[idx]; //对给定下标处的数组元素做加运算
}
T operator[](size_t const idx) const{ //const版本
return op1[idx]+op2[idx];
}
size_t size() const{
assert(op1.size()==0 || op2.size()==0
|| op1.size()==op2.size());
return op1.size()!=0?op1.size():op2.size();
}
};
template<typename T,typename OP1,typename OP2>
class A_Mult{ //两个数组之积的表达式模板
private:
typename A_Traits<OP1>::ExprRef op1;
typename A_Traits<OP2>::ExprRef op2;
public:
//构造函数,对指向操作数的引用进行初始化
A_Mult(OP1 const& a,OP2 const& b):op1(a),op2(b){
}
T& operator[](size_t const idx){ //在求值时计算积(非const版本)
return op1[idx]*op2[idx]; //对给定下标处的数组元素做乘运算
}
T operator[](size_t const idx) const{ //const版本
return op1[idx]*op2[idx];
}
size_t size() const{ //size表示最大容量
assert(op1.size()==0 || op2.size()==0
|| op1.size()==op2.size());
return op1.size()!=0?op1.size():op2.size();
}
};
#endif


//exprtrait.hpp:表达式模板参数的trait,根据不同的类型决定传值还是传引用
#ifndef A_TRAITS_HPP
#define A_TRAITS_HPP
#include "exprscalar.hpp"
template<typename T>
class A_Traits{ //基本模板
public:
typedef T const& ExprRef; //对一般的类型传const引用
};
template<typename T>
class A_Traits<A_Scalar<T> >{ //对scalar(放大倍数)的局部特化
public:
typedef A_Scalar<T> ExprRef; //对A_Scalar则传值
};
#endif


//exprscalar.hpp:封装放大倍数(scalar)的类模板
#ifndef A_SCALAR_HPP
#define A_SCALAR_HPP
template<typename T>
class A_Scalar{
private:
T const& s; //scalar的值
public:
A_Scalar(T const& v):s(v){
}
T& operator[](size_t const idx){ //对A_Scalar变量对象进行下标运算符
return s; //直接返回倍数值
}
T operator[](size_t const idx) const{ //对A_Scalar常量对象进行下标运算符
return s;
}
size_t size() const{
return 0; //元素个数为0
}
};
#endif


使用了表达模板后的数组类型Array<T>如下:

//exprarray.hpp: 使用了表达模板后的数组类型Array<T>,及其运算符的实现,
//第二个参数用于传表达式模板(如A_Add),默认
//使用的是普通的数组实现SArray
#ifndef ARRAY_HPP
#define ARRAY_HPP
#include <cstddef>
#include <cassert>
#include "sarray.hpp"
#include "exprscalar.hpp"
#include "exprops.hpp"
template<typename T,typename Rep=SArray<T> >
class Array{
private:
Rep expr_rep; //持有的对象,可以是底层的SArray数组对象(持有实际的数组数据),
//也可以是A_Add,A_Mult这样的高层的表达式对象
public:
explicit Array(size_t s):expr_rep(s){ //创建具有初始大小的数组
}
Array(Rep const& rb):expr_rep(rb){ //根据其他可能的表示来创建数组
}
Array<T,Rep>& operator=(Array<T,Rep> const& b){ //针对相同类型的赋值运算符
assert(size()==b.size());
for(size_t idx=0;idx<b.size();++idx){
expr_rep[idx]=b[idx];
}
return *this;
}
template<typename T2,typename Rep2>
Array<T,Rep>& operator=(Array<T2,Rep2> const& b){ //针对不同类型的赋值运算符
assert(size()==b.size());
for(size_t idx=0;idx<b.size();++idx){
expr_rep[idx]=b[idx];
}
return *this;
}
size_t size() const{
return expr_rep.size();
}
T& operator[](size_t const idx){ //对Array变量对象进行下标运算符
assert(idx<size());
return expr_rep[idx];
}
T operator[](size_t const idx) const{ //对Array常量对象进行下标运算符
assert(idx<size());
return expr_rep[idx];
}
Rep& rep(){ //返回数组现在所表示的对象
return expr_rep;
}
Rep const& rep() const{
return expr_rep;
}
};
//实际运算符实现:两个数组相加
template<typename T,typename R1,typename R2>
Array<T,A_Add<T,R1,R2> >
operator+(Array<T,R1> const& a,Array<T,R2> const& b){
return Array<T,A_Add<T,R1,R2> >(A_Add<T,R1,R2>(a.rep(),b.rep()));
}
//两个数组相乘
template<typename T,typename R1,typename R2>
Array<T,A_Mult<T,R1,R2> >
operator*(Array<T,R1> const& a,Array<T,R2> const& b){
return Array<T,A_Mult<T,R1,R2> >(A_Mult<T,R1,R2>(a.rep(),b.rep()));
}
//scalar和数组相乘
template<typename T,typename R2>
Array<T,A_Mult<T,A_Scalar<T>,R2> >
operator*(T const& s,Array<T,R2> const& b){
return Array<T,A_Mult<T,A_Scalar<T>,R2> >(A_Mult<T,A_Scalar<T>,R2>(A_Scalar<T>(s),b.rep()));
}
//数组和scalar相乘
//scalar和数组相乘
//数组和scalar相加
#endif


//exprtest.cpp:测试代码
#include "exprarray.hpp"
int main(){
Array<double> x(1000),y(1000); //创建两个数组
x=1.2*x+x*y; //对数组进行计算
return 0;
}


这里我们对OP1及OP2封装了一个trait类,如果它是数组类型Array<T>,则使用const引用来创建成员op1(或op2),如果它是放大倍数类型A_Scalar<T>,则直接使用传值的方式来创建成员,这样实现会更高效一点,因为数组类型的引用无需再调用昂贵的拷贝构造函数。表达式模板对给定下标处的数组元素进行这个表达式所表示的实际计算,并通过下标运算符返回计算出来的结果。可见现在表达式(如加A_Add、乘A_Mult,以及把它们结合起来的复杂表达式)的计算最终被映射成对数组元素的实际计算,就像前面说的最终映射成x[i]=1.2*x[i]+x[i]*y[i]。这是关键,在前面的SArray中,每个表达式只是被映射到了数组对象上,其计算调用的是重载运算符,要对整个数组扫描一次。
数组类型Array<T>被设计成既可以使用高层的表达模板,也可以直接使用底层的普通SArray<T>。Rep对象expr_rep持有底层的实际数组数据,rep()返回这个对象,它要么是一个高层的表达式对象,要么是一个底层的SArray数组对象(持有实际的数组数据)。它的赋值运算符、下标运算符可以对这两种情况进行统一的处理。这主要是因为表达式模板的下标运算符和SArray<T>的下标运算符映射到的都是数组元素。现在数组类型的重载运算符就要用表达式模板来实现了。
下面结合测试例子来理解这些运算符。比如x*y,调用operator*的第一个版本,两个操作数x和y的类型都是Array<double,SArray<double> >,R1和R2都使用默认的SArray<double>。调用rep()返回x和y中各自的SArray<double>对象expr_rep,用这两个对象创建一个临时的A_Mult<double,SArray<double>,SArray<double> >类型的表达式对象,然后用这个表达式对象作为Rep对象创建一个要返回的数组类型对象,返回的数组类型为
Array<double,A_Mult<double,SArray<double>,SArray<double> > >
对1.2*x(调用opeator*的第二个版本)类似分析,返回一个Array<double,R3>型的数组对象,这里R3是一个A_Mult的表达式对象。然后两个表达式相加,返回一个Array<double,R4>型的数组对象,这里R4是一个A_Add的表达式对象。最后,调用那个针对不同类型的赋值运算符,把它赋值给x对象。
可见,与原来SArray中的重载运算符相比,现在Array的重载运算符没有进行任何的计算,它只是根据传进来的操作数(数组对象)构造一个与表达式模板相关的新的数组对象并返回,所有的计算都委托给了表达式模板。现在整个表达式x=1.2*x+x*y没有进行任何运算,每一步的运算都只是返回一个新的数组对象,最后是一个数组对象的赋值操作。当调用下标操作x[3]来访问元素时,就会使用表达式模板来进行实际的计算x[3]=1.2*x[3]+x[3]*y[3](跟踪一下下标运算符操作就知道了),以获得我们所需要的值,这是能实现高性能的关键所在。
总结出表达式模板技术的基本思想: 当类需要重载运算符以进行对象的直接运算时,为每个运算符提供一个表达式模板,并把类设计成与表达式模板相关(即把类设计成模板,提供一个模板参数来传递表达式)。重载运算符的实现中并不做实际的运算,而只是根据传进来的操作数构造一个该类型的新对象并返回即可。所有的运算都委托给对应的表达式模板来完成。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: