diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 476c3079795..bb30d119d89 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4487,7 +4487,7 @@ struct ggml_tensor * ggml_conv_1d( int s0, int p0, int d0) { - struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K] + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, a->type == GGML_TYPE_F16 ? GGML_TYPE_F16 : GGML_TYPE_F32); // [N, OL, IC * K] struct ggml_tensor * result = ggml_mul_mat(ctx, @@ -4521,7 +4521,7 @@ struct ggml_tensor * ggml_conv_1d_dw( int d0) { struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]); - struct ggml_tensor * im2col = ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); + struct ggml_tensor * im2col = ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, a->type == GGML_TYPE_F16 ? GGML_TYPE_F16 : GGML_TYPE_F32); struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a); @@ -4592,7 +4592,7 @@ struct ggml_tensor * ggml_conv_2d( int p1, int d0, int d1) { - struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW] + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type == GGML_TYPE_F16 ? GGML_TYPE_F16 : GGML_TYPE_F32); // [N, OH, OW, IC * KH * KW] struct ggml_tensor * result = ggml_mul_mat(ctx, @@ -4674,7 +4674,7 @@ struct ggml_tensor * ggml_conv_3d( int d1, // dilation height int d2 // dilation depth ) { - struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW] + struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type == GGML_TYPE_F16 ? GGML_TYPE_F16 : GGML_TYPE_F32); // [N*OD, OH, OW, IC * KD * KH * KW] int64_t OC = a->ne[3] / IC; int64_t N = b->ne[3] / IC; @@ -4724,7 +4724,7 @@ struct ggml_tensor * ggml_conv_2d_dw( struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), - s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] + s0, s1, p0, p1, d0, d1, true, a->type == GGML_TYPE_F16 ? GGML_TYPE_F16 : GGML_TYPE_F32); // [N * IC, OH, OW, KH * KW] struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]