Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/cpu/x64/jit_brgemm_1x1_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ status_t brgemm_1x1_convolution_fwd_t<isa>::pd_t::init(engine_t *engine) {

brgemm_convolution_utils::set_amx_wsp_per_thread(jcp_);
auto scratchpad = scratchpad_registry().registrar();
brgemm_convolution_utils::init_scratchpad(scratchpad, jcp_);
CHECK(brgemm_convolution_utils::init_scratchpad(
scratchpad, jcp_, *src_md(), *weights_md(), *dst_md()));

return status::success;
}
Expand Down Expand Up @@ -441,7 +442,7 @@ void brgemm_1x1_convolution_fwd_t<isa>::exec_ker(
const bool is_ic_tail = jcp.is_reduced_rtus
? is_last_os
: (icc == pd()->ic_chunks_ - 1
&& ((jcp.ic - ic) % jcp.ic_block != 0));
&& ((jcp.ic - ic) % jcp.ic_block != 0));

// Using blk_off to offset batch is motivated input\output striding aligment
// See `blk_off` definition.
Expand Down Expand Up @@ -498,9 +499,9 @@ void brgemm_1x1_convolution_fwd_t<isa>::exec_ker(

for (int k = 0; k < n_ic_blocks; k++) {
const size_t ic_off = jcp.is_reduced_rtus
? (brgemm_is_ic_tail
? jcp.ic_without_padding - jcp.rtus_ic_size
: 0)
? (brgemm_is_ic_tail ? jcp.ic_without_padding
- jcp.rtus_ic_size
: 0)
: (ic_block_s + k) * jcp.ic_block;
const size_t src_ic = ic_off;
const auto wei_ic = ic + ic_off;
Expand Down Expand Up @@ -763,7 +764,7 @@ status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
brgemm_batch_element_t *const brg_batch_global
= (jcp.brg_type != brgemm_strd)
? scratchpad.template get<brgemm_batch_element_t>(
key_brgemm_primitive_batch)
key_brgemm_primitive_batch)
: nullptr;
char *const c_buffer_global = (jcp.use_buffer)
? scratchpad.template get<char>(key_brgemm_primitive_buffer)
Expand Down
47 changes: 24 additions & 23 deletions src/cpu/x64/jit_brgemm_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ int brgemm_convolution_fwd_t<isa>::pd_t::get_brg_idx(int m,
int kd_e, int kh_b, int kh_e) const {
const auto brg_idx = jcp_.use_uker
? brg_indices.find({m, is_N_tail, is_K_tail, do_initialization,
kd_b, kd_e, kh_b, kh_e})
kd_b, kd_e, kh_b, kh_e})
: brg_indices.find({m, is_N_tail, is_K_tail, do_initialization, 0,
jcp_.kd, 0, jcp_.kh});
jcp_.kd, 0, jcp_.kh});
if (brg_idx == brg_indices.end()) return -1;
return brg_idx->second;
}
Expand Down Expand Up @@ -676,7 +676,8 @@ status_t brgemm_convolution_fwd_t<isa>::pd_t::init(engine_t *engine) {

brgemm_convolution_utils::set_amx_wsp_per_thread(jcp_);
auto scratchpad = scratchpad_registry().registrar();
brgemm_convolution_utils::init_scratchpad(scratchpad, jcp_);
CHECK(brgemm_convolution_utils::init_scratchpad(
scratchpad, jcp_, *src_md(), *weights_md(), *dst_md()));

return status::success;
}
Expand Down Expand Up @@ -1358,7 +1359,7 @@ status_t brgemm_convolution_fwd_t<isa>::execute(const exec_ctx_t &ctx) const {
= brgemm_convolution_utils::uses_batch_elements(
jcp.brg_type, jcp.exec_type)
? scratchpad.template get<brgemm_batch_element_t>(
key_brgemm_primitive_batch)
key_brgemm_primitive_batch)
: nullptr;
char *const __restrict c_buffer_global = (jcp.use_buffer)
? scratchpad.template get<char>(key_brgemm_primitive_buffer)
Expand All @@ -1372,12 +1373,12 @@ status_t brgemm_convolution_fwd_t<isa>::execute(const exec_ctx_t &ctx) const {
: nullptr;
int32_t *src_zp_comp_base = jcp.src_zero_point
? (jcp.req_cal_comp_pad ? scratchpad.template get<int32_t>(
key_brgemm_primitive_zp_comp_a)
key_brgemm_primitive_zp_comp_a)
: zp_compensation)
: nullptr;
int32_t *s8s8_comp_base = jcp.s8s8_compensation_required
? (jcp.req_cal_comp_pad ? scratchpad.template get<int32_t>(
key_brgemm_primitive_buffer_comp)
key_brgemm_primitive_buffer_comp)
: s8s8_compensation)
: nullptr;

