@@ -439,52 +439,45 @@ template <typename argT, typename resT> struct TrueDivideInplaceFunctor
439439 }
440440};
441441
442- // cannot use the out of place table, as it permits real lhs and complex rhs
443- // T1 corresponds to the type of the rhs, while T2 corresponds to the lhs
444- // the type of the result must be the same as T2
445- template <typename T1, typename T2> struct TrueDivideInplaceOutputType
442+ /* @brief Types supported by in-place divide */
443+ template <typename argTy, typename resTy>
444+ struct TrueDivideInplaceTypePairSupport
446445{
447- using value_type = typename std::disjunction< // disjunction is C++17
448- // feature, supported by DPC++
449- td_ns::BinaryTypeMapResultEntry<T1,
450- sycl::half,
451- T2,
452- sycl::half,
453- sycl::half>,
454- td_ns::BinaryTypeMapResultEntry<T1, float , T2, float , float >,
455- td_ns::BinaryTypeMapResultEntry<T1, double , T2, double , double >,
456- td_ns::BinaryTypeMapResultEntry<T1,
457- std::complex <float >,
458- T2,
459- std::complex <float >,
460- std::complex <float >>,
461- td_ns::BinaryTypeMapResultEntry<T1,
462- float ,
463- T2,
464- std::complex <float >,
465- std::complex <float >>,
466- td_ns::BinaryTypeMapResultEntry<T1,
467- std::complex <double >,
468- T2,
469- std::complex <double >,
470- std::complex <double >>,
471- td_ns::BinaryTypeMapResultEntry<T1,
472- double ,
473- T2,
474- std::complex <double >,
475- std::complex <double >>,
476- td_ns::DefaultResultEntry<void >>::result_type;
446+
447+ /* value if true a kernel for <argTy, resTy> must be instantiated */
448+ static constexpr bool is_defined = std::disjunction< // disjunction is C++17
449+ // feature, supported
450+ // by DPC++ input bool
451+ td_ns::TypePairDefinedEntry<argTy, sycl::half, resTy, sycl::half>,
452+ td_ns::TypePairDefinedEntry<argTy, float , resTy, float >,
453+ td_ns::TypePairDefinedEntry<argTy, double , resTy, double >,
454+ td_ns::TypePairDefinedEntry<argTy, float , resTy, std::complex <float >>,
455+ td_ns::TypePairDefinedEntry<argTy,
456+ std::complex <float >,
457+ resTy,
458+ std::complex <float >>,
459+ td_ns::TypePairDefinedEntry<argTy, double , resTy, std::complex <double >>,
460+ td_ns::TypePairDefinedEntry<argTy,
461+ std::complex <double >,
462+ resTy,
463+ std::complex <double >>,
464+ // fall-through
465+ td_ns::NotDefinedEntry>::is_defined;
477466};
478467
479- template <typename fnT, typename T1 , typename T2 >
468+ template <typename fnT, typename argT , typename resT >
480469struct TrueDivideInplaceTypeMapFactory
481470{
482471 /* ! @brief get typeid for output type of divide(T1 x, T2 y) */
483472 std::enable_if_t <std::is_same<fnT, int >::value, int > get ()
484473 {
485- using rT = typename TrueDivideInplaceOutputType<T1, T2>::value_type;
486- static_assert (std::is_same_v<rT, T2> || std::is_same_v<rT, void >);
487- return td_ns::GetTypeid<rT>{}.get ();
474+ if constexpr (TrueDivideInplaceTypePairSupport<argT, resT>::is_defined)
475+ {
476+ return td_ns::GetTypeid<resT>{}.get ();
477+ }
478+ else {
479+ return td_ns::GetTypeid<void >{}.get ();
480+ }
488481 }
489482};
490483
@@ -537,10 +530,7 @@ struct TrueDivideInplaceContigFactory
537530{
538531 fnT get ()
539532 {
540- if constexpr (std::is_same_v<typename TrueDivideInplaceOutputType<
541- T1, T2>::value_type,
542- void >)
543- {
533+ if constexpr (!TrueDivideInplaceTypePairSupport<T1, T2>::is_defined) {
544534 fnT fn = nullptr ;
545535 return fn;
546536 }
@@ -579,10 +569,7 @@ struct TrueDivideInplaceStridedFactory
579569{
580570 fnT get ()
581571 {
582- if constexpr (std::is_same_v<typename TrueDivideInplaceOutputType<
583- T1, T2>::value_type,
584- void >)
585- {
572+ if constexpr (!TrueDivideInplaceTypePairSupport<T1, T2>::is_defined) {
586573 fnT fn = nullptr ;
587574 return fn;
588575 }
@@ -627,8 +614,7 @@ struct TrueDivideInplaceRowMatrixBroadcastFactory
627614{
628615 fnT get ()
629616 {
630- using resT = typename TrueDivideInplaceOutputType<T1, T2>::value_type;
631- if constexpr (!std::is_same_v<resT, T2>) {
617+ if constexpr (!TrueDivideInplaceTypePairSupport<T1, T2>::is_defined) {
632618 fnT fn = nullptr ;
633619 return fn;
634620 }
0 commit comments