diff --git a/src/cpu/x64/jit_brgemm_1x1_conv.cpp b/src/cpu/x64/jit_brgemm_1x1_conv.cpp index 9dd6adb5fdd..e24a69a2e98 100644 --- a/src/cpu/x64/jit_brgemm_1x1_conv.cpp +++ b/src/cpu/x64/jit_brgemm_1x1_conv.cpp @@ -158,7 +158,8 @@ status_t brgemm_1x1_convolution_fwd_t::pd_t::init(engine_t *engine) { brgemm_convolution_utils::set_amx_wsp_per_thread(jcp_); auto scratchpad = scratchpad_registry().registrar(); - brgemm_convolution_utils::init_scratchpad(scratchpad, jcp_); + CHECK(brgemm_convolution_utils::init_scratchpad( + scratchpad, jcp_, *src_md(), *weights_md(), *dst_md())); return status::success; } @@ -441,7 +442,7 @@ void brgemm_1x1_convolution_fwd_t::exec_ker( const bool is_ic_tail = jcp.is_reduced_rtus ? is_last_os : (icc == pd()->ic_chunks_ - 1 - && ((jcp.ic - ic) % jcp.ic_block != 0)); + && ((jcp.ic - ic) % jcp.ic_block != 0)); // Using blk_off to offset batch is motivated input\output striding aligment // See `blk_off` definition. @@ -498,9 +499,9 @@ void brgemm_1x1_convolution_fwd_t::exec_ker( for (int k = 0; k < n_ic_blocks; k++) { const size_t ic_off = jcp.is_reduced_rtus - ? (brgemm_is_ic_tail - ? jcp.ic_without_padding - jcp.rtus_ic_size - : 0) + ? (brgemm_is_ic_tail ? jcp.ic_without_padding + - jcp.rtus_ic_size + : 0) : (ic_block_s + k) * jcp.ic_block; const size_t src_ic = ic_off; const auto wei_ic = ic + ic_off; @@ -763,7 +764,7 @@ status_t brgemm_1x1_convolution_fwd_t::execute_forward_all( brgemm_batch_element_t *const brg_batch_global = (jcp.brg_type != brgemm_strd) ? scratchpad.template get( - key_brgemm_primitive_batch) + key_brgemm_primitive_batch) : nullptr; char *const c_buffer_global = (jcp.use_buffer) ? scratchpad.template get(key_brgemm_primitive_buffer) diff --git a/src/cpu/x64/jit_brgemm_conv.cpp b/src/cpu/x64/jit_brgemm_conv.cpp index 1f8622611ab..ef6dd033069 100644 --- a/src/cpu/x64/jit_brgemm_conv.cpp +++ b/src/cpu/x64/jit_brgemm_conv.cpp @@ -47,9 +47,9 @@ int brgemm_convolution_fwd_t::pd_t::get_brg_idx(int m, int kd_e, int kh_b, int kh_e) const { const auto brg_idx = jcp_.use_uker ? brg_indices.find({m, is_N_tail, is_K_tail, do_initialization, - kd_b, kd_e, kh_b, kh_e}) + kd_b, kd_e, kh_b, kh_e}) : brg_indices.find({m, is_N_tail, is_K_tail, do_initialization, 0, - jcp_.kd, 0, jcp_.kh}); + jcp_.kd, 0, jcp_.kh}); if (brg_idx == brg_indices.end()) return -1; return brg_idx->second; } @@ -676,7 +676,8 @@ status_t brgemm_convolution_fwd_t::pd_t::init(engine_t *engine) { brgemm_convolution_utils::set_amx_wsp_per_thread(jcp_); auto scratchpad = scratchpad_registry().registrar(); - brgemm_convolution_utils::init_scratchpad(scratchpad, jcp_); + CHECK(brgemm_convolution_utils::init_scratchpad( + scratchpad, jcp_, *src_md(), *weights_md(), *dst_md())); return status::success; } @@ -1358,7 +1359,7 @@ status_t brgemm_convolution_fwd_t::execute(const exec_ctx_t &ctx) const { = brgemm_convolution_utils::uses_batch_elements( jcp.brg_type, jcp.exec_type) ? scratchpad.template get( - key_brgemm_primitive_batch) + key_brgemm_primitive_batch) : nullptr; char *const __restrict c_buffer_global = (jcp.use_buffer) ? scratchpad.template get(key_brgemm_primitive_buffer) @@ -1372,12 +1373,12 @@ status_t brgemm_convolution_fwd_t::execute(const exec_ctx_t &ctx) const { : nullptr; int32_t *src_zp_comp_base = jcp.src_zero_point ? (jcp.req_cal_comp_pad ? scratchpad.template get( - key_brgemm_primitive_zp_comp_a) + key_brgemm_primitive_zp_comp_a) : zp_compensation) : nullptr; int32_t *s8s8_comp_base = jcp.s8s8_compensation_required ? (jcp.req_cal_comp_pad ? scratchpad.template get( - key_brgemm_primitive_buffer_comp) + key_brgemm_primitive_buffer_comp) : s8s8_compensation) : nullptr; @@ -1584,13 +1585,13 @@ status_t brgemm_convolution_fwd_t::cal_compensation( ? div_up(jcp.oc_block, inp_oc_block) : jcp.nb_oc; const auto wei_offs = is_relo_with_relo_weights - ? (jcp.is_relo_wi() - ? ((((g * nb_oc + wei_ocb) * KD) + kd_b) - * KH - + kh_b) - * KW * jcp.ic * inp_oc_block - : (((g * nb_oc + wei_ocb) * KH * KW) + kh_b) - * jcp.ic * inp_oc_block) + ? (jcp.is_relo_wi() ? ((((g * nb_oc + wei_ocb) * KD) + kd_b) + * KH + + kh_b) + * KW * jcp.ic * inp_oc_block + : (((g * nb_oc + wei_ocb) * KH * KW) + + kh_b) + * jcp.ic * inp_oc_block) : g * _pd->wei_g_stride + wei_ocb * _pd->wei_ocb_stride + kd_b * _pd->wei_kd_stride + kh_b * _pd->wei_kh_stride @@ -1674,21 +1675,21 @@ void brgemm_convolution_fwd_t::perform_outwork( p.apply_comp = has_postcomp; p.a_zp_compensation = has_postcomp && jcp.src_zero_point ? &btc.src_zp_comp_ptr[comp_ker_offs - + (ow_pw_s - ow) * comp_ow_sz] + + (ow_pw_s - ow) * comp_ow_sz] : btc.src_zp_comp_ptr; p.s8s8_compensation = has_postcomp && jcp.s8s8_compensation_required ? &btc.s8s8_comp_ptr[comp_ker_offs - + (ow_pw_s - ow) * comp_ow_sz] + + (ow_pw_s - ow) * comp_ow_sz] : btc.s8s8_comp_ptr; p.ptr_out = dst_base + dst_dsz * (btc.od * dst_h_sz + btc.oh * dst_w_sz + ow_pw_s * jcp.oc_without_padding); - p.ptr_in = static_cast( - jcp.use_buffer ? ( - btc.c_buffer + acc_dsz * (ow_pw_s - ow) * jcp.LDC) - : p.ptr_out); + p.ptr_in = static_cast(jcp.use_buffer + ? (btc.c_buffer + + acc_dsz * (ow_pw_s - ow) * jcp.LDC) + : p.ptr_out); } else { p.apply_comp = has_postcomp; char *const ptr_Cz = jcp.use_buffer @@ -1838,7 +1839,7 @@ void brgemm_convolution_fwd_t::maybe_conv_inp(brgemm_thread_ctx_t &btc, const auto icb = btc.icc * jcp.nb_ic_blocking; #define bmask(icb, odb, ohb, owb) \ - btc.inp_buffer_mask[(((icb)*jcp.nb_od + (odb)) * jcp.nb_oh + (ohb)) \ + btc.inp_buffer_mask[(((icb) * jcp.nb_od + (odb)) * jcp.nb_oh + (ohb)) \ * jcp.nb_ow \ + (owb)] @@ -2168,7 +2169,7 @@ void brgemm_convolution_fwd_t::ker_base(brgemm_thread_ctx_t &btc) const { if (ow_l > 0) { const size_t comp_ker_offs = do_postwork ? get_comp_offset(btc.g, btc.ocb, 0, ow_b, kd_s, kd_f, kh_s, - kh_f, 0, KW) + kh_f, 0, KW) : 0; if (nb_ic_b > 0) { @@ -2312,7 +2313,7 @@ void brgemm_convolution_fwd_t::ker_trans(brgemm_thread_ctx_t &btc) const { + src_dsz * ((jcp.copy_block_only ? 0 : ((icb + ic_block_s) - * _pd->pbuf_d_sz))) + * _pd->pbuf_d_sz))) + (jcp.is_relo_whi() ? src_dsz * btc.ohb * ((jcp.oh_block - 1) * _pd->pbuf_w_sz + jcp.stride_h @@ -2345,7 +2346,7 @@ void brgemm_convolution_fwd_t::ker_trans(brgemm_thread_ctx_t &btc) const { const auto comp_ker_offs = do_postwork ? get_comp_offset(btc.g, btc.ocb, btc.oh, ow_b, kd_s, kd_f, - comp_kh_s, comp_kh_f, 0, KW) + comp_kh_s, comp_kh_f, 0, KW) : 0; if (nb_ic_b > 0) { diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index b2317596c0e..479fe926f9b 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -893,7 +893,7 @@ float brg_blocking_t::est_eff() { const dim_t max_job = (loop_order == loop_ndhwgc) ? grid_coverage(thread_job, oc, ngroups, oc_block, sp, sp_block) : grid_coverage(thread_job, sp, static_cast(nb_od) * nb_oh, - sp_block, oc, oc_block); + sp_block, oc, oc_block); const dim_t sum_job = static_cast(mb) * od * oh * ow * ngroups * oc; const float job_eff = max_job == 0 @@ -1211,8 +1211,10 @@ void brg_blocking_t::iterate_ker_block(brg_blocking_t &best_brgb, int kd_block_, // WA coredump: ws:2, pl: 1, will result r_pad=-1 and incorrect post ops jit kernel. // case: onednn_verbose,exec,cpu,convolution,brgconv:avx512_core,forward_inference,src_f32::blocked:acb:f0 wei_f32:p:blocked:Acb16a:f0 bia_undef::undef::f0 dst_f32::blocked:acb:f0,attr-post-ops:sum:1:0:f32+eltwise_hardswish+binary_min:f32:0+binary_max:f32:0+binary_mul:f32:0+binary_add:f32:0+eltwise_round_half_to_even+binary_mul:f32:0+binary_add:f32:0 ,alg:convolution_direct,mb2_ic112oc6_iw7ow4kw1sw2dw0pw1,63700.4 if (exec_type == exec_base) - use_buffer = use_buffer || (maybe_use_buffer - && (iwp != iw || (l_pad + nstl::max(0, r_pad)) > 0)); + use_buffer = use_buffer + || (maybe_use_buffer + && (iwp != iw + || (l_pad + nstl::max(0, r_pad)) > 0)); const status_t st = estimate_brgemm_ur(); if (st != status::success) continue; @@ -1323,7 +1325,7 @@ float brg_blocking_t::est_eff_1x1() { const auto amx_fac = maskrcnn_cond ? (div_up(M + M_tail, 16) / (M_n_sp_blks + M_tail_n_sp_blks)) : (static_cast(div_up(M + M_tail, 16)) - / (M_n_sp_blks + M_tail_n_sp_blks)); + / (M_n_sp_blks + M_tail_n_sp_blks)); const auto brgemm_microkernel_eff = is_amx(isa) ? amx_fac * (static_cast(ocb_ave) * spb_ave) @@ -1380,15 +1382,15 @@ float brg_blocking_t::est_eff_1x1() { if (is_os_blocking) { max_job = (loop_order == loop_ndhwgc) ? grid_coverage(thread_job, oc, ngroups, oc_block, os, - static_cast(nb_os_blocking) * sp_block) + static_cast(nb_os_blocking) * sp_block) : grid_coverage(thread_job, os, 1, - static_cast(nb_os_blocking) * sp_block, oc, - oc_block); + static_cast(nb_os_blocking) * sp_block, oc, + oc_block); } else { max_job = (loop_order == loop_ndhwgc) ? grid_coverage(thread_job, oc, ngroups, oc_block, sp, sp_block) : grid_coverage(thread_job, sp, static_cast(od) * oh, - sp_block, oc, oc_block); + sp_block, oc, oc_block); } const dim_t sum_job = static_cast(mb) * od * oh * ow * ngroups * oc; @@ -1567,9 +1569,9 @@ void brg_blocking_t::calc_blocks_1x1() { const auto max_os_block_thr = (src_dsz * ic >= 1024 && src_dsz * ic < 4096) ? nstl::max(nstl::min(16, os), - div_up(os, div_up(nthr, mb * div_up(oc, oc_block)))) + div_up(os, div_up(nthr, mb * div_up(oc, oc_block)))) : nstl::max(div_up(2048, oc_block), - static_cast(div_up(mb * ngroups * os, nthr))); + static_cast(div_up(mb * ngroups * os, nthr))); const auto max_os_block_L2 = max_sp_block_L2; auto max_os_block_aliasing = 1000000 / nthr; @@ -2369,12 +2371,13 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, dim_t ds = jcp.copy_block_only ? (brg_blocking_t::get_inp_size(jcp.idp, jcp.od_block, jcp.kd, jcp.stride_d, jcp.dilate_d) - + nstl::max(0, jcp.f_pad) + nstl::max(0, jcp.back_pad)) + + nstl::max(0, jcp.f_pad) + + nstl::max(0, jcp.back_pad)) : jcp.idp; dim_t hs = jcp.copy_block_only ? (brg_blocking_t::get_inp_size(jcp.ihp, jcp.oh_block, jcp.kh, jcp.stride_h, jcp.dilate_h) - + nstl::max(0, jcp.t_pad) + nstl::max(0, jcp.b_pad)) + + nstl::max(0, jcp.t_pad) + nstl::max(0, jcp.b_pad)) : jcp.ihp; if (jcp.is_os_blocking) hs = div_up(rnd_up(hs * jcp.iwp, jcp.brgM), jcp.iwp) @@ -2707,8 +2710,13 @@ void set_amx_wsp_per_thread(jit_brgemm_conv_conf_t &jcp) { = utils::rnd_up(jcp.amx_buf_size_per_thread + 1, P4K); } -void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_brgemm_conv_conf_t &jcp) { +status_t init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_brgemm_conv_conf_t &jcp, const memory_desc_t &src_md, + const memory_desc_t &weights_md, const memory_desc_t &dst_md) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + if (uses_batch_elements(jcp.brg_type, jcp.exec_type)) { scratchpad.book(key_brgemm_primitive_batch, static_cast(jcp.nthr) * jcp.adjusted_batch_size, @@ -2767,6 +2775,20 @@ void init_scratchpad(memory_tracking::registrar_t &scratchpad, scratchpad.book(key_conv_dst_scales, static_cast(jcp.nthr) * sizeof(float), P4K); } + + // Check scratchpad size to avoid allocating huge buffers + if (jcp.exec_type == exec_trans) { + constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32 + << 30; // 32Gb - TODO: may it's too large? + const size_t scratchpad_limit_by_tensor_sizes = (size_t)64 * jcp.nthr + * (src_d.size() + weights_d.size() + dst_d.size()); + const size_t scratchpad_limit + = nstl::min(scratchpad_limit_by_absolute_value, + scratchpad_limit_by_tensor_sizes); + if (scratchpad.size() > scratchpad_limit) return status::unimplemented; + } + + return status::success; } void balance_bwd_w(jit_brgemm_conv_conf_t &jcp) { diff --git a/src/cpu/x64/jit_brgemm_conv_utils.hpp b/src/cpu/x64/jit_brgemm_conv_utils.hpp index d4cdc380dc1..ac0f264edda 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.hpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.hpp @@ -53,8 +53,9 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, void set_amx_wsp_per_thread(jit_brgemm_conv_conf_t &jcp); -void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_brgemm_conv_conf_t &jcp); +status_t init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_brgemm_conv_conf_t &jcp, const memory_desc_t &src_md, + const memory_desc_t &weights_md, const memory_desc_t &dst_md); status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &src_md,