@@ -175,6 +175,45 @@ void xa_opt_quantized_conv2d_nchw(
175175 bool conv1d = input.dim () == 3 ;
176176 constexpr int kNnlibMaxDim = 4 ;
177177
178+ WORD32 input_height = conv1d ? 1 : input.size (2 );
179+ WORD32 input_width = conv1d ? input.size (2 ) : input.size (3 );
180+ WORD32 input_channels = input.size (1 );
181+ WORD32 kernel_height = conv1d ? 1 : weight.size (2 );
182+ WORD32 kernel_width = conv1d ? weight.size (2 ) : weight.size (3 );
183+ WORD32 kernel_channels = weight.size (1 );
184+ WORD32 out_channels = weight.size (0 );
185+ WORD32 out_height = conv1d ? 1 : out.size (2 );
186+ WORD32 out_width = conv1d ? out.size (2 ) : out.size (3 );
187+ WORD32 batches = input.size (0 );
188+
189+ WORD32 x_stride = stride[1 ];
190+ WORD32 y_stride = stride[0 ];
191+ WORD32 x_padding = padding[1 ];
192+ WORD32 y_padding = padding[0 ];
193+ WORD32 dilation_width = dilation[1 ];
194+ WORD32 dilation_height = dilation[0 ];
195+
196+ WORD32 input_zero_bias = -in_zero_point;
197+ WORD32 kernel_zero_bias = -weight_zero_point;
198+
199+ WORD32 out_multiplier32[out_channels];
200+ WORD32 out_shift32[out_channels];
201+
202+ float out_scale = 1 . / output_scale;
203+
204+ for (int i = 0 ; i < out_channels; i++) {
205+ out_multiplier32[i] = bias_scale * out_scale * 2147483648 ;
206+ out_shift32[i] = 0 ;
207+ }
208+
209+ WORD32 out_zero_bias = output_zero_point;
210+ WORD32 inp_precision = 8 ;
211+ WORD32 kernel_precision = 8 ;
212+ pVOID p_scratch = nullptr ;
213+ WORD32 * ptr_scratch;
214+
215+ WORD32 scratch_size = 0 ;
216+
178217 if (input.scalar_type () == ScalarType::Char) {
179218 WORD8 * __restrict__ p_out =
180219 (WORD8 * __restrict__)out.mutable_data_ptr <int8_t >();
@@ -185,48 +224,6 @@ void xa_opt_quantized_conv2d_nchw(
185224 WORD32 * __restrict__ p_bias =
186225 (WORD32 * __restrict__)bias.const_data_ptr <int32_t >();
187226
188- WORD32 input_height = conv1d ? 1 : input.size (2 );
189- WORD32 input_width = conv1d ? input.size (2 ) : input.size (3 );
190- WORD32 input_channels = input.size (1 );
191- WORD32 kernel_height = conv1d ? 1 : weight.size (2 );
192- WORD32 kernel_width = conv1d ? weight.size (2 ) : weight.size (3 );
193- WORD32 kernel_channels = weight.size (1 );
194- WORD32 out_channels = weight.size (0 );
195- WORD32 out_height = conv1d ? 1 : out.size (2 );
196- WORD32 out_width = conv1d ? out.size (2 ) : out.size (3 );
197- WORD32 batches = input.size (0 );
198-
199- WORD32 x_stride = stride[1 ];
200- WORD32 y_stride = stride[0 ];
201- WORD32 x_padding = padding[1 ];
202- WORD32 y_padding = padding[0 ];
203- WORD32 dilation_width = dilation[1 ];
204- WORD32 dilation_height = dilation[0 ];
205-
206- // WORD32* kernel_bias_ptr =
207- // (WORD32*)weight_zero_point.const_data_ptr<int32_t>();
208-
209- WORD32 input_zero_bias = -in_zero_point;
210- WORD32 kernel_zero_bias = -weight_zero_point;
211-
212- WORD32 out_multiplier32[out_channels];
213- WORD32 out_shift32[out_channels];
214-
215- float out_scale = 1 . / output_scale;
216-
217- for (int i = 0 ; i < out_channels; i++) {
218- out_multiplier32[i] = bias_scale * out_scale * 2147483648 ;
219- out_shift32[i] = 0 ;
220- }
221-
222- WORD32 out_zero_bias = output_zero_point;
223- WORD32 inp_precision = 8 ;
224- WORD32 kernel_precision = 8 ;
225- pVOID p_scratch = nullptr ;
226- WORD32 * ptr_scratch;
227-
228- WORD32 scratch_size = 0 ;
229-
230227 if (groups == 1 ) {
231228 WORD32 out_data_format = 1 ;
232229
@@ -245,13 +242,13 @@ void xa_opt_quantized_conv2d_nchw(
245242 WORD8 * pkernel = (WORD8 *)ALIGN_PTR (ptr2, 8 );
246243
247244 WORD32 p_inp_shape[kNnlibMaxDim ];
248- p_inp_shape[0 ] = input. size ( 0 ) ;
245+ p_inp_shape[0 ] = batches ;
249246 p_inp_shape[1 ] = input_channels;
250247 p_inp_shape[2 ] = input_height;
251248 p_inp_shape[3 ] = input_width;
252249
253250 WORD32 p_out_shape[kNnlibMaxDim ];
254- p_out_shape[0 ] = input. size ( 0 ) ;
251+ p_out_shape[0 ] = batches ;
255252 p_out_shape[1 ] = input_height;
256253 p_out_shape[2 ] = input_width;
257254 p_out_shape[3 ] = input_channels;
@@ -439,6 +436,231 @@ void xa_opt_quantized_conv2d_nchw(
439436 return ;
440437 }
441438 }
439+
440+ if (input.scalar_type () == ScalarType::Byte) {
441+ UWORD8 * __restrict__ p_out =
442+ (UWORD8 * __restrict__)out.mutable_data_ptr <uint8_t >();
443+ UWORD8 * __restrict__ p_inp =
444+ (UWORD8 * __restrict__)input.const_data_ptr <uint8_t >();
445+ UWORD8 * __restrict__ p_kernel =
446+ (UWORD8 * __restrict__)weight.const_data_ptr <uint8_t >();
447+ WORD32 * __restrict__ p_bias =
448+ (WORD32 * __restrict__)bias.const_data_ptr <int32_t >();
449+
450+ WORD32 out_multiplier = out_multiplier32[0 ];
451+ WORD32 out_shift = out_shift32[0 ];
452+
453+ if (groups == 1 ) {
454+ WORD32 out_data_format = 1 ;
455+
456+ UWORD8 * ptr1 = (UWORD8 *)kernels::allocate_temp_memory (
457+ ctx,
458+ ((batches * input_channels * input_height * input_width) + 8 ) *
459+ sizeof (UWORD8 ));
460+
461+ UWORD8 * ptr2 = (UWORD8 *)kernels::allocate_temp_memory (
462+ ctx,
463+ ((out_channels * kernel_channels * kernel_height * kernel_width) +
464+ 8 ) *
465+ sizeof (UWORD8 ));
466+
467+ UWORD8 * pin = (UWORD8 *)ALIGN_PTR (ptr1, 8 );
468+ UWORD8 * pkernel = (UWORD8 *)ALIGN_PTR (ptr2, 8 );
469+
470+ WORD32 p_inp_shape[kNnlibMaxDim ];
471+ p_inp_shape[0 ] = batches;
472+ p_inp_shape[1 ] = input_channels;
473+ p_inp_shape[2 ] = input_height;
474+ p_inp_shape[3 ] = input_width;
475+
476+ WORD32 p_out_shape[kNnlibMaxDim ];
477+ p_out_shape[0 ] = batches;
478+ p_out_shape[1 ] = input_height;
479+ p_out_shape[2 ] = input_width;
480+ p_out_shape[3 ] = input_channels;
481+
482+ WORD32 p_permute_vec[kNnlibMaxDim ] = {0 , 2 , 3 , 1 };
483+
484+ xa_nn_transpose_8_8 (
485+ (WORD8 *)pin,
486+ p_out_shape,
487+ (WORD8 *)p_inp,
488+ p_inp_shape,
489+ p_permute_vec,
490+ kNnlibMaxDim ,
491+ kNnlibMaxDim );
492+
493+ WORD32 p_inp_shape1[kNnlibMaxDim ];
494+ p_inp_shape1[0 ] = out_channels;
495+ p_inp_shape1[1 ] = kernel_channels;
496+ p_inp_shape1[2 ] = kernel_height;
497+ p_inp_shape1[3 ] = kernel_width;
498+
499+ WORD32 p_out_shape1[kNnlibMaxDim ];
500+ p_out_shape1[0 ] = out_channels;
501+ p_out_shape1[1 ] = kernel_height;
502+ p_out_shape1[2 ] = kernel_width;
503+ p_out_shape1[3 ] = kernel_channels;
504+
505+ xa_nn_transpose_8_8 (
506+ (WORD8 *)pkernel,
507+ p_out_shape1,
508+ (WORD8 *)p_kernel,
509+ p_inp_shape1,
510+ p_permute_vec,
511+ kNnlibMaxDim ,
512+ kNnlibMaxDim );
513+
514+ scratch_size = xa_nn_conv2d_getsize (
515+ input_height,
516+ input_width,
517+ input_channels,
518+ kernel_height,
519+ kernel_width,
520+ kernel_channels,
521+ dilation_height,
522+ dilation_width,
523+ y_stride,
524+ y_padding,
525+ x_stride,
526+ x_padding,
527+ out_height,
528+ out_width,
529+ out_channels,
530+ inp_precision,
531+ kernel_precision,
532+ out_data_format);
533+
534+ scratch_size = scratch_size < 0 ? 0 : scratch_size;
535+
536+ ptr_scratch = (WORD32 *)kernels::allocate_temp_memory (ctx, scratch_size);
537+
538+ p_scratch = (pVOID)ALIGN_PTR (ptr_scratch, 8 );
539+
540+ for (int _n = 0 ; _n < batches; _n++) {
541+ UWORD8 * in_batch =
542+ pin + _n * input_channels * input_height * input_width;
543+ UWORD8 * out_batch = p_out + _n * out_channels * out_height * out_width;
544+
545+ xa_nn_conv2d_std_asym8uxasym8u (
546+ out_batch,
547+ in_batch,
548+ pkernel,
549+ p_bias,
550+ input_height,
551+ input_width,
552+ input_channels,
553+ kernel_height,
554+ kernel_width,
555+ out_channels,
556+ x_stride,
557+ y_stride,
558+ x_padding,
559+ y_padding,
560+ out_height,
561+ out_width,
562+ input_zero_bias,
563+ kernel_zero_bias,
564+ out_multiplier,
565+ out_shift,
566+ out_zero_bias,
567+ out_data_format,
568+ p_scratch);
569+ }
570+ return ;
571+ }
572+
573+ if (groups == input_channels) {
574+ WORD32 channels_multiplier = out_channels / input_channels;
575+
576+ scratch_size = xa_nn_conv2d_depthwise_getsize (
577+ input_height,
578+ input_width,
579+ input_channels,
580+ kernel_height,
581+ kernel_width,
582+ channels_multiplier,
583+ x_stride,
584+ y_stride,
585+ x_padding,
586+ y_padding,
587+ out_height,
588+ out_width,
589+ inp_precision,
590+ 1 ); // NCHW
591+
592+ scratch_size = scratch_size < 0 ? 0 : scratch_size;
593+
594+ ptr_scratch = (WORD32 *)kernels::allocate_temp_memory (ctx, scratch_size);
595+
596+ p_scratch = (pVOID)ALIGN_PTR (ptr_scratch, 8 );
597+
598+ UWORD8 * ptr1 = (UWORD8 *)kernels::allocate_temp_memory (
599+ ctx,
600+ ((batches * out_channels * out_height * out_width) + 8 ) *
601+ sizeof (UWORD8 ));
602+
603+ UWORD8 * p_out_temp = (UWORD8 *)ALIGN_PTR (ptr1, 8 );
604+
605+ for (int _n = 0 ; _n < batches; _n++) {
606+ UWORD8 * in_batch =
607+ p_inp + _n * input_channels * input_height * input_width;
608+ UWORD8 * out_batch =
609+ p_out_temp + _n * out_channels * out_height * out_width;
610+
611+ xa_nn_conv2d_depthwise_asym8uxasym8u (
612+ out_batch,
613+ p_kernel,
614+ in_batch,
615+ p_bias,
616+ input_height,
617+ input_width,
618+ input_channels,
619+ kernel_height,
620+ kernel_width,
621+ channels_multiplier,
622+ x_stride,
623+ y_stride,
624+ x_padding,
625+ y_padding,
626+ out_height,
627+ out_width,
628+ input_zero_bias,
629+ kernel_zero_bias,
630+ out_multiplier,
631+ out_shift,
632+ out_zero_bias,
633+ 1 , // NCHW
634+ 0 , // NHWC
635+ p_scratch);
636+ }
637+
638+ WORD32 p_inp_shape[kNnlibMaxDim ];
639+ p_inp_shape[0 ] = batches;
640+ p_inp_shape[1 ] = out_height;
641+ p_inp_shape[2 ] = out_width;
642+ p_inp_shape[3 ] = out_channels;
643+
644+ WORD32 p_out_shape[kNnlibMaxDim ];
645+ p_out_shape[0 ] = batches;
646+ p_out_shape[1 ] = out_channels;
647+ p_out_shape[2 ] = out_height;
648+ p_out_shape[3 ] = out_width;
649+
650+ WORD32 p_permute_vec[kNnlibMaxDim ] = {0 , 3 , 1 , 2 };
651+
652+ xa_nn_transpose_8_8 (
653+ (WORD8 *)p_out,
654+ p_out_shape,
655+ (WORD8 *)p_out_temp,
656+ p_inp_shape,
657+ p_permute_vec,
658+ kNnlibMaxDim ,
659+ kNnlibMaxDim );
660+
661+ return ;
662+ }
663+ }
442664}
443665
444666// The quantized convolution kernel. in_scale and weight_scale are implicit in
0 commit comments