Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/pythonapi/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Techniques are enabled by calling methods on the ``mcdc.simulation`` singleton:
- ``mcdc.simulation.global_weight_roulette(weight_threshold=0.0, weight_target=1.0)``
- ``mcdc.simulation.population_control(active=True)``
- ``mcdc.simulation.weighted_emission(active=True, weight_target=1.0)``
- ``mcdc.simulation.weight_windows(weight_windows, mesh=None, energy=None)``
- ``mcdc.simulation.weight_windows(weight_windows, mesh=None, energy=None, time=None, mu=None, azimuthal=None)``

Running
-------
Expand Down
53 changes: 53 additions & 0 deletions mcdc/code_factory/numba_objects_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,29 @@ def generate_mcdc_access(targets):
text_setter += _accessor_4d_element(
object_name, attribute_name, shape[1], shape[2], shape[3], True
)
elif len(shape) == 7:
text_getter += _accessor_7d_element(
object_name,
attribute_name,
shape[1],
shape[2],
shape[3],
shape[4],
shape[5],
shape[6],
)

text_setter += _accessor_7d_element(
object_name,
attribute_name,
shape[1],
shape[2],
shape[3],
shape[4],
shape[5],
shape[6],
True,
)
text_getter += _accessor_chunk(object_name, attribute_name)
text_setter += _accessor_chunk(object_name, attribute_name, True)