Expand Down Expand Up @@ -1584,13 +1585,13 @@ status_t brgemm_convolution_fwd_t<isa>::cal_compensation(
? div_up(jcp.oc_block, inp_oc_block)
: jcp.nb_oc;
const auto wei_offs = is_relo_with_relo_weights
? (jcp.is_relo_wi()
? ((((g * nb_oc + wei_ocb) * KD) + kd_b)
* KH
+ kh_b)
* KW * jcp.ic * inp_oc_block
: (((g * nb_oc + wei_ocb) * KH * KW) + kh_b)
* jcp.ic * inp_oc_block)
? (jcp.is_relo_wi() ? ((((g * nb_oc + wei_ocb) * KD) + kd_b)
* KH
+ kh_b)
* KW * jcp.ic * inp_oc_block
: (((g * nb_oc + wei_ocb) * KH * KW)
+ kh_b)
* jcp.ic * inp_oc_block)
: g * _pd->wei_g_stride + wei_ocb * _pd->wei_ocb_stride
+ kd_b * _pd->wei_kd_stride
+ kh_b * _pd->wei_kh_stride
Expand Down Expand Up @@ -1674,21 +1675,21 @@ void brgemm_convolution_fwd_t<isa>::perform_outwork(
p.apply_comp = has_postcomp;
p.a_zp_compensation = has_postcomp && jcp.src_zero_point
? &btc.src_zp_comp_ptr[comp_ker_offs
+ (ow_pw_s - ow) * comp_ow_sz]
+ (ow_pw_s - ow) * comp_ow_sz]
: btc.src_zp_comp_ptr;
p.s8s8_compensation = has_postcomp && jcp.s8s8_compensation_required
? &btc.s8s8_comp_ptr[comp_ker_offs
+ (ow_pw_s - ow) * comp_ow_sz]
+ (ow_pw_s - ow) * comp_ow_sz]
: btc.s8s8_comp_ptr;

