Skip to content

Commit 8ab65b3

Browse files
authored
Migrate quantized_conv2d tests to graph builder
Differential Revision: D93112638 Pull Request resolved: #17451
1 parent 034c3a7 commit 8ab65b3

1 file changed

Lines changed: 266 additions & 44 deletions

File tree

backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp

Lines changed: 266 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,45 @@ void xa_opt_quantized_conv2d_nchw(
175175
bool conv1d = input.dim() == 3;
176176
constexpr int kNnlibMaxDim = 4;
177177

178+
WORD32 input_height = conv1d ? 1 : input.size(2);
179+
WORD32 input_width = conv1d ? input.size(2) : input.size(3);
180+
WORD32 input_channels = input.size(1);
181+
WORD32 kernel_height = conv1d ? 1 : weight.size(2);
182+
WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3);
183+
WORD32 kernel_channels = weight.size(1);
184+
WORD32 out_channels = weight.size(0);
185+
WORD32 out_height = conv1d ? 1 : out.size(2);
186+
WORD32 out_width = conv1d ? out.size(2) : out.size(3);
187+
WORD32 batches = input.size(0);
188+
189+
WORD32 x_stride = stride[1];
190+
WORD32 y_stride = stride[0];
191+
WORD32 x_padding = padding[1];
192+
WORD32 y_padding = padding[0];
193+
WORD32 dilation_width = dilation[1];
194+
WORD32 dilation_height = dilation[0];
195+
196+
WORD32 input_zero_bias = -in_zero_point;
197+
WORD32 kernel_zero_bias = -weight_zero_point;
198+
199+
WORD32 out_multiplier32[out_channels];
200+
WORD32 out_shift32[out_channels];
201+
202+
float out_scale = 1. / output_scale;
203+
204+
for (int i = 0; i < out_channels; i++) {
205+
out_multiplier32[i] = bias_scale * out_scale * 2147483648;
206+
out_shift32[i] = 0;
207+
}
208+
209+
WORD32 out_zero_bias = output_zero_point;
210+
WORD32 inp_precision = 8;
211+
WORD32 kernel_precision = 8;
212+
pVOID p_scratch = nullptr;
213+
WORD32* ptr_scratch;
214+
215+
WORD32 scratch_size = 0;
216+
178217
if (input.scalar_type() == ScalarType::Char) {
179218
WORD8* __restrict__ p_out =
180219
(WORD8* __restrict__)out.mutable_data_ptr<int8_t>();
@@ -185,48 +224,6 @@ void xa_opt_quantized_conv2d_nchw(
185224
WORD32* __restrict__ p_bias =
186225
(WORD32* __restrict__)bias.const_data_ptr<int32_t>();
187226

188-
WORD32 input_height = conv1d ? 1 : input.size(2);
189-
WORD32 input_width = conv1d ? input.size(2) : input.size(3);
190-
WORD32 input_channels = input.size(1);
191-
WORD32 kernel_height = conv1d ? 1 : weight.size(2);
192-
WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3);
193-
WORD32 kernel_channels = weight.size(1);
194-
WORD32 out_channels = weight.size(0);
195-
WORD32 out_height = conv1d ? 1 : out.size(2);
196-
WORD32 out_width = conv1d ? out.size(2) : out.size(3);
197-
WORD32 batches = input.size(0);
198-
199-
WORD32 x_stride = stride[1];
200-
WORD32 y_stride = stride[0];
201-
WORD32 x_padding = padding[1];
202-
WORD32 y_padding = padding[0];
203-
WORD32 dilation_width = dilation[1];
204-
WORD32 dilation_height = dilation[0];
205-
206-
// WORD32* kernel_bias_ptr =
207-
// (WORD32*)weight_zero_point.const_data_ptr<int32_t>();
208-
209-
WORD32 input_zero_bias = -in_zero_point;
210-
WORD32 kernel_zero_bias = -weight_zero_point;
211-
212-
WORD32 out_multiplier32[out_channels];
213-
WORD32 out_shift32[out_channels];
214-
215-
float out_scale = 1. / output_scale;
216-
217-
for (int i = 0; i < out_channels; i++) {
218-
out_multiplier32[i] = bias_scale * out_scale * 2147483648;
219-
out_shift32[i] = 0;
220-
}
221-
222-
WORD32 out_zero_bias = output_zero_point;
223-
WORD32 inp_precision = 8;
224-
WORD32 kernel_precision = 8;
225-
pVOID p_scratch = nullptr;
226-
WORD32* ptr_scratch;
227-
228-
WORD32 scratch_size = 0;
229-
230227
if (groups == 1) {
231228
WORD32 out_data_format = 1;
232229

@@ -245,13 +242,13 @@ void xa_opt_quantized_conv2d_nchw(
245242
WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8);
246243

247244
WORD32 p_inp_shape[kNnlibMaxDim];
248-
p_inp_shape[0] = input.size(0);
245+
p_inp_shape[0] = batches;
249246
p_inp_shape[1] = input_channels;
250247
p_inp_shape[2] = input_height;
251248
p_inp_shape[3] = input_width;
252249

253250
WORD32 p_out_shape[kNnlibMaxDim];
254-
p_out_shape[0] = input.size(0);
251+
p_out_shape[0] = batches;
255252
p_out_shape[1] = input_height;
256253
p_out_shape[2] = input_width;
257254
p_out_shape[3] = input_channels;
@@ -439,6 +436,231 @@ void xa_opt_quantized_conv2d_nchw(
439436
return;
440437
}
441438
}
439+
440+
if (input.scalar_type() == ScalarType::Byte) {
441+
UWORD8* __restrict__ p_out =
442+
(UWORD8* __restrict__)out.mutable_data_ptr<uint8_t>();
443+
UWORD8* __restrict__ p_inp =
444+
(UWORD8* __restrict__)input.const_data_ptr<uint8_t>();
445+
UWORD8* __restrict__ p_kernel =
446+
(UWORD8* __restrict__)weight.const_data_ptr<uint8_t>();
447+
WORD32* __restrict__ p_bias =
448+
(WORD32* __restrict__)bias.const_data_ptr<int32_t>();
449+
450+
WORD32 out_multiplier = out_multiplier32[0];
451+
WORD32 out_shift = out_shift32[0];
452+
453+
if (groups == 1) {
454+
WORD32 out_data_format = 1;
455+
456+
UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory(
457+
ctx,
458+
((batches * input_channels * input_height * input_width) + 8) *
459+
sizeof(UWORD8));
460+
461+
UWORD8* ptr2 = (UWORD8*)kernels::allocate_temp_memory(
462+
ctx,
463+
((out_channels * kernel_channels * kernel_height * kernel_width) +
464+
8) *
465+
sizeof(UWORD8));
466+
467+
UWORD8* pin = (UWORD8*)ALIGN_PTR(ptr1, 8);
468+
UWORD8* pkernel = (UWORD8*)ALIGN_PTR(ptr2, 8);
469+
470+
WORD32 p_inp_shape[kNnlibMaxDim];
471+
p_inp_shape[0] = batches;
472+
p_inp_shape[1] = input_channels;
473+
p_inp_shape[2] = input_height;
474+
p_inp_shape[3] = input_width;
475+
476+
WORD32 p_out_shape[kNnlibMaxDim];
477+
p_out_shape[0] = batches;
478+
p_out_shape[1] = input_height;
479+
p_out_shape[2] = input_width;
480+
p_out_shape[3] = input_channels;
481+
482+
WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1};
483+
484+
xa_nn_transpose_8_8(
485+
(WORD8*)pin,
486+
p_out_shape,
487+
(WORD8*)p_inp,
488+
p_inp_shape,
489+
p_permute_vec,
490+
kNnlibMaxDim,
491+
kNnlibMaxDim);
492+
493+
WORD32 p_inp_shape1[kNnlibMaxDim];
494+
p_inp_shape1[0] = out_channels;
495+
p_inp_shape1[1] = kernel_channels;
496+
p_inp_shape1[2] = kernel_height;
497+
p_inp_shape1[3] = kernel_width;
498+
499+
WORD32 p_out_shape1[kNnlibMaxDim];
500+
p_out_shape1[0] = out_channels;
501+
p_out_shape1[1] = kernel_height;
502+
p_out_shape1[2] = kernel_width;
503+
p_out_shape1[3] = kernel_channels;
504+
505+
xa_nn_transpose_8_8(
506+
(WORD8*)pkernel,
507+
p_out_shape1,
508+
(WORD8*)p_kernel,
509+
p_inp_shape1,
510+
p_permute_vec,
511+
kNnlibMaxDim,
512+
kNnlibMaxDim);
513+
514+
scratch_size = xa_nn_conv2d_getsize(
515+
input_height,
516+
input_width,
517+
input_channels,
518+
kernel_height,
519+
kernel_width,
520+
kernel_channels,
521+
dilation_height,
522+
dilation_width,
523+
y_stride,
524+
y_padding,
525+
x_stride,
526+
x_padding,
527+
out_height,
528+
out_width,
529+
out_channels,
530+
inp_precision,
531+
kernel_precision,
532+
out_data_format);
533+
534+
scratch_size = scratch_size < 0 ? 0 : scratch_size;
535+
536+
ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size);
537+
538+
p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8);
539+
540+
for (int _n = 0; _n < batches; _n++) {
541+
UWORD8* in_batch =
542+
pin + _n * input_channels * input_height * input_width;
543+
UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width;
544+
545+
xa_nn_conv2d_std_asym8uxasym8u(
546+
out_batch,
547+
in_batch,
548+
pkernel,
549+
p_bias,
550+
input_height,
551+
input_width,
552+
input_channels,
553+
kernel_height,
554+
kernel_width,
555+
out_channels,
556+
x_stride,
557+
y_stride,
558+
x_padding,
559+
y_padding,
560+
out_height,
561+
out_width,
562+
input_zero_bias,
563+
kernel_zero_bias,
564+
out_multiplier,
565+
out_shift,
566+
out_zero_bias,
567+
out_data_format,
568+
p_scratch);
569+
}
570+
return;
571+
}
572+
573+
if (groups == input_channels) {
574+
WORD32 channels_multiplier = out_channels / input_channels;
575+
576+
scratch_size = xa_nn_conv2d_depthwise_getsize(
577+
input_height,
578+
input_width,
579+
input_channels,
580+
kernel_height,
581+
kernel_width,
582+
channels_multiplier,
583+
x_stride,
584+
y_stride,
585+
x_padding,
586+
y_padding,
587+
out_height,
588+
out_width,
589+
inp_precision,
590+
1); // NCHW
591+
592+
scratch_size = scratch_size < 0 ? 0 : scratch_size;
593+
594+
ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size);
595+
596+
p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8);
597+
598+
UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory(
599+
ctx,
600+
((batches * out_channels * out_height * out_width) + 8) *
601+
sizeof(UWORD8));
602+
603+
UWORD8* p_out_temp = (UWORD8*)ALIGN_PTR(ptr1, 8);
604+
605+
for (int _n = 0; _n < batches; _n++) {
606+
UWORD8* in_batch =
607+
p_inp + _n * input_channels * input_height * input_width;
608+
UWORD8* out_batch =
609+
p_out_temp + _n * out_channels * out_height * out_width;
610+
611+
xa_nn_conv2d_depthwise_asym8uxasym8u(
612+
out_batch,
613+
p_kernel,
614+
in_batch,
615+
p_bias,
616+
input_height,
617+
input_width,
618+
input_channels,
619+
kernel_height,
620+
kernel_width,
621+
channels_multiplier,
622+
x_stride,
623+
y_stride,
624+
x_padding,
625+
y_padding,
626+
out_height,
627+
out_width,
628+
input_zero_bias,
629+
kernel_zero_bias,
630+
out_multiplier,
631+
out_shift,
632+
out_zero_bias,
633+
1, // NCHW
634+
0, // NHWC
635+
p_scratch);
636+
}
637+
638+
WORD32 p_inp_shape[kNnlibMaxDim];
639+
p_inp_shape[0] = batches;
640+
p_inp_shape[1] = out_height;
641+
p_inp_shape[2] = out_width;
642+
p_inp_shape[3] = out_channels;
643+
644+
WORD32 p_out_shape[kNnlibMaxDim];
645+
p_out_shape[0] = batches;
646+
p_out_shape[1] = out_channels;
647+
p_out_shape[2] = out_height;
648+
p_out_shape[3] = out_width;
649+
650+
WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2};
651+
652+
xa_nn_transpose_8_8(
653+
(WORD8*)p_out,
654+
p_out_shape,
655+
(WORD8*)p_out_temp,
656+
p_inp_shape,
657+
p_permute_vec,
658+
kNnlibMaxDim,
659+
kNnlibMaxDim);
660+
661+
return;
662+
}
663+
}
442664
}
443665

444666
// The quantized convolution kernel. in_scale and weight_scale are implicit in

0 commit comments

Comments
 (0)