Expand Down Expand Up @@ -1184,6 +1207,36 @@ def _accessor_4d_element(
return text


def _accessor_7d_element(
object_name,
attribute_name,
stride_2,
stride_3,
stride_4,
stride_5,
stride_6,
stride_7,
setter=False,
):
text = f"@njit\n"
if setter:
text += f"def {attribute_name}(index_1, index_2, index_3, index_4, index_5, index_6, index_7, {object_name}, data, value):\n"
else:
text += f"def {attribute_name}(index_1, index_2, index_3, index_4, index_5, index_6, index_7, {object_name}, data):\n"
text += f' offset = {object_name}["{attribute_name}_offset"]\n'
text += f' stride_2 = {object_name}["{stride_2}"]\n'
text += f' stride_3 = {object_name}["{stride_3}"]\n'
text += f' stride_4 = {object_name}["{stride_4}"]\n'
text += f' stride_5 = {object_name}["{stride_5}"]\n'
text += f' stride_6 = {object_name}["{stride_6}"]\n'
text += f' stride_7 = {object_name}["{stride_7}"]\n'
if setter:
text += f" data[offset + index_1 * stride_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_4 * stride_5 * stride_6 * stride_7 + index_5 * stride_6 * stride_7 + index_6 * stride_7 + index_7] = value\n\n\n"
else:
text += f" return data[offset + index_1 * stride_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_4 * stride_5 * stride_6 * stride_7 + index_5 * stride_6 * stride_7 + index_6 * stride_7 + index_7]\n\n\n"
return text


# ======================================================================================
# Misc.
# ======================================================================================
Expand Down
126 changes: 111 additions & 15 deletions mcdc/mcdc_get/weight_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,35 @@
from numba import njit


@njit
def time_bounds(index, weight_windows, data):
offset = weight_windows["time_bounds_offset"]
return data[offset + index]


@njit
def time_bounds_all(weight_windows, data):
start = weight_windows["time_bounds_offset"]
size = weight_windows["time_bounds_length"]
end = start + size
return data[start:end]


@njit
def time_bounds_last(weight_windows, data):
start = weight_windows["time_bounds_offset"]
size = weight_windows["time_bounds_length"]
end = start + size
return data[end - 1]


@njit
def time_bounds_chunk(start, length, weight_windows, data):
start += weight_windows["time_bounds_offset"]
end = start + length
return data[start:end]


@njit
def energy_bounds(index, weight_windows, data):
offset = weight_windows["energy_bounds_offset"]
Expand Down Expand Up @@ -33,12 +62,73 @@ def energy_bounds_chunk(start, length, weight_windows, data):


@njit
def lower_weights(index_1, index_2, index_3, index_4, weight_windows, data):
def mu_bounds(index, weight_windows, data):
offset = weight_windows["mu_bounds_offset"]
return data[offset + index]


@njit
def mu_bounds_all(weight_windows, data):
start = weight_windows["mu_bounds_offset"]
size = weight_windows["mu_bounds_length"]
end = start + size
return data[start:end]


@njit
def mu_bounds_last(weight_windows, data):
start = weight_windows["mu_bounds_offset"]
size = weight_windows["mu_bounds_length"]
end = start + size
return data[end - 1]


@njit
def mu_bounds_chunk(start, length, weight_windows, data):
start += weight_windows["mu_bounds_offset"]
end = start + length
return data[start:end]


@njit
def azi_bounds(index, weight_windows, data):
offset = weight_windows["azi_bounds_offset"]
return data[offset + index]


@njit
def azi_bounds_all(weight_windows, data):
start = weight_windows["azi_bounds_offset"]
size = weight_windows["azi_bounds_length"]
end = start + size
return data[start:end]


@njit
def azi_bounds_last(weight_windows, data):
start = weight_windows["azi_bounds_offset"]
size = weight_windows["azi_bounds_length"]
end = start + size
return data[end - 1]


@njit
def azi_bounds_chunk(start, length, weight_windows, data):
start += weight_windows["azi_bounds_offset"]
end = start + length
return data[start:end]


@njit
def lower_weights(index_1, index_2, index_3, index_4, index_5, index_6, index_7, weight_windows, data):
offset = weight_windows["lower_weights_offset"]
stride_2 = weight_windows["Nx"]
stride_3 = weight_windows["Ny"]
stride_4 = weight_windows["Nz"]
return data[offset + index_1 * stride_2 * stride_3 * stride_4 + index_2 * stride_3 * stride_4 + index_3 * stride_4 + index_4]
stride_2 = weight_windows["Ne"]
stride_3 = weight_windows["Nmu"]
stride_4 = weight_windows["Na"]
stride_5 = weight_windows["Nx"]
stride_6 = weight_windows["Ny"]
stride_7 = weight_windows["Nz"]
return data[offset + index_1 * stride_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_4 * stride_5 * stride_6 * stride_7 + index_5 * stride_6 * stride_7 + index_6 * stride_7 + index_7]


@njit
Expand All @@ -49,12 +139,15 @@ def lower_weights_chunk(start, length, weight_windows, data):


@njit
def target_weights(index_1, index_2, index_3, index_4, weight_windows, data):
def target_weights(index_1, index_2, index_3, index_4, index_5, index_6, index_7, weight_windows, data):
offset = weight_windows["target_weights_offset"]
stride_2 = weight_windows["Nx"]
stride_3 = weight_windows["Ny"]
stride_4 = weight_windows["Nz"]
return data[offset + index_1 * stride_2 * stride_3 * stride_4 + index_2 * stride_3 * stride_4 + index_3 * stride_4 + index_4]
stride_2 = weight_windows["Ne"]
stride_3 = weight_windows["Nmu"]
stride_4 = weight_windows["Na"]
stride_5 = weight_windows["Nx"]
stride_6 = weight_windows["Ny"]
stride_7 = weight_windows["Nz"]
return data[offset + index_1 * stride_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_4 * stride_5 * stride_6 * stride_7 + index_5 * stride_6 * stride_7 + index_6 * stride_7 + index_7]


@njit
Expand All @@ -65,12 +158,15 @@ def target_weights_chunk(start, length, weight_windows, data):


@njit
def upper_weights(index_1, index_2, index_3, index_4, weight_windows, data):
def upper_weights(index_1, index_2, index_3, index_4, index_5, index_6, index_7, weight_windows, data):
offset = weight_windows["upper_weights_offset"]
stride_2 = weight_windows["Nx"]
stride_3 = weight_windows["Ny"]
stride_4 = weight_windows["Nz"]
return data[offset + index_1 * stride_2 * stride_3 * stride_4 + index_2 * stride_3 * stride_4 + index_3 * stride_4 + index_4]
stride_2 = weight_windows["Ne"]
stride_3 = weight_windows["Nmu"]
stride_4 = weight_windows["Na"]
stride_5 = weight_windows["Nx"]
stride_6 = weight_windows["Ny"]
stride_7 = weight_windows["Nz"]
return data[offset + index_1 * stride_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_4 * stride_5 * stride_6 * stride_7 + index_5 * stride_6 * stride_7 + index_6 * stride_7 + index_7]


@njit
Expand Down
126 changes: 111 additions & 15 deletions mcdc/mcdc_set/weight_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,35 @@
from numba import njit


@njit
def time_bounds(index, weight_windows, data, value):
offset = weight_windows["time_bounds_offset"]
data[offset + index] = value


@njit
def time_bounds_all(weight_windows, data, value):
start = weight_windows["time_bounds_offset"]
size = weight_windows["time_bounds_length"]
end = start + size
data[start:end] = value


@njit
def time_bounds_last(weight_windows, data, value):
start = weight_windows["time_bounds_offset"]
size = weight_windows["time_bounds_length"]
end = start + size
data[end - 1] = value


@njit
def time_bounds_chunk(start, length, weight_windows, data, value):
start += weight_windows["time_bounds_offset"]
end = start + length
data[start:end] = value


@njit
def energy_bounds(index, weight_windows, data, value):
offset = weight_windows["energy_bounds_offset"]
Expand Down Expand Up @@ -33,12 +62,73 @@ def energy_bounds_chunk(start, length, weight_windows, data, value):


@njit
def lower_weights(index_1, index_2, index_3, index_4, weight_windows, data, value):
def mu_bounds(index, weight_windows, data, value):
offset = weight_windows["mu_bounds_offset"]
data[offset + index] = value


@njit
def mu_bounds_all(weight_windows, data, value):
start = weight_windows["mu_bounds_offset"]
size = weight_windows["mu_bounds_length"]
end = start + size
data[start:end] = value


@njit
def mu_bounds_last(weight_windows, data, value):
start = weight_windows["mu_bounds_offset"]
size = weight_windows["mu_bounds_length"]
end = start + size
data[end - 1] = value


@njit
def mu_bounds_chunk(start, length, weight_windows, data, value):
start += weight_windows["mu_bounds_offset"]
end = start + length
data[start:end] = value


@njit
def azi_bounds(index, weight_windows, data, value):
offset = weight_windows["azi_bounds_offset"]
data[offset + index] = value


@njit
def azi_bounds_all(weight_windows, data, value):
start = weight_windows["azi_bounds_offset"]
size = weight_windows["azi_bounds_length"]
end = start + size
data[start:end] = value


@njit
def azi_bounds_last(weight_windows, data, value):
start = weight_windows["azi_bounds_offset"]
size = weight_windows["azi_bounds_length"]
end = start + size
data[end - 1] = value


@njit
def azi_bounds_chunk(start, length, weight_windows, data, value):
start += weight_windows["azi_bounds_offset"]
end = start + length
data[start:end] = value


@njit
def lower_weights(index_1, index_2, index_3, index_4, index_5, index_6, index_7, weight_windows, data, value):
offset = weight_windows["lower_weights_offset"]
stride_2 = weight_windows["Nx"]
stride_3 = weight_windows["Ny"]
stride_4 = weight_windows["Nz"]
data[offset + index_1 * stride_2 * stride_3 * stride_4 + index_2 * stride_3 * stride_4 + index_3 * stride_4 + index_4] = value
stride_2 = weight_windows["Ne"]
stride_3 = weight_windows["Nmu"]
stride_4 = weight_windows["Na"]
stride_5 = weight_windows["Nx"]
stride_6 = weight_windows["Ny"]
stride_7 = weight_windows["Nz"]
data[offset + index_1 * stride_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_4 * stride_5 * stride_6 * stride_7 + index_5 * stride_6 * stride_7 + index_6 * stride_7 + index_7] = value


@njit
Expand All @@ -49,12 +139,15 @@ def lower_weights_chunk(start, length, weight_windows, data, value):


@njit
def target_weights(index_1, index_2, index_3, index_4, weight_windows, data, value):
def target_weights(index_1, index_2, index_3, index_4, index_5, index_6, index_7, weight_windows, data, value):
offset = weight_windows["target_weights_offset"]
stride_2 = weight_windows["Nx"]
stride_3 = weight_windows["Ny"]
stride_4 = weight_windows["Nz"]
data[offset + index_1 * stride_2 * stride_3 * stride_4 + index_2 * stride_3 * stride_4 + index_3 * stride_4 + index_4] = value
stride_2 = weight_windows["Ne"]
stride_3 = weight_windows["Nmu"]
stride_4 = weight_windows["Na"]
stride_5 = weight_windows["Nx"]
stride_6 = weight_windows["Ny"]
stride_7 = weight_windows["Nz"]
data[offset + index_1 * stride_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_4 * stride_5 * stride_6 * stride_7 + index_5 * stride_6 * stride_7 + index_6 * stride_7 + index_7] = value


@njit
Expand All @@ -65,12 +158,15 @@ def target_weights_chunk(start, length, weight_windows, data, value):


@njit
def upper_weights(index_1, index_2, index_3, index_4, weight_windows, data, value):
def upper_weights(index_1, index_2, index_3, index_4, index_5, index_6, index_7, weight_windows, data, value):
offset = weight_windows["upper_weights_offset"]
stride_2 = weight_windows["Nx"]
stride_3 = weight_windows["Ny"]
stride_4 = weight_windows["Nz"]
data[offset + index_1 * stride_2 * stride_3 * stride_4 + index_2 * stride_3 * stride_4 + index_3 * stride_4 + index_4] = value
stride_2 = weight_windows["Ne"]
stride_3 = weight_windows["Nmu"]
stride_4 = weight_windows["Na"]
stride_5 = weight_windows["Nx"]
stride_6 = weight_windows["Ny"]
stride_7 = weight_windows["Nz"]
data[offset + index_1 * stride_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_2 * stride_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_3 * stride_4 * stride_5 * stride_6 * stride_7 + index_4 * stride_5 * stride_6 * stride_7 + index_5 * stride_6 * stride_7 + index_6 * stride_7 + index_7] = value


@njit
Expand Down
Loading
Loading