@@ -2868,26 +2868,20 @@ def dpnp_solve(a, b):
28682868 a_usm_arr = dpnp .get_usm_ndarray (a )
28692869 b_usm_arr = dpnp .get_usm_ndarray (b )
28702870
2871- # Due to MKLD-17226 (bug with incorrect checking ldb parameter
2872- # in oneapi::mkl::lapack::gesv_scratchad_size that raises an error
2873- # `invalid argument` when nrhs > n) we can not use _gesv directly.
2874- # This w/a uses _getrf and _getrs instead
2875- # to handle cases where nrhs > n for a.shape = (n x n)
2876- # and b.shape = (n x nrhs).
2877-
2878- # oneMKL LAPACK getrf overwrites `a`.
2879- a_h = dpnp .empty_like (a , order = "C" , dtype = res_type , usm_type = res_usm_type )
2871+ # oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as
2872+ # input
2873+ a_h = dpnp .empty_like (a , order = "F" , dtype = res_type , usm_type = res_usm_type )
28802874
28812875 _manager = dpu .SequentialOrderManager [exec_q ]
2882- dev_evs = _manager .submitted_events
2876+ dep_evs = _manager .submitted_events
28832877
28842878 # use DPCTL tensor function to fill the сopy of the input array
28852879 # from the input array
28862880 ht_ev , a_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
28872881 src = a_usm_arr ,
28882882 dst = a_h .get_array (),
28892883 sycl_queue = a .sycl_queue ,
2890- depends = dev_evs ,
2884+ depends = dep_evs ,
28912885 )
28922886 _manager .add_event_pair (ht_ev , a_copy_ev )
28932887
@@ -2903,43 +2897,18 @@ def dpnp_solve(a, b):
29032897 src = b_usm_arr ,
29042898 dst = b_h .get_array (),
29052899 sycl_queue = b .sycl_queue ,
2906- depends = dev_evs ,
2900+ depends = dep_evs ,
29072901 )
29082902 _manager .add_event_pair (ht_ev , b_copy_ev )
29092903
2910- n = a .shape [0 ]
2911-
2912- ipiv_h = dpnp .empty_like (
2913- a ,
2914- shape = (n ,),
2915- dtype = dpnp .int64 ,
2904+ # Call the LAPACK extension function _gesv to solve the system of linear
2905+ # equations with the coefficient square matrix and
2906+ # the dependent variables array
2907+ ht_lapack_ev , gesv_ev = li ._gesv (
2908+ exec_q , a_h .get_array (), b_h .get_array (), [a_copy_ev , b_copy_ev ]
29162909 )
2917- dev_info_h = [0 ]
29182910
2919- # Call the LAPACK extension function _getrf
2920- # to perform LU decomposition of the input matrix
2921- ht_ev , getrf_ev = li ._getrf (
2922- exec_q ,
2923- a_h .get_array (),
2924- ipiv_h .get_array (),
2925- dev_info_h ,
2926- depends = [a_copy_ev ],
2927- )
2928- _manager .add_event_pair (ht_ev , getrf_ev )
2929-
2930- _check_lapack_dev_info (dev_info_h )
2931-
2932- # Call the LAPACK extension function _getrs
2933- # to solve the system of linear equations with an LU-factored
2934- # coefficient square matrix, with multiple right-hand sides.
2935- ht_ev , getrs_ev = li ._getrs (
2936- exec_q ,
2937- a_h .get_array (),
2938- ipiv_h .get_array (),
2939- b_h .get_array (),
2940- depends = [b_copy_ev , getrf_ev ],
2941- )
2942- _manager .add_event_pair (ht_ev , getrs_ev )
2911+ _manager .add_event_pair (ht_lapack_ev , gesv_ev )
29432912 return b_h
29442913
29452914
0 commit comments