Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 66 additions & 43 deletions src/nutils/units/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __hash__(self):
@__table.register("numpy.positive")
@__table.register("numpy.ptp")
@__table.register("numpy.real")
@__table.register("numpy.repeat")
@__table.register("numpy.reshape")
@__table.register("numpy.sum")
@__table.register("numpy.take")
Expand All @@ -113,9 +114,10 @@ def __hash__(self):
@__table.register("_operator.neg")
@__table.register("_operator.pos")
def __unary(op, *args, **kwargs):
cls, (arg0, dim0) = _unwrap(args[0])
cls = _get_monomial_class(args[0])
arg0, dim0 = unwrap(args[0])
val = op(arg0, *args[1:], **kwargs)
return cls(val, dim0)
return wrap(cls, val, dim0)

@__table.register("numpy.add")
@__table.register("numpy.hypot")
Expand All @@ -126,20 +128,24 @@ def __unary(op, *args, **kwargs):
@__table.register("_operator.mod")
@__table.register("_operator.sub")
def __add_like(op, *args, **kwargs):
cls, (arg0, dim0), (arg1, dim1) = _unwrap(args[0], args[1])
cls = _get_monomial_class(args[0], args[1])
arg0, dim0 = unwrap(args[0])
arg1, dim1 = unwrap(args[1])
if dim0 != dim1:
raise DimensionError(
f"incompatible dimensions for {op.__name__}: {dim0}, {dim1}"
)
val = op(arg0, arg1, *args[2:], **kwargs)
return cls(val, dim0)
return wrap(cls, val, dim0)

@__table.register("numpy.matmul")
@__table.register("numpy.multiply")
@__table.register("_operator.matmul")
@__table.register("_operator.mul")
def __mul_like(op, *args, **kwargs):
cls, (arg0, dim0), (arg1, dim1) = _unwrap(args[0], args[1])
cls = _get_monomial_class(args[0], args[1])
arg0, dim0 = unwrap(args[0])
arg1, dim1 = unwrap(args[1])
val = op(arg0, arg1, *args[2:], **kwargs)
return wrap(cls, val, dim0 + dim1)

Expand All @@ -150,35 +156,43 @@ def __mul_like(op, *args, **kwargs):
@__table.register("numpy.divide")
@__table.register("_operator.truediv")
def __div_like(op, *args, **kwargs):
cls, (arg0, dim0), (arg1, dim1) = _unwrap(args[0], args[1])
cls = _get_monomial_class(args[0], args[1])
arg0, dim0 = unwrap(args[0])
arg1, dim1 = unwrap(args[1])
val = op(arg0, arg1, *args[2:], **kwargs)
return wrap(cls, val, dim0 - dim1)

@__table.register("nutils.function.laplace")
def __laplace(op, *args, **kwargs):
cls, (arg0, dim0), (arg1, dim1) = _unwrap(args[0], args[1])
cls = _get_monomial_class(args[0], args[1])
arg0, dim0 = unwrap(args[0])
arg1, dim1 = unwrap(args[1])
val = op(arg0, arg1, *args[2:], **kwargs)
return wrap(cls, val, dim0 - dim1 * 2)

@__table.register("numpy.sqrt")
def __sqrt(op, *args, **kwargs):
cls, (arg0, dim0) = _unwrap(args[0])
cls = _get_monomial_class(args[0])
arg0, dim0 = unwrap(args[0])
val = op(arg0, *args[1:], **kwargs)
return cls(val, dim0 / 2)
return wrap(cls, val, dim0 / 2)

@__table.register("_operator.setitem")
def __setitem(op, *args, **kwargs):
cls, (arg0, dim0), (arg2, dim2) = _unwrap(args[0], args[2])
cls = _get_monomial_class(args[0], args[2])
arg0, dim0 = unwrap(args[0])
arg2, dim2 = unwrap(args[2])
if dim0 != dim2:
raise DimensionError(f"cannot assign {dim2} to {dim0}")
val = op(arg0, args[1], arg2, *args[3:], **kwargs)
return cls(val, dim0)
return wrap(cls, val, dim0)

@__table.register("nutils.function.jacobian")
@__table.register("numpy.power")
@__table.register("_operator.pow")
def __pow_like(op, *args, **kwargs):
cls, (arg0, dim0) = _unwrap(args[0])
cls = _get_monomial_class(args[0])
arg0, dim0 = unwrap(args[0])
val = op(arg0, *args[1:], **kwargs)
return wrap(cls, val, dim0 * args[1])

Expand Down Expand Up @@ -207,7 +221,8 @@ def __unary_op(op, *args, **kwargs):
@__table.register("_operator.lt")
@__table.register("_operator.ne")
def __binary_op(op, *args, **kwargs):
_, (arg0, dim0), (arg1, dim1) = _unwrap(args[0], args[1])
arg0, dim0 = unwrap(args[0])
arg1, dim1 = unwrap(args[1])
if dim0 != dim1 and not (isinstance(args[1], (int, float)) and args[1] == 0):
raise DimensionError(
f"incompatible dimensions for {op.__name__}: {dim0}, {dim1}"
Expand All @@ -217,50 +232,57 @@ def __binary_op(op, *args, **kwargs):
@__table.register("numpy.stack")
@__table.register("numpy.concatenate")
def __stack_like(op, *args, **kwargs):
cls, arg0, dims = _unwrap_many(args[0])
cls = _get_monomial_class(*args[0])
arg0, dims = zip(*map(unwrap, args[0]))
dim = dims[0]
if any(d != dim for d in dims[1:]):
raise DimensionError(
f"incompatible dimensions for {op.__name__}: "
+ ", ".join(map(str, dims))
)
val = op(arg0, *args[1:], **kwargs)
return cls(val, dim)
return wrap(cls, val, dim)

@__table.register("nutils.function.curvature")
def __evaluate(op, *args, **kwargs):
cls, (arg0, dim0) = _unwrap(args[0])
def __curvature(op, *args, **kwargs):
cls = _get_monomial_class(*args[0])
arg0, dim0 = unwrap(args[0])
val = op(*args, **kwargs)
return cls(val, -dim0)
return wrap(cls, val, -dim0)

@__table.register("nutils.function.evaluate")
def __evaluate(op, *args, **kwargs):
cls, args, dims = _unwrap_many(args)
cls = _get_monomial_class(*args)
args, dims = zip(*map(unwrap, args))
vals = op(*args, **kwargs)
return tuple(map(cls, vals, dims))
return tuple(wrap(cls, val, dim) for val, dim in zip(vals, dims))

@__table.register("nutils.function.field")
def __field(op, *args, **kwargs):
cls, args, dims = _unwrap_many(args)
cls = _get_monomial_class(*args)
args, dims = zip(*map(unwrap, args))
val = op(*args, **kwargs)
dim = reduce(operator.add, dims)
# we abuse the fact that unpack str returns dimensionless
return cls(val, dim)
return wrap(cls, val, dim)

@__table.register("nutils.function.arguments_for")
def __attribute(op, *args, **kwargs):
_, args, _ = _unwrap_many(args)
args = [unwrap(arg)[0] for arg in args]
return op(*args, **kwargs)

@__table.register("numpy.interp")
def __interp(op, x, xp, fp, *args, **kwargs):
cls, (x, dimx), (xp, dimxp), (fp, dimfp) = _unwrap(x, xp, fp)
cls = _get_monomial_class(x, xp, fp)
x, dimx = unwrap(x)
xp, dimxp = unwrap(xp)
fp, dimfp = unwrap(fp)
if dimx != dimxp:
raise DimensionError(
f"incompatible dimensions for {op.__name__}: {dimx}, {dimxp}"
)
val = op(x, xp, fp, *args, **kwargs)
return cls(val, dimfp)
return wrap(cls, val, dimfp)

@__table.register("nutils.topology.Topology.locate")
def __locate(
Expand All @@ -273,13 +295,10 @@ def __locate(
maxdist=None,
**kwargs,
):
(
cls,
(geom, dimgeom),
(coords, dimcoords),
(tol, dimtol),
(maxdist, dimmaxdist),
) = _unwrap(geom, coords, tol, maxdist)
geom, dimgeom = unwrap(geom)
coords, dimcoords = unwrap(coords)
tol, dimtol = unwrap(tol)
maxdist, dimmaxdist = unwrap(maxdist)
if dimgeom != dimcoords:
raise DimensionError(
f"incompatible dimensions for locate: {dimgeom}, {dimcoords}"
Expand All @@ -304,16 +323,18 @@ def __locate(
@__table.register("nutils.sample.Sample.bind")
@__table.register("nutils.sample.Sample.integral")
def __sample(op, sample, func):
cls, (func, dim) = _unwrap(func)
cls = _get_monomial_class(func)
func, dim = unwrap(func)
val = op(sample, func)
return cls(val, dim)
return wrap(cls, val, dim)

@__table.register("numpy.linalg.det")
def __det(op, arg):
cls, (arg, dim) = _unwrap(arg)
cls = _get_monomial_class(arg)
arg, dim = unwrap(arg)
val = op(arg)
dim = dim * arg.ndim
return cls(val, dim)
return wrap(cls, val, dim)

## DEFINE OPERATORS

Expand Down Expand Up @@ -357,7 +378,14 @@ def unwrap(obj):
return obj.unwrap() if isinstance(obj, Monomial) else (obj, Powers({}))


def _unwrap(*args):
def _get_monomial_class(*args):
'''Return common Monomial base class.

This helper function returns the highest subclass of which all monomial
arguments are an instance. Concretely, if one argument is a Monomial and
the other a UMonomial, then Monomial is returned. If all are UMonomial then
the return value is UMonomial.'''

types = {type(arg) for arg in args if isinstance(arg, Monomial)}
bases = _collect_bases(types.pop())
for cls in types:
Expand All @@ -367,12 +395,7 @@ def _unwrap(*args):
assert bases[0] is Monomial
while bases[-1].__init__ != Monomial.__init__:
bases = bases[:-1]
return bases[-1], *map(unwrap, args)


def _unwrap_many(args):
cls, *args_dims = _unwrap(*args)
return cls, *zip(*args_dims)
return bases[-1]


def _collect_bases(cls, sentinel=object):
Expand Down
Loading