此为用C++写微分方程数值解实验的前置,目标是封装个多维数组,重载四则运算,实现一些矩阵的操作。
背景
实验需要求解线性方程组,我记得我之前写过线性方程组求解(C++版),然后看了一眼我以前写的什么东西,于是决定重新写。
本来只是想用array实现一个矩阵类型就行,最近比较喜欢用array
预想实现的效果是这样mat<int,2,2>就是一个2x2的矩阵。但是我突然灵机一动为什么不把维数作为可变模板参数,实现一个多维数组,矩阵只是2维数组的一个特例而已。
然后在想名字的时候,C++的数组叫array,在想多维数组的时候突然想起python的ndarray,于是就变成了ndarray<int,2,2,2>,看起来相当的完美。
封装多维数组
名字和写法想好了,接下来就是实现了。首先我们希望ndarray<int,1,2,3>的实际类型应该是几个array的复合,它实际上应该是这样的1
array<array<array<int,3>,2>,1>
我们可以通过可变模板参数来递归的构造出来,就是每次提取一个维度作为当前这一层array的长度,然后类型就是用剩下的参数继续构造的array。
下面是代码实现:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15template<typename T,size_t...Args>
struct ndarray_helper;
template<typename T,size_t N>
struct ndarray_helper<T,N>{
using type=std::array<T,N>;
};
template<typename T,size_t N,size_t...Args>
struct ndarray_helper<T,N,Args...>{
using type=std::array<typename ndarray_helper<T,Args...>::type,N>;
};
template<typename T, size_t...Args>
using ndarray=typename ndarray_helper<T,Args...>::type;
没有完全自己实现一个全新的类型,只是用模板来实现了一种方便的array嵌套类型,所以本质上还是array的一个别名。
因为别名不能重复定义,但是需要特化模板来终止递归,所以借助了struct来保存中间类型来进行递归,最后再将ndarray作为struct保存的类型的别名。
重载四则运算
四则运算既是矩阵的基本运算,又是为后面的初等变换做准备。
因为使用的是array,所以需要重载array的四则运算,这次就用ranges库来写。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24template<typename T,size_t N>
constexpr std::array<T,N> operator+(const std::array<T,N>&a,const std::array<T,N>&b){
std::array<T,N>c{};
std::ranges::transform(a,b,c.begin(),[](auto x,auto y){return x+y;});
return c;
}
template<typename T,size_t N>
constexpr std::array<T,N> operator-(const std::array<T,N>&a,const std::array<T,N>&b){
std::array<T,N>c{};
std::ranges::transform(a,b,c.begin(),[](auto x,auto y){return x-y;});
return c;
}
template<typename T,size_t N>
constexpr std::array<T,N> operator*(const std::array<T,N>&a,const std::array<T,N>&b){
std::array<T,N>c{};
std::ranges::transform(a,b,c.begin(),[](auto x,auto y){return x*y;});
return c;
}
template<typename T,size_t N>
constexpr std::array<T,N> operator/(const std::array<T,N>&a,const std::array<T,N>&b){
std::array<T,N>c{};
std::ranges::transform(a,b,c.begin(),[](auto x,auto y){return x/y;});
return c;
}
看起来很简单,就是一行transform的事情,然后复制四个就行了。显然这个重载对于嵌套的array也是通用的。
实际效果
然后再来个标量的乘除,虽然说是标量,但是我比较懒,不想检查标量,所以就对b一直递归的计算到能算为止,如果不能算的话就等它报错吧,起码下面的例子是能够计算标量乘每个元素的。1
2
3
4
5
6
7
8
9
10
11
12template<typename T1,typename T2,size_t N>
constexpr std::array<T1,N> operator*(const T1&a,const std::array<T2,N>&b){
std::array<T1,N>c{};
std::ranges::transform(b,c.begin(),[a](auto x){return a*x;});
return c;
}
template<typename T1,typename T2,size_t N>
constexpr std::array<T1,N> operator/(const std::array<T1,N>&a,const T1&b){
std::array<T1,N>c{};
std::ranges::transform(a,c.begin(),[b](auto x){return x/b;});
return c;
}
看看效果
多维数组拼接
接下来该见证极致的烧脑了!用过python的应该都知道,concat是拼接的意思,它可以指定拼接的维度,或者说轴。它拼接的规则是,只有要拼的那个维度的长度可以不同,其他维度的长度必须相同。举个例子1
2
3
4
5
6
7
8//可以拼接,每个维度的长度都一样
concat(ndarray<2,2,2>,ndarray<2,2,2>,1)
//可以拼接,只有第一个维度(从0开始)长度不同,其他都是2
concat(ndarray<2,1,2>,ndarray<2,2,2>,1)
//不行,其他维度的长度不一样
concat(ndarray<1,2,3>,ndarray<2,2,2>,1)
拼接之后的形状就是指定维度相加,其他维度不变,举个例子,下面简写1
2concat(<1,1,1>,<1,1,1>,0)=<2,1,1>
concat(<1,2,1>,<1,1,1>,1)=<1,3,1>
在举例子之前,先来规定一下,如果用二维来举例,第一个维度叫行,从上往下,长度是几就有几行,第二个维度叫列,从左往右,长度为几就有几列。下面是形象的例子1
2
3
4
5
6
7
8
9
10
11
12
13ndarray<2,2>A
0,1
2,3
concat(A,A,0)
0,1
2,3
0,1
2,3
concat(A,A,1)
0,1,0,1
2,3,2,3
所以我们先来把拼接后的类型确定一下,虽然我们能够用一句话描述出拼接后的类型就是ndarray<类型不变,指定维度相加其他维度不变>,但是ndarray只是别名,我们创建的变量类型是array的嵌套,所以我们还是需要像一开始写ndarray的时候一样递归的处理array的类型。
先写个主模板,它的参数显然就是两个数组的类型和一个维度,然后先处理最简单的情况,也就是拼接的维度为0的情况,这个也是我们的终止条件,为什么它是终止条件有点绕,可以先自己思考一下,下面先看代码1
2
3
4
5
6
7template<typename T1,typename T2,size_t I>
struct concat_helper;
template<typename T,size_t N1,size_t N2>
struct concat_helper<std::array<T,N1>,std::array<T,N2>,0>{
using type=std::array<T,N1+N2>;
};
根据规则,拼接的那个维度的长度可以不同,所以设置了两个模板参数N1,N2来匹配不同长度,但是其他长度要相同,也就是说后面的维度的长度是相同的,类型就都为T,然后0就是当前维度,或者说当前的最外层的array。
用二维的来举例子就是两个两行两列的数组拼成了一个四行两列的数组。数组能直接控制的就是它自己当前的维度的长度,也就是它控制不了列的长度。所以如果能够修改它的后一个维度呢?
如果两个数组要拼接后面的维度,起码说明前面的维度的长度是相等的,同为N,但是后面的类型可能不一样(要拼接的维度的长度可以不一样),所以有两个类型T1和T2,然后拼接之后是的类型应该是一个确定的类型(废话,如果拼接之后不是一个类型那就不会有这篇文章了),那个类型就是下一个维度拼接之后的类型,所以就可以写出递归的类型了1
2
3
4template<typename T1,typename T2,size_t N,size_t I>
struct concat_helper<std::array<T1,N>,std::array<T2,N>,I>{
using type=std::array<typename concat_helper<T1,T2,I-1>::type,N>;
};
然后仿照前面的,我们可以得到拼接后的类型1
2template<typename T1,typename T2,size_t I>
using concat_type=typename concat_helper<T1,T2,I>::type;
然后我迫不及待的去试了一下,果然报错了。因为I=0且N相等的情况可以匹配到两种特化,所以我们在第二种特化上加上条件I!=01
2
3
4
5template<typename T1,typename T2,size_t N,size_t I>
requires(I!=0)
struct concat_helper<std::array<T1,N>,std::array<T2,N>,I>{
using type=std::array<typename concat_helper<T1,T2,I-1>::type,N>;
};
现在类型有了,只差把数据搬到我们创建的新数组里了。先来写个函数,因为拼接的维度需要是编译期常数,所以不能把它作为函数参数,应该作为模板参数(实际使用中应该也不会存在运行时决定拼接位置的情况)。然后调整一下参数顺序,T1、T2可以自动推导,我们只需要指定拼接维度就可以了。1
2
3
4
5
6template<size_t I,typename T1,typename T2>
constexpr auto concat(T1 a,T2 b){
concat_type<T1,T2,I>c;
concat_helper<T1,T2,I>::copy(a,b,c);
return c;
}
先创建了c然后再把数据从a和b复制到c,复制显然也是递归的,等会需要实现我们的工具函数copy。还是先从最简单的情况开始1
2
3
4
5
6
7
8template<typename T,size_t N1,size_t N2>
struct concat_helper<std::array<T,N1>,std::array<T,N2>,0>{
using type=std::array<T,N1+N2>;
static constexpr void copy(std::array<T,N1>&a,std::array<T,N2>&b,type&c){
std::ranges::copy(a,c.begin());
std::ranges::copy(b,c.begin()+N1);
}
};
如果I=0,直接把a和b复制到c就可以了,因为是需要修改c,所以参数的类型是引用。接下来就是递归的情况
1 | template<typename T1,typename T2,size_t N,size_t I> |
如果I!=0,说明当前维度不是需要拼接的,但是长度应该是相等的,所以将a和b对应的每一行复制到c,因为用zip打包后遍历的类型是元组,所以需要写个lambda表达式用apply把元组的每个元素作为参数来应用函数。
下面来看看实现的效果
提取指定列
这个操作其实我还没有想好怎么设计,因为使用的时候需要提取矩阵的某一列,但是如果只能提取矩阵的那就只是维度为2的特例,所以先写了个简单的来用着,不多解释了1
2
3
4
5
6
7
8
9
10
11
12
13
14
15template<size_t I,typename T1,typename T2>
constexpr void getCols_helper(const T1&A,T2&B){}
template<size_t I,size_t C,size_t...Cols,typename T1,typename T2>
constexpr void getCols_helper(const T1&A,T2&B){
std::ranges::for_each(std::views::zip(A,B),[](auto&&p){auto&[a,b]=p;b[I]=a[C];});
getCols_helper<I+1,Cols...>(A,B);
}
template<size_t...Cols,typename T,size_t N,size_t M>
constexpr auto getCols(const std::array<std::array<T,M>,N>&A){
std::array<std::array<T,sizeof...(Cols)>,N>res{};
getCols_helper<0,Cols...>(A,res);
return res;
}
实际效果:
化为行最简矩阵
此为大坑,我是一边写的代码一边写的文章,在我写完这里之后我才发现问题,这还是个数值分析的问题。我写的是最朴素的高斯消元法,就是没有考虑到什么大数除小数的那种,实际写实验的时候应该是用不了的。所以我最后还是用会matlab把我的实验补了,这个东西就等我明年重修数值分析的时候再补吧。