Skip to content

Commit 5f748ed

Browse files
authored
support punica bgmv (#3422)
* enable punica * change API * enable bgmv_expand * improve code style * add kernel punica_bgmv_expand_slice * add more ut; pass all ut * add frontend APIs and docs; refine code * revert APIs changes * aligned APIs with xpu * add docs * add docs * fix clang
1 parent 30ecffa commit 5f748ed

File tree

8 files changed

+1732
-0
lines changed

8 files changed

+1732
-0
lines changed

csrc/cpu/aten/Punica.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#include "Punica.h"
2+
#include <torch/all.h>
3+
#include <torch/csrc/autograd/function.h>
4+
#include "csrc/utils/CustomOperatorRegistration.h"
5+
6+
namespace torch_ipex {
7+
namespace cpu {
8+
9+
IPEX_DEFINE_DISPATCH(punica_bgmv_shrink_kernel_stub);
10+
IPEX_DEFINE_DISPATCH(punica_bgmv_expand_kernel_stub);
11+
IPEX_DEFINE_DISPATCH(punica_bgmv_expand_slice_kernel_stub);
12+
13+
void punica_bgmv_shrink_forward_cpu(
14+
at::Tensor& out,
15+
at::Tensor& input,
16+
at::Tensor& weights,
17+
at::Tensor& indicies,
18+
const double scale) {
19+
return punica_bgmv_shrink_kernel_stub(
20+
kCPU, out, input, weights, indicies, scale);
21+
}
22+
23+
void punica_bgmv_expand_forward_cpu(
24+
at::Tensor& out,
25+
at::Tensor& input,
26+
at::Tensor& weights,
27+
at::Tensor& indicies,
28+
bool add_inputs) {
29+
return punica_bgmv_expand_kernel_stub(
30+
kCPU, out, input, weights, indicies, add_inputs);
31+
}
32+
33+
void punica_bgmv_expand_slice_forward_cpu(
34+
at::Tensor& out,
35+
at::Tensor& input,
36+
at::Tensor& weights,
37+
at::Tensor& indicies,
38+
int64_t slice_offset,
39+
int64_t slice_size,
40+
bool add_inputs) {
41+
return punica_bgmv_expand_slice_kernel_stub(
42+
kCPU,
43+
out,
44+
input,
45+
weights,
46+
indicies,
47+
slice_offset,
48+
slice_size,
49+
add_inputs);
50+
}
51+
52+
} // namespace cpu
53+
} // namespace torch_ipex
54+
55+
namespace {
56+
57+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
58+
IPEX_OP_REGISTER_DISPATCH(
59+
"punica_bgmv_shrink",
60+
torch_ipex::cpu::punica_bgmv_shrink_forward_cpu,
61+
c10::DispatchKey::CPU);
62+
}
63+
64+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
65+
IPEX_OP_REGISTER_DISPATCH(
66+
"punica_bgmv_expand",
67+
torch_ipex::cpu::punica_bgmv_expand_forward_cpu,
68+
c10::DispatchKey::CPU);
69+
}
70+
71+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
72+
IPEX_OP_REGISTER_DISPATCH(
73+
"punica_bgmv_expand_slice",
74+
torch_ipex::cpu::punica_bgmv_expand_slice_forward_cpu,
75+
c10::DispatchKey::CPU);
76+
}
77+
78+
} // namespace

csrc/cpu/aten/Punica.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <dyndisp/DispatchStub.h>
5+
6+
namespace torch_ipex {
7+
namespace cpu {
8+
9+
namespace {
10+
11+
void punica_bgmv_shrink(
12+
at::Tensor& out,
13+
at::Tensor& input,
14+
at::Tensor& weights,
15+
at::Tensor& indicies,
16+
const double scale);
17+
18+
void punica_bgmv_expand(
19+
at::Tensor& out,
20+
at::Tensor& input,
21+
at::Tensor& weights,
22+
at::Tensor& indicies,
23+
bool add_inputs);
24+
25+
void punica_bgmv_expand_slice(
26+
at::Tensor& out,
27+
at::Tensor& input,
28+
at::Tensor& weights,
29+
at::Tensor& indicies,
30+
int64_t slice_offset,
31+
int64_t slice_size,
32+
bool add_inputs);
33+
} // namespace
34+
35+
using punica_bgmv_shrink_fn = void (*)(
36+
at::Tensor& out,
37+
at::Tensor& input,
38+
at::Tensor& weights,
39+
at::Tensor& indicies,
40+
const double scale);
41+
42+
using punica_bgmv_expand_fn = void (*)(
43+
at::Tensor& out,
44+
at::Tensor& input,
45+
at::Tensor& weights,
46+
at::Tensor& indicies,
47+
bool add_inputs);
48+
49+
using punica_bgmv_expand_slice_fn = void (*)(
50+
at::Tensor& out,
51+
at::Tensor& input,
52+
at::Tensor& weights,
53+
at::Tensor& indicies,
54+
int64_t slice_offset,
55+
int64_t slice_size,
56+
bool add_inputs);
57+
58+
IPEX_DECLARE_DISPATCH(punica_bgmv_shrink_fn, punica_bgmv_shrink_kernel_stub);
59+
60+
IPEX_DECLARE_DISPATCH(punica_bgmv_expand_fn, punica_bgmv_expand_kernel_stub);
61+
62+
IPEX_DECLARE_DISPATCH(
63+
punica_bgmv_expand_slice_fn,
64+
punica_bgmv_expand_slice_kernel_stub);
65+
66+
} // namespace cpu
67+
} // namespace torch_ipex

0 commit comments

Comments
 (0)