diff --git a/src/nutils/units/core.py b/src/nutils/units/core.py index 8502aa7..6c04e12 100644 --- a/src/nutils/units/core.py +++ b/src/nutils/units/core.py @@ -103,6 +103,7 @@ def __hash__(self): @__table.register("numpy.positive") @__table.register("numpy.ptp") @__table.register("numpy.real") + @__table.register("numpy.repeat") @__table.register("numpy.reshape") @__table.register("numpy.sum") @__table.register("numpy.take") @@ -113,9 +114,10 @@ def __hash__(self): @__table.register("_operator.neg") @__table.register("_operator.pos") def __unary(op, *args, **kwargs): - cls, (arg0, dim0) = _unwrap(args[0]) + cls = _get_monomial_class(args[0]) + arg0, dim0 = unwrap(args[0]) val = op(arg0, *args[1:], **kwargs) - return cls(val, dim0) + return wrap(cls, val, dim0) @__table.register("numpy.add") @__table.register("numpy.hypot") @@ -126,20 +128,24 @@ def __unary(op, *args, **kwargs): @__table.register("_operator.mod") @__table.register("_operator.sub") def __add_like(op, *args, **kwargs): - cls, (arg0, dim0), (arg1, dim1) = _unwrap(args[0], args[1]) + cls = _get_monomial_class(args[0], args[1]) + arg0, dim0 = unwrap(args[0]) + arg1, dim1 = unwrap(args[1]) if dim0 != dim1: raise DimensionError( f"incompatible dimensions for {op.__name__}: {dim0}, {dim1}" ) val = op(arg0, arg1, *args[2:], **kwargs) - return cls(val, dim0) + return wrap(cls, val, dim0) @__table.register("numpy.matmul") @__table.register("numpy.multiply") @__table.register("_operator.matmul") @__table.register("_operator.mul") def __mul_like(op, *args, **kwargs): - cls, (arg0, dim0), (arg1, dim1) = _unwrap(args[0], args[1]) + cls = _get_monomial_class(args[0], args[1]) + arg0, dim0 = unwrap(args[0]) + arg1, dim1 = unwrap(args[1]) val = op(arg0, arg1, *args[2:], **kwargs) return wrap(cls, val, dim0 + dim1) @@ -150,35 +156,43 @@ def __mul_like(op, *args, **kwargs): @__table.register("numpy.divide") @__table.register("_operator.truediv") def __div_like(op, *args, **kwargs): - cls, (arg0, dim0), (arg1, dim1) = _unwrap(args[0], args[1]) + cls = _get_monomial_class(args[0], args[1]) + arg0, dim0 = unwrap(args[0]) + arg1, dim1 = unwrap(args[1]) val = op(arg0, arg1, *args[2:], **kwargs) return wrap(cls, val, dim0 - dim1) @__table.register("nutils.function.laplace") def __laplace(op, *args, **kwargs): - cls, (arg0, dim0), (arg1, dim1) = _unwrap(args[0], args[1]) + cls = _get_monomial_class(args[0], args[1]) + arg0, dim0 = unwrap(args[0]) + arg1, dim1 = unwrap(args[1]) val = op(arg0, arg1, *args[2:], **kwargs) return wrap(cls, val, dim0 - dim1 * 2) @__table.register("numpy.sqrt") def __sqrt(op, *args, **kwargs): - cls, (arg0, dim0) = _unwrap(args[0]) + cls = _get_monomial_class(args[0]) + arg0, dim0 = unwrap(args[0]) val = op(arg0, *args[1:], **kwargs) - return cls(val, dim0 / 2) + return wrap(cls, val, dim0 / 2) @__table.register("_operator.setitem") def __setitem(op, *args, **kwargs): - cls, (arg0, dim0), (arg2, dim2) = _unwrap(args[0], args[2]) + cls = _get_monomial_class(args[0], args[2]) + arg0, dim0 = unwrap(args[0]) + arg2, dim2 = unwrap(args[2]) if dim0 != dim2: raise DimensionError(f"cannot assign {dim2} to {dim0}") val = op(arg0, args[1], arg2, *args[3:], **kwargs) - return cls(val, dim0) + return wrap(cls, val, dim0) @__table.register("nutils.function.jacobian") @__table.register("numpy.power") @__table.register("_operator.pow") def __pow_like(op, *args, **kwargs): - cls, (arg0, dim0) = _unwrap(args[0]) + cls = _get_monomial_class(args[0]) + arg0, dim0 = unwrap(args[0]) val = op(arg0, *args[1:], **kwargs) return wrap(cls, val, dim0 * args[1]) @@ -207,7 +221,8 @@ def __unary_op(op, *args, **kwargs): @__table.register("_operator.lt") @__table.register("_operator.ne") def __binary_op(op, *args, **kwargs): - _, (arg0, dim0), (arg1, dim1) = _unwrap(args[0], args[1]) + arg0, dim0 = unwrap(args[0]) + arg1, dim1 = unwrap(args[1]) if dim0 != dim1 and not (isinstance(args[1], (int, float)) and args[1] == 0): raise DimensionError( f"incompatible dimensions for {op.__name__}: {dim0}, {dim1}" @@ -217,7 +232,8 @@ def __binary_op(op, *args, **kwargs): @__table.register("numpy.stack") @__table.register("numpy.concatenate") def __stack_like(op, *args, **kwargs): - cls, arg0, dims = _unwrap_many(args[0]) + cls = _get_monomial_class(*args[0]) + arg0, dims = zip(*map(unwrap, args[0])) dim = dims[0] if any(d != dim for d in dims[1:]): raise DimensionError( @@ -225,42 +241,48 @@ def __stack_like(op, *args, **kwargs): + ", ".join(map(str, dims)) ) val = op(arg0, *args[1:], **kwargs) - return cls(val, dim) + return wrap(cls, val, dim) @__table.register("nutils.function.curvature") - def __evaluate(op, *args, **kwargs): - cls, (arg0, dim0) = _unwrap(args[0]) + def __curvature(op, *args, **kwargs): + cls = _get_monomial_class(*args[0]) + arg0, dim0 = unwrap(args[0]) val = op(*args, **kwargs) - return cls(val, -dim0) + return wrap(cls, val, -dim0) @__table.register("nutils.function.evaluate") def __evaluate(op, *args, **kwargs): - cls, args, dims = _unwrap_many(args) + cls = _get_monomial_class(*args) + args, dims = zip(*map(unwrap, args)) vals = op(*args, **kwargs) - return tuple(map(cls, vals, dims)) + return tuple(wrap(cls, val, dim) for val, dim in zip(vals, dims)) @__table.register("nutils.function.field") def __field(op, *args, **kwargs): - cls, args, dims = _unwrap_many(args) + cls = _get_monomial_class(*args) + args, dims = zip(*map(unwrap, args)) val = op(*args, **kwargs) dim = reduce(operator.add, dims) # we abuse the fact that unpack str returns dimensionless - return cls(val, dim) + return wrap(cls, val, dim) @__table.register("nutils.function.arguments_for") def __attribute(op, *args, **kwargs): - _, args, _ = _unwrap_many(args) + args = [unwrap(arg)[0] for arg in args] return op(*args, **kwargs) @__table.register("numpy.interp") def __interp(op, x, xp, fp, *args, **kwargs): - cls, (x, dimx), (xp, dimxp), (fp, dimfp) = _unwrap(x, xp, fp) + cls = _get_monomial_class(x, xp, fp) + x, dimx = unwrap(x) + xp, dimxp = unwrap(xp) + fp, dimfp = unwrap(fp) if dimx != dimxp: raise DimensionError( f"incompatible dimensions for {op.__name__}: {dimx}, {dimxp}" ) val = op(x, xp, fp, *args, **kwargs) - return cls(val, dimfp) + return wrap(cls, val, dimfp) @__table.register("nutils.topology.Topology.locate") def __locate( @@ -273,13 +295,10 @@ def __locate( maxdist=None, **kwargs, ): - ( - cls, - (geom, dimgeom), - (coords, dimcoords), - (tol, dimtol), - (maxdist, dimmaxdist), - ) = _unwrap(geom, coords, tol, maxdist) + geom, dimgeom = unwrap(geom) + coords, dimcoords = unwrap(coords) + tol, dimtol = unwrap(tol) + maxdist, dimmaxdist = unwrap(maxdist) if dimgeom != dimcoords: raise DimensionError( f"incompatible dimensions for locate: {dimgeom}, {dimcoords}" @@ -304,16 +323,18 @@ def __locate( @__table.register("nutils.sample.Sample.bind") @__table.register("nutils.sample.Sample.integral") def __sample(op, sample, func): - cls, (func, dim) = _unwrap(func) + cls = _get_monomial_class(func) + func, dim = unwrap(func) val = op(sample, func) - return cls(val, dim) + return wrap(cls, val, dim) @__table.register("numpy.linalg.det") def __det(op, arg): - cls, (arg, dim) = _unwrap(arg) + cls = _get_monomial_class(arg) + arg, dim = unwrap(arg) val = op(arg) dim = dim * arg.ndim - return cls(val, dim) + return wrap(cls, val, dim) ## DEFINE OPERATORS @@ -357,7 +378,14 @@ def unwrap(obj): return obj.unwrap() if isinstance(obj, Monomial) else (obj, Powers({})) -def _unwrap(*args): +def _get_monomial_class(*args): + '''Return common Monomial base class. + + This helper function returns the highest subclass of which all monomial + arguments are an instance. Concretely, if one argument is a Monomial and + the other a UMonomial, then Monomial is returned. If all are UMonomial then + the return value is UMonomial.''' + types = {type(arg) for arg in args if isinstance(arg, Monomial)} bases = _collect_bases(types.pop()) for cls in types: @@ -367,12 +395,7 @@ def _unwrap(*args): assert bases[0] is Monomial while bases[-1].__init__ != Monomial.__init__: bases = bases[:-1] - return bases[-1], *map(unwrap, args) - - -def _unwrap_many(args): - cls, *args_dims = _unwrap(*args) - return cls, *zip(*args_dims) + return bases[-1] def _collect_bases(cls, sentinel=object):