Commit 23263842 authored by Matthias Werner's avatar Matthias Werner

simplified constants (use cufftType directly with constexpr).

parent 860798d4
......@@ -13,15 +13,14 @@
namespace gearshifft {
namespace CuFFT {
namespace traits{
// @todo simplify constants
template<typename T_Precision=float>
struct Types
{
using ComplexType = cufftComplex;
using RealType = cufftReal;
struct FFTForward: std::integral_constant< cufftType, CUFFT_R2C >{};
struct FFTComplex: std::integral_constant< cufftType, CUFFT_C2C >{};
struct FFTBackward: std::integral_constant< cufftType, CUFFT_C2R >{};
static constexpr cufftType FFTForward = CUFFT_R2C;
static constexpr cufftType FFTComplex = CUFFT_C2C;
static constexpr cufftType FFTBackward = CUFFT_C2R;
struct FFTExecuteForward{
void operator()(cufftHandle plan, RealType* in, ComplexType* out){
......@@ -46,9 +45,9 @@ namespace CuFFT {
{
using ComplexType = cufftDoubleComplex;
using RealType = cufftDoubleReal;
struct FFTForward: std::integral_constant< cufftType, CUFFT_D2Z >{};
struct FFTComplex: std::integral_constant< cufftType, CUFFT_Z2Z >{};
struct FFTBackward: std::integral_constant< cufftType, CUFFT_Z2D >{};
static constexpr cufftType FFTForward = CUFFT_D2Z;
static constexpr cufftType FFTComplex = CUFFT_Z2Z;
static constexpr cufftType FFTBackward = CUFFT_Z2D;
struct FFTExecuteForward{
void operator()(cufftHandle plan, RealType* in, ComplexType* out){
......@@ -74,38 +73,35 @@ namespace CuFFT {
* Estimates memory reserved by cufft plan depending on FFT transform type
* (CUFFT_R2C, ...) and depending on number of dimensions {1,2,3}.
*/
template<typename FFTType, size_t NDim>
template<cufftType FFTType, size_t NDim>
size_t estimateAllocSize(const std::array<unsigned,NDim>& e, cufftHandle& plan)
{
size_t s=0;
if(NDim==1){
// CHECK_CUFFT( cufftEstimate1d(e[0], FFTType::value, 1, &s) );
CHECK_CUFFT( cufftGetSize1d(plan, e[0], FFTType::value, 1, &s) );
CHECK_CUFFT( cufftGetSize1d(plan, e[0], FFTType, 1, &s) );
}
if(NDim==2){
// CHECK_CUFFT( cufftEstimate2d(e[0], e[1], FFTType::value, &s) );
CHECK_CUFFT( cufftGetSize2d(plan, e[0], e[1], FFTType::value, &s) );
CHECK_CUFFT( cufftGetSize2d(plan, e[0], e[1], FFTType, &s) );
}
if(NDim==3){
// CHECK_CUFFT( cufftEstimate3d(e[0], e[1], e[2], FFTType::value, &s) );
CHECK_CUFFT( cufftGetSize3d(plan, e[0], e[1], e[2], FFTType::value, &s) );
CHECK_CUFFT( cufftGetSize3d(plan, e[0], e[1], e[2], FFTType, &s) );
}
return s;
}
/**
* Plan Creator depending on FFT transform type (CUFFT_R2C, ...).
*/
template<typename FFTType>
template<cufftType FFTType>
void makePlan(cufftHandle& plan, const std::array<unsigned,3>& e){
CHECK_CUFFT(cufftPlan3d(&plan, e[0], e[1], e[2], FFTType::value));
CHECK_CUFFT(cufftPlan3d(&plan, e[0], e[1], e[2], FFTType));
}
template<typename FFTType>
template<cufftType FFTType>
void makePlan(cufftHandle& plan, const std::array<unsigned,1>& e){
CHECK_CUFFT(cufftPlan1d(&plan, e[0], FFTType::value, 1));
CHECK_CUFFT(cufftPlan1d(&plan, e[0], FFTType, 1));
}
template<typename FFTType>
template<cufftType FFTType>
void makePlan(cufftHandle& plan, const std::array<unsigned,2>& e){
CHECK_CUFFT(cufftPlan2d(&plan, e[0], e[1], FFTType::value));
CHECK_CUFFT(cufftPlan2d(&plan, e[0], e[1], FFTType));
}
......@@ -121,20 +117,26 @@ namespace CuFFT {
>
struct CuFFTImpl
{
using Types = typename traits::Types<TPrecision>;
using Extent = std::array<unsigned,NDim>;
using Types = typename traits::Types<TPrecision>;
using ComplexType = typename Types::ComplexType;
using RealType = typename Types::RealType;
using Extent = std::array<unsigned,NDim>;
static constexpr auto IsInplace = TFFT::IsInplace;
static constexpr auto IsComplex = TFFT::IsComplex;
// @todo NDim>1 remove
static constexpr auto Padding = IsInplace && IsComplex==false && NDim>1;
using RealOrComplexType = typename std::conditional<IsComplex,ComplexType,RealType>::type;
using FFTForward = typename std::conditional_t<IsComplex, typename Types::FFTComplex, typename Types::FFTForward>;
using FFTBackward = typename std::conditional_t<IsComplex, typename Types::FFTComplex, typename Types::FFTBackward>;
using FFTExecuteForward = typename Types::FFTExecuteForward;
using FFTExecuteBackward = typename Types::FFTExecuteBackward;
static constexpr
bool IsInplace = TFFT::IsInplace;
static constexpr
bool IsComplex = TFFT::IsComplex;
static constexpr
bool Padding = IsInplace && IsComplex==false && NDim>1;
static constexpr
cufftType FFTForward = IsComplex ? Types::FFTComplex : Types::FFTForward;
static constexpr
cufftType FFTBackward = IsComplex ? Types::FFTComplex : Types::FFTBackward;
using RealOrComplexType = typename std::conditional<IsComplex,ComplexType,RealType>::type;
size_t n_; // =[1]*..*[dim]
size_t n_padded_; // =[1]*..*[dim-1]*([dim]/2+1)
Extent extents_;
......@@ -145,7 +147,6 @@ namespace CuFFT {
size_t data_size_;
size_t data_transform_size_;
CuFFTImpl(const Extent& cextents)
: extents_(cextents)
{
......@@ -166,6 +167,7 @@ namespace CuFFT {
size_t getAllocSize() {
return data_size_ + data_transform_size_;
}
/**
* Returns estimated allocated memory on device for FFT plan
*/
......@@ -208,9 +210,11 @@ namespace CuFFT {
void execute_forward() {
FFTExecuteForward()(plan_, data_, data_transform_);
}
void execute_backward() {
FFTExecuteBackward()(plan_, data_transform_, data_);
}
template<typename THostData>
void upload(THostData* input) {
if(Padding) // real + inplace + ndim>1
......@@ -223,6 +227,7 @@ namespace CuFFT {
CHECK_ERROR(cudaMemcpy(data_, input, data_size_, cudaMemcpyHostToDevice));
}
}
template<typename THostData>
void download(THostData* output) {
if(Padding) // real + inplace + ndim>1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment