Skip to content

Commit 37c5e3e

Browse files
committed
Flux: simplify when patch_size is 1
1 parent 203d053 commit 37c5e3e

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

flux.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,11 @@ namespace Flux {
891891
int64_t C = x->ne[2];
892892
int64_t H = x->ne[1];
893893
int64_t W = x->ne[0];
894+
if (params.patch_size == 1) {
895+
x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W]
896+
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C]
897+
return x;
898+
}
894899
int64_t p = params.patch_size;
895900
int64_t h = H / params.patch_size;
896901
int64_t w = W / params.patch_size;
@@ -925,6 +930,12 @@ namespace Flux {
925930
int64_t W = w * params.patch_size;
926931
int64_t p = params.patch_size;
927932

933+
if (params.patch_size == 1) {
934+
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W]
935+
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W]
936+
return x;
937+
}
938+
928939
GGML_ASSERT(C * p * p == x->ne[0]);
929940

930941
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]

0 commit comments

Comments
 (0)