7#ifndef HEFFTE_STOCK_VEC_TYPES_H
8#define HEFFTE_STOCK_VEC_TYPES_H
10#include "heffte_config.h"
12#ifdef Heffte_ENABLE_AVX
21using is_float = std::is_same<float, typename std::remove_cv<T>::type>;
24using is_double = std::is_same<double, typename std::remove_cv<T>::type>;
27using is_fcomplex = std::is_same<std::complex<float>,
typename std::remove_cv<T>::type>;
30using is_dcomplex = std::is_same<std::complex<double>,
typename std::remove_cv<T>::type>;
34 static constexpr bool value = is_float<T>::value || is_double<T>::value;
39 static constexpr bool value = is_fcomplex<T>::value || is_dcomplex<T>::value;
43template<
typename T,
int N>
struct pack {};
45template<>
struct pack<float, 1> {
using type = std::complex<float>; };
47template<>
struct pack<double, 1> {
using type = std::complex<double>; };
55template<
typename F,
int L>
61template<
typename F,
int L>
62inline typename pack<F,L>::type mm_load(F
const *src) {
return typename pack<F,L>::type {src[0], src[1]}; }
67template<
typename F,
int L>
69 dest[0] = src.real(); dest[1] = src.imag();
75template<
typename F,
int L>
81template<
typename F,
int L>
87template<
typename F,
int L>
88inline typename pack<F,L>::type mm_complex_load(std::complex<F>
const *src) {
return *src; }
93template<
typename F,
int L>
94inline typename pack<F,L>::type mm_complex_load(std::complex<F>
const *src,
int) {
return *src; }
159#ifdef Heffte_ENABLE_AVX
162template<>
struct pack<double, 2> {
using type = __m128d; };
164template<>
struct pack<float, 4> {
using type = __m128; };
166template<>
struct pack<double, 4> {
using type = __m256d; };
168template<>
struct pack<float, 8> {
using type = __m256; };
180inline typename pack<float, 4>::type mm_load<float, 4>(
float const *src) {
return _mm_loadu_ps(src); }
184inline void mm_store<float, 4>(
float *dest,
pack<float, 4>::type const &src) { _mm_storeu_ps(dest, src); }
188inline typename pack<float, 4>::type mm_pair_set<float, 4>(
float x,
float y) {
return _mm_setr_ps(x, y, x, y); }
196inline typename pack<float, 4>::type mm_complex_load<float,4>(std::complex<float>
const *src,
int stride) {
197 return _mm_setr_ps(src[0].real(), src[0].imag(), src[stride].real(), src[stride].imag());
202 return mm_complex_load<float,4>(src, 1);
215inline typename pack<float, 8>::type mm_load<float, 8>(
float const *src) {
return _mm256_loadu_ps(src); }
219inline void mm_store<float, 8>(
float *dest,
pack<float, 8>::type const &src) { _mm256_storeu_ps(dest, src); }
223inline typename pack<float, 8>::type mm_pair_set<float, 8>(
float x,
float y) {
return _mm256_setr_ps(x, y, x, y, x, y, x, y); }
231inline typename pack<float, 8>::type mm_complex_load<float, 8>(std::complex<float>
const *src,
int stride) {
232 return _mm256_setr_ps(src[0*stride].real(), src[0*stride].imag(),
233 src[1*stride].real(), src[1*stride].imag(),
234 src[2*stride].real(), src[2*stride].imag(),
235 src[3*stride].real(), src[3*stride].imag());
240 return mm_complex_load<float,8>(src, 1);
253inline typename pack<double, 2>::type mm_load<double, 2>(
double const *src) {
return _mm_loadu_pd(src); }
257inline void mm_store<double, 2>(
double *dest,
pack<double, 2>::type const &src) { _mm_storeu_pd(dest, src); }
261inline typename pack<double, 2>::type mm_pair_set<double, 2>(
double x,
double y) {
return _mm_setr_pd(x, y); }
269inline typename pack<double,2>::type mm_complex_load<double,2>(std::complex<double>
const *src,
int) {
270 return _mm_setr_pd(src[0].real(), src[0].imag());
273inline typename pack<double,2>::type mm_complex_load<double,2>(std::complex<double>
const *src) {
274 return mm_complex_load<double,2>(src, 1);
287inline typename pack<double, 4>::type mm_load<double, 4>(
double const *src) {
return _mm256_loadu_pd(src); }
291inline void mm_store<double, 4>(
double *dest,
pack<double, 4>::type const &src) { _mm256_storeu_pd(dest, src); }
295inline typename pack<double, 4>::type mm_pair_set<double, 4>(
double x,
double y) {
return _mm256_setr_pd(x, y, x, y); }
303inline typename pack<double,4>::type mm_complex_load<double,4>(std::complex<double>
const *src,
int stride) {
304 return _mm256_setr_pd(src[0].real(), src[0].imag(), src[stride].real(), src[stride].imag());
308inline typename pack<double,4>::type mm_complex_load<double,4>(std::complex<double>
const *src) {
309 return mm_complex_load<double,4>(src, 1);
320 return _mm_add_ps(x, y);
325 return _mm256_add_ps(x, y);
330 return _mm_add_pd(x, y);
335 return _mm256_add_pd(x, y);
342 return _mm_sub_ps(x, y);
347 return _mm256_sub_ps(x, y);
352 return _mm_sub_pd(x, y);
357 return _mm256_sub_pd(x, y);
364 return _mm_mul_ps(x, y);
369 return _mm256_mul_ps(x, y);
374 return _mm_mul_pd(x, y);
379 return _mm256_mul_pd(x, y);
386 return _mm_div_ps(x, y);
391 return _mm256_div_ps(x, y);
396 return _mm_div_pd(x, y);
401 return _mm256_div_pd(x, y);
408 return _mm_xor_ps(x, (mm_set1<float, 4>(-0.f)));
413 return _mm256_xor_ps(x, (mm_set1<float, 8>(-0.f)));
418 return _mm_xor_pd(x, (mm_set1<double, 2>(-0.)));
423 return _mm256_xor_pd(x, (mm_set1<double, 4>(-0.)));
558 return _mm_or_ps(_mm_dp_ps(x, x, 0b11001100), _mm_dp_ps(x, x, 0b00110011));
563 return _mm256_or_ps(_mm256_dp_ps(x, x, 0b11001100), _mm256_dp_ps(x, x, 0b00110011));
568 return _mm_dp_pd(x, x, 0b11111111);
574 return _mm256_hadd_pd(a, a);
581 return _mm_sqrt_ps(mm_complex_sq_mod(x));
586 return _mm256_sqrt_ps(mm_complex_sq_mod(x));
591 return _mm_sqrt_pd(mm_complex_sq_mod(x));
596 return _mm256_sqrt_pd(mm_complex_sq_mod(x));
601 return _mm_blend_ps(x, (mm_neg(x)), 0b1010);
606 return _mm256_blend_ps(x, (mm_neg(x)), 0b10101010);
611 return _mm_blend_pd(x, (mm_neg(x)), 0b10);
616 return _mm256_blend_pd(x, (mm_neg(x)), 0b1010);
622 return _mm_permute_ps( (mm_complex_conj(x)), 0b10110001);
627 return _mm256_permute_ps( (mm_complex_conj(x)), 0b10110001);
632 return _mm_permute_pd( (mm_complex_conj(x)), 0b00000001);
637 return _mm256_permute_pd( (mm_complex_conj(x)), 0b00000101);
642 return mm_complex_conj(_mm_permute_ps(x, 0b10110001));
647 return mm_complex_conj(_mm256_permute_ps(x, 0b10110001));
652 return mm_complex_conj(_mm_permute_pd(x, 0b0000001));
657 return mm_complex_conj(_mm256_permute_pd(x, 0b00000101));
664 return _mm_div_ps(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
669 return _mm256_div_ps(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
674 return _mm_div_pd(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
679 return _mm256_div_pd(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
683#ifdef Heffte_ENABLE_AVX512
686template<>
struct pack<double, 8> {
using type = __m512d; };
688template<>
struct pack<float, 16> {
using type = __m512; };
700inline typename pack<float, 16>::type mm_load<float, 16>(
float const *src) {
return _mm512_loadu_ps(src); }
704inline void mm_store<float, 16>(
float *dest,
pack<float, 16>::type const &src) { _mm512_storeu_ps(dest, src); }
708inline typename pack<float, 16>::type mm_pair_set<float, 16>(
float x,
float y) {
return _mm512_setr_ps(x, y, x, y, x, y, x, y, x, y, x, y, x, y, x, y); }
716inline typename pack<float, 16>::type mm_complex_load<float, 16>(std::complex<float>
const *src,
int stride) {
717 return _mm512_setr_ps(src[0*stride].real(), src[0*stride].imag(), src[1*stride].real(), src[1*stride].imag(),
718 src[2*stride].real(), src[2*stride].imag(), src[3*stride].real(), src[3*stride].imag(),
719 src[4*stride].real(), src[4*stride].imag(), src[5*stride].real(), src[5*stride].imag(),
720 src[6*stride].real(), src[6*stride].imag(), src[7*stride].real(), src[7*stride].imag());
726 return mm_complex_load<float, 16>(src, 1);
739inline typename pack<double, 8>::type mm_load<double, 8>(
double const *src) {
return _mm512_loadu_pd(src); }
743inline void mm_store<double, 8>(
double *dest,
pack<double, 8>::type const &src) { _mm512_storeu_pd(dest, src); }
747inline typename pack<double, 8>::type mm_pair_set<double, 8>(
double x,
double y) {
return _mm512_setr_pd(x, y, x, y, x, y, x, y); }
755inline typename pack<double, 8>::type mm_complex_load<double, 8>(std::complex<double>
const *src,
int stride) {
756 return _mm512_setr_pd(src[0*stride].real(), src[0*stride].imag(), src[1*stride].real(), src[1*stride].imag(),
757 src[2*stride].real(), src[2*stride].imag(), src[3*stride].real(), src[3*stride].imag());
762 return mm_complex_load<double, 8>(src, 1);
773 return _mm512_add_ps(x, y);
778 return _mm512_add_pd(x, y);
785 return _mm512_sub_ps(x, y);
790 return _mm512_sub_pd(x, y);
797 return _mm512_mul_ps(x, y);
802 return _mm512_mul_pd(x, y);
809 return _mm512_div_ps(x, y);
814 return _mm512_div_pd(x, y);
820 return _mm512_xor_ps(x, (mm_set1<float, 16>(-0.f)));
825 return _mm512_xor_pd(x, (mm_set1<double, 8>(-0.f)));
918 return _mm512_sqrt_ps(mm_complex_sq_mod(x));
923 return _mm512_sqrt_pd(mm_complex_sq_mod(x));
930 return _mm512_mask_blend_ps(0b1010101010101010, x, mm_neg(x));
935 return _mm512_mask_blend_pd(0b10101010, x, mm_neg(x));
941 return _mm512_permute_ps( (mm_complex_conj(x)), 0b10110001);
946 return _mm512_permute_pd( (mm_complex_conj(x)), 0b01010101);
951 return mm_complex_conj(_mm512_permute_ps(x, 0b10110001));
956 return mm_complex_conj(_mm512_permute_pd(x, 0b01010101));
963 return _mm512_div_ps(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
968 return _mm512_div_pd(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
Namespace containing all HeFFTe methods and classes.
Definition heffte_backend_cuda.h:38
Struct determining whether a type is a complex number.
Definition heffte_stock_vec_types.h:38
Struct determining whether a type is a real number.
Definition heffte_stock_vec_types.h:33
Struct to retrieve the vector type associated with the number of elements stored "per unit".
Definition heffte_stock_vec_types.h:43