1+ #include " DSMoE.h"
2+ #include < ATen/cpu/vec/vec.h>
3+ #include < ATen/native/CPUBlas.h>
4+ #include < aten/utils/amx.h>
5+ #include < aten/utils/common.h>
6+ #include < torch/all.h>
7+ #include < torch/csrc/autograd/function.h>
8+ namespace torch_ipex {
9+ namespace cpu {
10+
11+ IPEX_DEFINE_DISPATCH (fused_experts_impl_stub);
12+ at::Tensor fused_experts (
13+ const at::Tensor& hidden_states,
14+ const at::Tensor& w1,
15+ const at::Tensor& w2,
16+ const at::Tensor& topk_weights,
17+ const at::Tensor& topk_ids,
18+ bool inplace,
19+ bool is_vnni,
20+ bool is_distributed,
21+ bool is_woq,
22+ int64_t woq_weight_dtype,
23+ int64_t woq_group_size,
24+ int64_t woq_lowp_mode,
25+ const std::optional<at::Tensor>& w1_scale,
26+ const std::optional<at::Tensor>& w1_zp,
27+ const std::optional<at::Tensor>& w1_compensation,
28+ const std::optional<at::Tensor>& w2_scale,
29+ const std::optional<at::Tensor>& w2_zp,
30+ const std::optional<at::Tensor>& w2_compensation) {
31+ RECORD_FUNCTION (" ipex::fused_experts" , c10::ArrayRef<c10::IValue>({}));
32+
33+ return fused_experts_impl_stub (
34+ kCPU ,
35+ hidden_states,
36+ w1,
37+ w2,
38+ topk_weights,
39+ topk_ids,
40+ inplace,
41+ is_vnni,
42+ is_distributed,
43+ is_woq,
44+ woq_weight_dtype,
45+ woq_group_size,
46+ woq_lowp_mode,
47+ w1_scale,
48+ w1_zp,
49+ w1_compensation,
50+ w2_scale,
51+ w2_zp,
52+ w2_compensation);
53+ }
54+
55+ constexpr int block_size_m () {
56+ return 1 * TILE_M;
57+ }
58+ constexpr int block_size_n () {
59+ return 8 * TILE_N;
60+ }
61+ // convert to vnni format
62+ // from [N, K] to [K/2, N, 2] for bfloat16 and float16
63+ //
64+ // [N, K/2, 2] to [K/2, N, 2]
65+ template <typename scalar_t >
66+ inline void pack_vnni (
67+ scalar_t * __restrict__ packed,
68+ const scalar_t * __restrict__ weight,
69+ int N,
70+ int K) {
71+ for (int n = 0 ; n < N; ++n) {
72+ for (int k = 0 ; k < K / VNNI_BLK; ++k) {
73+ for (int d = 0 ; d < VNNI_BLK; ++d) {
74+ packed[k * N * VNNI_BLK + n * VNNI_BLK + d] =
75+ weight[n * K + k * VNNI_BLK + d];
76+ }
77+ }
78+ }
79+ }
80+
81+ at::Tensor convert_weight_packed_bf16 (at::Tensor& weight) {
82+ // weight : [E, OC, IC]
83+ // w1 : [E, 2N, K]
84+ // w2 : [E, K, N]
85+ CHECK_DIM (3 , weight);
86+ const auto st = weight.scalar_type ();
87+ const int E = weight.size (0 );
88+ const int OC = weight.size (1 );
89+ const int IC = weight.size (2 );
90+ // we handle 2 TILE_N at a time.
91+ TORCH_CHECK (OC % TILE_N == 0 , " invalid weight out features " , OC);
92+ TORCH_CHECK (IC % TILE_K == 0 , " invalid weight input features " , IC);
93+ constexpr int BLOCK_N = block_size_n ();
94+ // use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
95+ auto packed_weight = at::empty ({E, OC, IC}, weight.options ());
96+ const int stride = OC * IC;
97+ // TODO: add float8 support
98+ TORCH_CHECK (
99+ st == at::kBFloat16 || st == at::kHalf ,
100+ " expect weight to be bfloat16 or float16." );
101+ AT_DISPATCH_REDUCED_FLOATING_TYPES (st, " conver_weight_packed_impl" , [&] {
102+ const scalar_t * w_data = weight.data_ptr <scalar_t >();
103+ scalar_t * packed_data = packed_weight.data_ptr <scalar_t >();
104+ // parallel on {E}
105+ at::parallel_for (0 , E, 0 , [&](int begin, int end) {
106+ for (int e = begin; e < end; ++e) {
107+ for (int n = 0 ; n < OC; n += BLOCK_N) {
108+ int n_size = std::min (BLOCK_N, OC - n);
109+ pack_vnni<scalar_t >(
110+ packed_data + e * stride + n * IC,
111+ w_data + e * stride + n * IC,
112+ n_size,
113+ IC);
114+ }
115+ }
116+ });
117+ });
118+
119+ return packed_weight;
120+ }
121+
122+ template <typename scalar_t , int SIZE>
123+ inline void sigmoid (
124+ float * __restrict__ out,
125+ const scalar_t * __restrict__ input) {
126+ using bVec = at::vec::Vectorized<scalar_t >;
127+ using fVec = at::vec::Vectorized<float >;
128+ constexpr int kVecSize = bVec::size ();
129+ // step 0: convert input
130+ fVec one_fvec = fVec (1.0 );
131+ if constexpr (SIZE < kVecSize ) {
132+ // SIZE = 1, 2, 4, 8, 16; only the top half is used
133+ bVec x_bvec = bVec::loadu (input, SIZE);
134+ fVec x_fvec0, x_fvec1;
135+ std::tie (x_fvec0, x_fvec1) = at::vec::convert_to_float (x_bvec);
136+ x_fvec0.store (out, SIZE);
137+ } else {
138+ for (int d = 0 ; d < SIZE; d += kVecSize ) {
139+ bVec x_bvec = bVec::loadu (input + d);
140+ fVec x_fvec0, x_fvec1;
141+ std::tie (x_fvec0, x_fvec1) = at::vec::convert_to_float (x_bvec);
142+ x_fvec0.store (out + d);
143+ x_fvec1.store (out + d + fVec::size ());
144+ }
145+ }
146+
147+ fVec zero_fvec = fVec (0.0 );
148+ // div_out = (1 + (-x).exp())
149+ // out = 1/ div_out
150+ if constexpr (SIZE < fVec::size ()) {
151+ // SIZE = 1, 2, 4, 8
152+ fVec x_fvec =
153+ one_fvec / (one_fvec + (zero_fvec - fVec::loadu (out, SIZE)).exp_u20 ());
154+ x_fvec.store (out, SIZE);
155+ } else {
156+ for (int d = 0 ; d < SIZE; d += fVec::size ()) {
157+ fVec x_fvec =
158+ one_fvec / (one_fvec + (zero_fvec - fVec::loadu (out + d)).exp_u20 ());
159+ x_fvec.store (out + d);
160+ }
161+ }
162+ }
163+ template <typename scalar_t , int NUM_EXPERTS>
164+ void grouped_topk_kernel_impl (
165+ float * __restrict__ topk_weights,
166+ int32_t * __restrict__ topk_ids,
167+ const scalar_t * __restrict__ gating_output,
168+ int num_tokens,
169+ int topk,
170+ int num_groups,
171+ int topk_group,
172+ bool renormalize,
173+ float * __restrict__ e_score_correction_bias,
174+ float * routed_scaling_factor) {
175+ const int num_experts_per_group = NUM_EXPERTS / num_groups;
176+ parallel_for (num_tokens, [&](int begin, int end) {
177+ static thread_local float scores[NUM_EXPERTS];
178+ static thread_local float ori_scores[NUM_EXPERTS];
179+ using elem_t = std::pair<float , int32_t >;
180+ std::vector<elem_t > queue_temp (num_groups);
181+ std::vector<elem_t > queue (num_groups);
182+ std::vector<elem_t > queue2 (topk_group * num_experts_per_group);
183+
184+ for (int i = begin; i < end; ++i) {
185+ // do softmax to get scores
186+ sigmoid<scalar_t , NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
187+ for (int g = 0 ; g < NUM_EXPERTS; ++g) {
188+ ori_scores[g] = scores[g];
189+ scores[g] = scores[g] + e_score_correction_bias[g];
190+ }
191+ // find max score per group
192+ for (int g = 0 ; g < num_groups; ++g) {
193+ float gmax = -std::numeric_limits<float >::infinity ();
194+ for (int e = 0 ; e < num_experts_per_group; ++e) {
195+ gmax = std::max (gmax, scores[g * num_experts_per_group + e]);
196+ }
197+ queue_temp[g] = {gmax, g};
198+ }
199+ for (int g = 0 ; g < num_groups; ++g) {
200+ float pervious_max = queue_temp[g].first ;
201+ int count_pervious_max = 1 ;
202+ float gmax = -std::numeric_limits<float >::infinity ();
203+ for (int e = 0 ; e < num_experts_per_group; ++e) {
204+ if (count_pervious_max == 1 &&
205+ scores[g * num_experts_per_group + e] == pervious_max) {
206+ count_pervious_max--;
207+ } else {
208+ gmax = std::max (gmax, scores[g * num_experts_per_group + e]);
209+ }
210+ }
211+ queue[g] = {gmax + pervious_max, g};
212+ }
213+ // find group topk
214+ std::partial_sort (
215+ queue.begin (),
216+ queue.begin () + topk_group,
217+ queue.end (),
218+ [](const elem_t & x, const elem_t & y) -> bool {
219+ return x.first > y.first ;
220+ });
221+
222+ for (int g = 0 ; g < topk_group; ++g) {
223+ int32_t group_idx = queue[g].second ;
224+ for (int e = 0 ; e < num_experts_per_group; ++e) {
225+ int32_t expert_idx = group_idx * num_experts_per_group + e;
226+ queue2[g * num_experts_per_group + e] = {
227+ scores[expert_idx], expert_idx};
228+ }
229+ }
230+ // find global topk
231+ std::partial_sort (
232+ queue2.begin (),
233+ queue2.begin () + topk,
234+ queue2.end (),
235+ [](const elem_t & x, const elem_t & y) -> bool {
236+ return x.first > y.first ;
237+ });
238+ for (int j = 0 ; j < topk; ++j) {
239+ topk_weights[i * topk + j] = ori_scores[queue2[j].second ];
240+ topk_ids[i * topk + j] = queue2[j].second ;
241+ }
242+ if (renormalize) {
243+ float sum = 0 .f ;
244+ for (int j = 0 ; j < topk; ++j) {
245+ sum += topk_weights[i * topk + j];
246+ }
247+ float scale = 1 .f / sum;
248+ for (int j = 0 ; j < topk; ++j) {
249+ topk_weights[i * topk + j] *= scale;
250+ }
251+ }
252+ for (int j = 0 ; j < topk; ++j) {
253+ topk_weights[i * topk + j] =
254+ topk_weights[i * topk + j] * routed_scaling_factor[0 ];
255+ }
256+ }
257+ });
258+ }
259+
260+ #define LAUNCH_GROUPED_TOPK_KERNEL (NE ) \
261+ grouped_topk_kernel_impl<at::BFloat16, NE>( \
262+ topk_weights.data_ptr<float >(), \
263+ topk_ids.data_ptr<int32_t >(), \
264+ gating_output.data_ptr<at::BFloat16>(), \
265+ num_tokens, \
266+ topk, \
267+ num_expert_group, \
268+ topk_group, \
269+ renormalize, \
270+ e_score_correction_bias.data_ptr<float >(), \
271+ routed_scaling_factor.data_ptr<float >());
272+
273+ //
274+ std::tuple<at::Tensor, at::Tensor> grouped_topk (
275+ at::Tensor& hidden_states,
276+ at::Tensor& gating_output,
277+ int64_t topk,
278+ bool renormalize,
279+ int64_t num_expert_group,
280+ int64_t topk_group,
281+ at::Tensor& e_score_correction_bias,
282+ at::Tensor& routed_scaling_factor) {
283+ const auto st = hidden_states.scalar_type ();
284+ CHECK_EQ (gating_output.scalar_type (), st);
285+
286+ int64_t num_tokens = hidden_states.size (0 );
287+ int64_t num_experts = gating_output.size (1 );
288+ TORCH_CHECK (gating_output.size (0 ) == num_tokens, " Number of tokens mismatch" );
289+ auto topk_weights = at::empty ({num_tokens, topk}, at::kFloat );
290+ auto topk_ids = at::empty_like (topk_weights, at::kInt );
291+ switch (num_experts) {
292+ case 1 :
293+ LAUNCH_GROUPED_TOPK_KERNEL (1 );
294+ break ;
295+ case 2 :
296+ LAUNCH_GROUPED_TOPK_KERNEL (2 );
297+ break ;
298+ case 4 :
299+ LAUNCH_GROUPED_TOPK_KERNEL (4 );
300+ break ;
301+ case 8 :
302+ LAUNCH_GROUPED_TOPK_KERNEL (8 );
303+ break ;
304+ case 16 :
305+ LAUNCH_GROUPED_TOPK_KERNEL (16 );
306+ break ;
307+ case 32 :
308+ LAUNCH_GROUPED_TOPK_KERNEL (32 );
309+ break ;
310+ case 64 :
311+ LAUNCH_GROUPED_TOPK_KERNEL (64 );
312+ break ;
313+ case 128 :
314+ LAUNCH_GROUPED_TOPK_KERNEL (128 );
315+ break ;
316+ case 256 :
317+ LAUNCH_GROUPED_TOPK_KERNEL (256 );
318+ break ;
319+ default :
320+ TORCH_CHECK (false , " Unexpected num_experts: " , num_experts);
321+ }
322+ return std::make_tuple (topk_ids, topk_weights);
323+ }
324+ } // namespace cpu
325+ } // namespace torch_ipex
326+
327+ namespace {
328+
329+ TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
330+ m.def (
331+ " fused_experts(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, \
332+ Tensor topk_ids, bool inplace, bool is_vnni, \
333+ bool is_distributed, bool is_woq, int woq_weight_dtype, int woq_group_size, int woq_lowp_mode, \
334+ Tensor? w1_scale, Tensor? w1_zp, Tensor? w1_compensation, Tensor? w2_scale, Tensor? w2_zp, Tensor? w2_compensation) -> Tensor" );
335+ m.impl (
336+ " fused_experts" , c10::DispatchKey::CPU, torch_ipex::cpu::fused_experts);
337+ m.def (
338+ " grouped_topk(Tensor hidden_states, Tensor gating_output, \
339+ int topk, bool renormalize, int num_expert_group, int topk_group, Tensor e_score_correction_bias, Tensor routed_scaling_factor) -> (Tensor, Tensor)" );
340+ m.impl (" grouped_topk" , c10::DispatchKey::CPU, torch_ipex::cpu::grouped_topk);
341+ m.def (" convert_weight_packed_bf16(Tensor weight) -> Tensor" );
342+ m.impl (
343+ " convert_weight_packed_bf16" ,
344+ c10::DispatchKey::CPU,
345+ torch_ipex::cpu::convert_weight_packed_bf16);
346+ }
347+ } // namespace
0 commit comments