@@ -105,4 +105,195 @@ get_mad_kernel(sycl::queue &q, size_t n, T *in1, T *in2, T *out, scT val)
105105 return program.get_kernel <mad_kern<T, scT>>();
106106};
107107
108+ template <typename name,
109+ typename localAccessorT,
110+ class KernelFuncArgs ,
111+ class KernelFunctor >
112+ auto make_cgh_nd_function_with_local_memory (const sycl::nd_range<1 > &nd_range,
113+ size_t slm_size,
114+ KernelFuncArgs kern_params)
115+ {
116+ auto Kernel = [&](sycl::handler &cgh) {
117+ localAccessorT lm (slm_size, cgh);
118+ cgh.parallel_for <name>(nd_range, KernelFunctor (kern_params, lm));
119+ };
120+ return Kernel;
121+ };
122+
123+ template <typename name, class KernelFunctor >
124+ auto make_cgh_nd_function (const sycl::nd_range<1 > &nd_range, KernelFunctor kern)
125+ {
126+ auto Kernel = [&](sycl::handler &cgh) {
127+ cgh.parallel_for <name>(nd_range, kern);
128+ };
129+ return Kernel;
130+ };
131+
132+ template <typename T> struct LocalSortArgs
133+ {
134+ T *arr;
135+ size_t global_array_size;
136+ size_t wg_chunk_size;
137+ LocalSortArgs (T *arr, size_t arr_len, size_t wg_len)
138+ : arr(arr), global_array_size(arr_len), wg_chunk_size(wg_len)
139+ {
140+ }
141+ ~LocalSortArgs () {}
142+
143+ T *get_array_pointer () const
144+ {
145+ return arr;
146+ }
147+ size_t get_array_size () const
148+ {
149+ return global_array_size;
150+ }
151+ size_t get_chunk_size () const
152+ {
153+ return wg_chunk_size;
154+ }
155+ };
156+
157+ template <typename T, typename localAccessorT> struct LocalSortFunc
158+ {
159+ /*
160+
161+ */
162+ T *arr;
163+ size_t global_array_size;
164+ size_t wg_chunk_size;
165+ localAccessorT lm;
166+ LocalSortFunc (T *arr, size_t arr_len, size_t wg_len, localAccessorT lm)
167+ : arr(arr), global_array_size(arr_len), wg_chunk_size(wg_len), lm(lm)
168+ {
169+ }
170+ template <class paramsT >
171+ LocalSortFunc (paramsT params, localAccessorT lm)
172+ : arr(params.get_array_pointer()),
173+ global_array_size (params.get_array_size()),
174+ wg_chunk_size(params.get_chunk_size()), lm(lm)
175+ {
176+ }
177+ ~LocalSortFunc () {}
178+ void operator ()(sycl::nd_item<1 > item) const
179+ {
180+ /* Use odd-even merge sort to sort lws chunk of array */
181+ size_t group_id = item.get_group_linear_id ();
182+ size_t chunk_size =
183+ sycl::min ((group_id + 1 ) * wg_chunk_size, global_array_size) -
184+ group_id * wg_chunk_size;
185+
186+ // compute the greatest power of 2 less than chunk_size
187+ size_t sp2 = 1 ;
188+ while (sp2 < chunk_size) {
189+ sp2 <<= 1 ;
190+ }
191+ sp2 >>= 1 ;
192+
193+ size_t gid = item.get_global_linear_id ();
194+ size_t lid = item.get_local_linear_id ();
195+
196+ if (gid < global_array_size) {
197+ lm[lid] = arr[gid];
198+ }
199+ item.barrier (sycl::access::fence_space::local_space);
200+
201+ for (size_t p = sp2; p > 0 ; p >>= 1 ) {
202+ size_t q = sp2;
203+ size_t r = 0 ;
204+ for (size_t d = p; d > 0 ; d = q - p, q >>= 1 , r = p) {
205+ if ((lid < chunk_size - d) && (lid & p) == r) {
206+ size_t i = lid;
207+ size_t j = i + d;
208+ T v1 = lm[i];
209+ T v2 = lm[j];
210+ if (v1 > v2) {
211+ lm[i] = v2;
212+ lm[j] = v1;
213+ }
214+ }
215+ item.barrier (sycl::access::fence_space::local_space);
216+ }
217+ }
218+ if (gid < global_array_size) {
219+ arr[gid] = lm[lid];
220+ }
221+ };
222+ };
223+
224+ template <typename T> class local_sort_kern ;
225+
226+ template <typename T>
227+ sycl::kernel get_local_sort_kernel (sycl::queue &q,
228+ size_t gws,
229+ size_t lws,
230+ T *arr,
231+ size_t arr_len)
232+ {
233+ sycl::program program (q.get_context ());
234+
235+ using local_accessor_t =
236+ sycl::accessor<T, 1 , sycl::access::mode::read_write,
237+ sycl::access::target::local>;
238+
239+ [[maybe_unused]] auto cgh_fn = make_cgh_nd_function_with_local_memory<
240+ local_sort_kern<T>, local_accessor_t , LocalSortArgs<T>,
241+ LocalSortFunc<T, local_accessor_t >>(
242+ sycl::nd_range<1 >(gws, lws), lws, LocalSortArgs<T>(arr, arr_len, lws));
243+
244+ program.build_with_kernel_type <local_sort_kern<T>>();
245+ return program.get_kernel <local_sort_kern<T>>();
246+ };
247+
248+ template <typename T> struct LocalCountExceedanceFunc
249+ {
250+ T *arr;
251+ size_t arr_len;
252+ T threshold_val;
253+ int *count_arr;
254+ LocalCountExceedanceFunc (T *arr,
255+ size_t arr_len,
256+ T threshold_val,
257+ int *count_arr)
258+ : arr(arr), arr_len(arr_len), threshold_val(threshold_val),
259+ count_arr (count_arr)
260+ {
261+ }
262+
263+ void operator ()(sycl::nd_item<1 > item) const
264+ {
265+ /* count number of array elements in group chunk that
266+ exceeds the threshold value */
267+ size_t gid = item.get_global_linear_id ();
268+ int partial_sum = sycl::ONEAPI::reduce (
269+ item.get_group (),
270+ (gid < arr_len) ? int (arr[gid] > threshold_val) : int (0 ),
271+ std::plus<int >());
272+ count_arr[item.get_group_linear_id ()] = partial_sum;
273+ }
274+ };
275+
276+ template <typename T> class local_exceedance_kern ;
277+
278+ template <typename T>
279+ sycl::kernel get_local_count_exceedance_kernel (sycl::queue &q,
280+ size_t gws,
281+ size_t lws,
282+ T *arr,
283+ size_t arr_len,
284+ T threshold_val,
285+ int *counts)
286+ {
287+ sycl::program program (q.get_context ());
288+
289+ [[maybe_unused]] auto cgh_fn =
290+ make_cgh_nd_function<local_exceedance_kern<T>,
291+ LocalCountExceedanceFunc<T>>(
292+ sycl::nd_range<1 >(gws, lws),
293+ LocalCountExceedanceFunc<T>(arr, arr_len, threshold_val, counts));
294+
295+ program.build_with_kernel_type <local_exceedance_kern<T>>();
296+ return program.get_kernel <local_exceedance_kern<T>>();
297+ };
298+
108299} // namespace dpcpp_kernels
0 commit comments