Skip to content

Commit 1b26f0e

Browse files
committed
Refactor in-place division to use TypePairDefinedEntry
This makes the code easier to understand
1 parent 2588eb5 commit 1b26f0e

File tree

1 file changed

+34
-48
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+34
-48
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp

Lines changed: 34 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -439,52 +439,45 @@ template <typename argT, typename resT> struct TrueDivideInplaceFunctor
439439
}
440440
};
441441

442-
// cannot use the out of place table, as it permits real lhs and complex rhs
443-
// T1 corresponds to the type of the rhs, while T2 corresponds to the lhs
444-
// the type of the result must be the same as T2
445-
template <typename T1, typename T2> struct TrueDivideInplaceOutputType
442+
/* @brief Types supported by in-place divide */
443+
template <typename argTy, typename resTy>
444+
struct TrueDivideInplaceTypePairSupport
446445
{
447-
using value_type = typename std::disjunction< // disjunction is C++17
448-
// feature, supported by DPC++
449-
td_ns::BinaryTypeMapResultEntry<T1,
450-
sycl::half,
451-
T2,
452-
sycl::half,
453-
sycl::half>,
454-
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
455-
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
456-
td_ns::BinaryTypeMapResultEntry<T1,
457-
std::complex<float>,
458-
T2,
459-
std::complex<float>,
460-
std::complex<float>>,
461-
td_ns::BinaryTypeMapResultEntry<T1,
462-
float,
463-
T2,
464-
std::complex<float>,
465-
std::complex<float>>,
466-
td_ns::BinaryTypeMapResultEntry<T1,
467-
std::complex<double>,
468-
T2,
469-
std::complex<double>,
470-
std::complex<double>>,
471-
td_ns::BinaryTypeMapResultEntry<T1,
472-
double,
473-
T2,
474-
std::complex<double>,
475-
std::complex<double>>,
476-
td_ns::DefaultResultEntry<void>>::result_type;
446+
447+
/* value if true a kernel for <argTy, resTy> must be instantiated */
448+
static constexpr bool is_defined = std::disjunction< // disjunction is C++17
449+
// feature, supported
450+
// by DPC++ input bool
451+
td_ns::TypePairDefinedEntry<argTy, sycl::half, resTy, sycl::half>,
452+
td_ns::TypePairDefinedEntry<argTy, float, resTy, float>,
453+
td_ns::TypePairDefinedEntry<argTy, double, resTy, double>,
454+
td_ns::TypePairDefinedEntry<argTy, float, resTy, std::complex<float>>,
455+
td_ns::TypePairDefinedEntry<argTy,
456+
std::complex<float>,
457+
resTy,
458+
std::complex<float>>,
459+
td_ns::TypePairDefinedEntry<argTy, double, resTy, std::complex<double>>,
460+
td_ns::TypePairDefinedEntry<argTy,
461+
std::complex<double>,
462+
resTy,
463+
std::complex<double>>,
464+
// fall-through
465+
td_ns::NotDefinedEntry>::is_defined;
477466
};
478467

479-
template <typename fnT, typename T1, typename T2>
468+
template <typename fnT, typename argT, typename resT>
480469
struct TrueDivideInplaceTypeMapFactory
481470
{
482471
/*! @brief get typeid for output type of divide(T1 x, T2 y) */
483472
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
484473
{
485-
using rT = typename TrueDivideInplaceOutputType<T1, T2>::value_type;
486-
static_assert(std::is_same_v<rT, T2> || std::is_same_v<rT, void>);
487-
return td_ns::GetTypeid<rT>{}.get();
474+
if constexpr (TrueDivideInplaceTypePairSupport<argT, resT>::is_defined)
475+
{
476+
return td_ns::GetTypeid<resT>{}.get();
477+
}
478+
else {
479+
return td_ns::GetTypeid<void>{}.get();
480+
}
488481
}
489482
};
490483

@@ -537,10 +530,7 @@ struct TrueDivideInplaceContigFactory
537530
{
538531
fnT get()
539532
{
540-
if constexpr (std::is_same_v<typename TrueDivideInplaceOutputType<
541-
T1, T2>::value_type,
542-
void>)
543-
{
533+
if constexpr (!TrueDivideInplaceTypePairSupport<T1, T2>::is_defined) {
544534
fnT fn = nullptr;
545535
return fn;
546536
}
@@ -579,10 +569,7 @@ struct TrueDivideInplaceStridedFactory
579569
{
580570
fnT get()
581571
{
582-
if constexpr (std::is_same_v<typename TrueDivideInplaceOutputType<
583-
T1, T2>::value_type,
584-
void>)
585-
{
572+
if constexpr (!TrueDivideInplaceTypePairSupport<T1, T2>::is_defined) {
586573
fnT fn = nullptr;
587574
return fn;
588575
}
@@ -627,8 +614,7 @@ struct TrueDivideInplaceRowMatrixBroadcastFactory
627614
{
628615
fnT get()
629616
{
630-
using resT = typename TrueDivideInplaceOutputType<T1, T2>::value_type;
631-
if constexpr (!std::is_same_v<resT, T2>) {
617+
if constexpr (!TrueDivideInplaceTypePairSupport<T1, T2>::is_defined) {
632618
fnT fn = nullptr;
633619
return fn;
634620
}

0 commit comments

Comments
 (0)