p.ptr_out = dst_base
+ dst_dsz
* (btc.od * dst_h_sz + btc.oh * dst_w_sz
+ ow_pw_s * jcp.oc_without_padding);
p.ptr_in = static_cast<void *>(
jcp.use_buffer ? (
btc.c_buffer + acc_dsz * (ow_pw_s - ow) * jcp.LDC)
: p.ptr_out);
p.ptr_in = static_cast<void *>(jcp.use_buffer
? (btc.c_buffer
+ acc_dsz * (ow_pw_s - ow) * jcp.LDC)
: p.ptr_out);
} else {
p.apply_comp = has_postcomp;
char *const ptr_Cz = jcp.use_buffer
Expand Down Expand Up @@ -1838,7 +1839,7 @@ void brgemm_convolution_fwd_t<isa>::maybe_conv_inp(brgemm_thread_ctx_t &btc,
const auto icb = btc.icc * jcp.nb_ic_blocking;

#define bmask(icb, odb, ohb, owb) \
btc.inp_buffer_mask[(((icb)*jcp.nb_od + (odb)) * jcp.nb_oh + (ohb)) \
btc.inp_buffer_mask[(((icb) * jcp.nb_od + (odb)) * jcp.nb_oh + (ohb)) \
* jcp.nb_ow \
+ (owb)]

Expand Down Expand Up @@ -2168,7 +2169,7 @@ void brgemm_convolution_fwd_t<isa>::ker_base(brgemm_thread_ctx_t &btc) const {
if (ow_l > 0) {
const size_t comp_ker_offs = do_postwork
? get_comp_offset(btc.g, btc.ocb, 0, ow_b, kd_s, kd_f, kh_s,
kh_f, 0, KW)
kh_f, 0, KW)
: 0;

if (nb_ic_b > 0) {
Expand Down Expand Up @@ -2312,7 +2313,7 @@ void brgemm_convolution_fwd_t<isa>::ker_trans(brgemm_thread_ctx_t &btc) const {
+ src_dsz
* ((jcp.copy_block_only ? 0
: ((icb + ic_block_s)
* _pd->pbuf_d_sz)))
* _pd->pbuf_d_sz)))
+ (jcp.is_relo_whi() ? src_dsz * btc.ohb
* ((jcp.oh_block - 1) * _pd->pbuf_w_sz
+ jcp.stride_h
Expand Down Expand Up @@ -2345,7 +2346,7 @@ void brgemm_convolution_fwd_t<isa>::ker_trans(brgemm_thread_ctx_t &btc) const {

const auto comp_ker_offs = do_postwork
? get_comp_offset(btc.g, btc.ocb, btc.oh, ow_b, kd_s, kd_f,
comp_kh_s, comp_kh_f, 0, KW)
comp_kh_s, comp_kh_f, 0, KW)
: 0;

if (nb_ic_b > 0) {
Expand Down
50 changes: 36 additions & 14 deletions src/cpu/x64/jit_brgemm_conv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ float brg_blocking_t::est_eff() {
const dim_t max_job = (loop_order == loop_ndhwgc)
? grid_coverage(thread_job, oc, ngroups, oc_block, sp, sp_block)
: grid_coverage(thread_job, sp, static_cast<dim_t>(nb_od) * nb_oh,
sp_block, oc, oc_block);
sp_block, oc, oc_block);
const dim_t sum_job = static_cast<dim_t>(mb) * od * oh * ow * ngroups * oc;

const float job_eff = max_job == 0
Expand Down Expand Up @@ -1211,8 +1211,10 @@ void brg_blocking_t::iterate_ker_block(brg_blocking_t &best_brgb, int kd_block_,
// WA coredump: ws:2, pl: 1, will result r_pad=-1 and incorrect post ops jit kernel.
// case: onednn_verbose,exec,cpu,convolution,brgconv:avx512_core,forward_inference,src_f32::blocked:acb:f0 wei_f32:p:blocked:Acb16a:f0 bia_undef::undef::f0 dst_f32::blocked:acb:f0,attr-post-ops:sum:1:0:f32+eltwise_hardswish+binary_min:f32:0+binary_max:f32:0+binary_mul:f32:0+binary_add:f32:0+eltwise_round_half_to_even+binary_mul:f32:0+binary_add:f32:0 ,alg:convolution_direct,mb2_ic112oc6_iw7ow4kw1sw2dw0pw1,63700.4
if (exec_type == exec_base)
use_buffer = use_buffer || (maybe_use_buffer
&& (iwp != iw || (l_pad + nstl::max(0, r_pad)) > 0));
use_buffer = use_buffer
|| (maybe_use_buffer
&& (iwp != iw
|| (l_pad + nstl::max(0, r_pad)) > 0));

const status_t st = estimate_brgemm_ur();
if (st != status::success) continue;
Expand Down Expand Up @@ -1323,7 +1325,7 @@ float brg_blocking_t::est_eff_1x1() {
const auto amx_fac = maskrcnn_cond
? (div_up(M + M_tail, 16) / (M_n_sp_blks + M_tail_n_sp_blks))
: (static_cast<float>(div_up(M + M_tail, 16))
/ (M_n_sp_blks + M_tail_n_sp_blks));
/ (M_n_sp_blks + M_tail_n_sp_blks));

const auto brgemm_microkernel_eff = is_amx(isa)
? amx_fac * (static_cast<float>(ocb_ave) * spb_ave)
Expand Down Expand Up @@ -1380,15 +1382,15 @@ float brg_blocking_t::est_eff_1x1() {
if (is_os_blocking) {
max_job = (loop_order == loop_ndhwgc)
? grid_coverage(thread_job, oc, ngroups, oc_block, os,
static_cast<dim_t>(nb_os_blocking) * sp_block)
static_cast<dim_t>(nb_os_blocking) * sp_block)
: grid_coverage(thread_job, os, 1,
static_cast<dim_t>(nb_os_blocking) * sp_block, oc,
oc_block);
static_cast<dim_t>(nb_os_blocking) * sp_block, oc,
oc_block);
} else {
max_job = (loop_order == loop_ndhwgc)
? grid_coverage(thread_job, oc, ngroups, oc_block, sp, sp_block)
: grid_coverage(thread_job, sp, static_cast<dim_t>(od) * oh,
sp_block, oc, oc_block);
sp_block, oc, oc_block);
}

const dim_t sum_job = static_cast<dim_t>(mb) * od * oh * ow * ngroups * oc;
Expand Down Expand Up @@ -1567,9 +1569,9 @@ void brg_blocking_t::calc_blocks_1x1() {
const auto max_os_block_thr
= (src_dsz * ic >= 1024 && src_dsz * ic < 4096)
? nstl::max(nstl::min(16, os),
div_up(os, div_up(nthr, mb * div_up(oc, oc_block))))
div_up(os, div_up(nthr, mb * div_up(oc, oc_block))))
: nstl::max(div_up(2048, oc_block),
static_cast<int>(div_up(mb * ngroups * os, nthr)));
static_cast<int>(div_up(mb * ngroups * os, nthr)));
const auto max_os_block_L2 = max_sp_block_L2;

auto max_os_block_aliasing = 1000000 / nthr;
Expand Down Expand Up @@ -2369,12 +2371,13 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
dim_t ds = jcp.copy_block_only
? (brg_blocking_t::get_inp_size(jcp.idp, jcp.od_block, jcp.kd,
jcp.stride_d, jcp.dilate_d)
+ nstl::max(0, jcp.f_pad) + nstl::max(0, jcp.back_pad))
+ nstl::max(0, jcp.f_pad)
+ nstl::max(0, jcp.back_pad))
: jcp.idp;
dim_t hs = jcp.copy_block_only
? (brg_blocking_t::get_inp_size(jcp.ihp, jcp.oh_block, jcp.kh,
jcp.stride_h, jcp.dilate_h)
+ nstl::max(0, jcp.t_pad) + nstl::max(0, jcp.b_pad))
+ nstl::max(0, jcp.t_pad) + nstl::max(0, jcp.b_pad))
: jcp.ihp;
if (jcp.is_os_blocking)
hs = div_up(rnd_up(hs * jcp.iwp, jcp.brgM), jcp.iwp)
Expand Down Expand Up @@ -2707,8 +2710,13 @@ void set_amx_wsp_per_thread(jit_brgemm_conv_conf_t &jcp) {
= utils::rnd_up(jcp.amx_buf_size_per_thread + 1, P4K);
}

void init_scratchpad(memory_tracking::registrar_t &scratchpad,
const jit_brgemm_conv_conf_t &jcp) {
status_t init_scratchpad(memory_tracking::registrar_t &scratchpad,
const jit_brgemm_conv_conf_t &jcp, const memory_desc_t &src_md,
const memory_desc_t &weights_md, const memory_desc_t &dst_md) {
const memory_desc_wrapper src_d(&src_md);
const memory_desc_wrapper weights_d(&weights_md);
const memory_desc_wrapper dst_d(&dst_md);

if (uses_batch_elements(jcp.brg_type, jcp.exec_type)) {
scratchpad.book(key_brgemm_primitive_batch,
static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size,
Expand Down Expand Up @@ -2767,6 +2775,20 @@ void init_scratchpad(memory_tracking::registrar_t &scratchpad,
scratchpad.book(key_conv_dst_scales,
static_cast<size_t>(jcp.nthr) * sizeof(float), P4K);
}

// Check scratchpad size to avoid allocating huge buffers
if (jcp.exec_type == exec_trans) {
constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32
<< 30; // 32Gb - TODO: may it's too large?
const size_t scratchpad_limit_by_tensor_sizes = (size_t)64 * jcp.nthr
* (src_d.size() + weights_d.size() + dst_d.size());
const size_t scratchpad_limit
= nstl::min(scratchpad_limit_by_absolute_value,
scratchpad_limit_by_tensor_sizes);
if (scratchpad.size() > scratchpad_limit) return status::unimplemented;
}

return status::success;
}

void balance_bwd_w(jit_brgemm_conv_conf_t &jcp) {
Expand Down
5 changes: 3 additions & 2 deletions src/cpu/x64/jit_brgemm_conv_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,

void set_amx_wsp_per_thread(jit_brgemm_conv_conf_t &jcp);

void init_scratchpad(memory_tracking::registrar_t &scratchpad,
const jit_brgemm_conv_conf_t &jcp);
status_t init_scratchpad(memory_tracking::registrar_t &scratchpad,
const jit_brgemm_conv_conf_t &jcp, const memory_desc_t &src_md,
const memory_desc_t &weights_md, const memory_desc_t &dst_md);

status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp,
const convolution_desc_t &cd, memory_desc_t &src_md,
Expand Down
Loading