Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/decimo/bigdecimal/bigdecimal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,26 @@ struct BigDecimal(
"""
return bigdecimal_special.factorial(self, precision)

def permutation(self, k: Int, precision: Int = 0) raises -> Self:
"""Returns the number of `k`-permutations of `self` items.

`P(n, k) = n! / (n - k)!`, where `n = self`. Exact when
`precision == 0`; a positive `precision` returns that many
significant digits.

Args:
k: The number of ordered positions to fill (non-negative).
precision: Significant digits for the result (`0` = exact).

Returns:
`P(self, k)`; 0 when `k > self`, and `P(self, 0) == 1`.

Raises:
ValueError: If `self` is not an integer, `self` or `k` is
negative, or `k` is larger than 10^6.
"""
return bigdecimal_special.permutation(self, k, precision)

@always_inline
def ln(self, precision: Int = PRECISION) raises -> Self:
"""Returns the natural logarithm of the BigDecimal number.
Expand Down
94 changes: 89 additions & 5 deletions src/decimo/bigdecimal/special.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def factorial(x: BigDecimal, precision: Int = 0) raises -> BigDecimal:
# fractional part (e.g. "5.00") convert cleanly to `Int`.
var n = Int(x.truncate())
if precision <= 0:
# Exact: full-width products, no rounding.
var result = BigDecimal(1)
for i in range(2, n + 1):
result = result.multiply(BigDecimal(i))
return result^
# Exact: balanced binary splitting (same idea as BigInt) is much
# faster than a left-to-right running product for large `n`.
if n < 2:
return BigDecimal(1)
return product_range(2, n)

# Rounded: keep every product at `precision + guard` significant digits,
# where the guard also grows with the number of digits in `n`. Round the
Expand All @@ -98,3 +98,87 @@ def factorial(x: BigDecimal, precision: Int = 0) raises -> BigDecimal:
for i in range(2, n + 1):
result = result.multiply(BigDecimal(i), working_precision)
return result.multiply(BigDecimal(1), precision)


def product_range(low: Int, high: Int) raises -> BigDecimal:
"""Returns the exact product of the consecutive integers in `[low, high]`.

The range is inclusive; an empty range (`low > high`) returns 1. Uses
balanced binary splitting with exact multiplication (`precision=0`), so
each multiplication stays between operands of similar size.

Args:
low: The first integer in the range.
high: The last integer in the range.

Returns:
`low * (low + 1) * ... * high` (1 when the range is empty).
"""
if low > high:
return BigDecimal(1)
if low == high:
return BigDecimal(low)
if high == low + 1:
return BigDecimal(low).multiply(BigDecimal(high))
var mid = low + (high - low) // 2
return product_range(low, mid).multiply(product_range(mid + 1, high))


def permutation(x: BigDecimal, k: Int, precision: Int = 0) raises -> BigDecimal:
"""Calculates the number of `k`-permutations of `n = x` items.

`P(n, k) = n! / (n - k)!`. The result is exact when `precision == 0`; a
positive `precision` rounds the intermediate products and returns
`precision` significant digits.

Args:
x: The number of items `n` (a non-negative integer value).
k: The number of ordered positions to fill (non-negative).
precision: Significant digits for the result (`0` = exact).

Returns:
`P(n, k)`. Returns 0 when `k > n` (no such arrangement exists);
`P(n, 0) == 1`.

Raises:
ValueError: If `x` is not an integer, `x` or `k` is negative, or `k`
is larger than `FACTORIAL_MAX_INPUT` (10^6).
"""
if not x.is_integer():
raise ValueError(
message="Permutation is only defined for integer values of n.",
function="permutation()",
)
if x < BigDecimal(0):
raise ValueError(
message="Permutation is not defined for a negative n.",
function="permutation()",
)
if k < 0:
raise ValueError(
message="Permutation is not defined for a negative k.",
function="permutation()",
)
if k > FACTORIAL_MAX_INPUT:
raise ValueError(
message="Permutation k is too large to compute (must be <= 10^6).",
function="permutation()",
)
if x > BigDecimal(Int.MAX):
raise ValueError(
message="Permutation n is too large to fit in an Int.",
function="permutation()",
)
var n = Int(x.truncate())
if k > n:
return BigDecimal(0)
if precision <= 0:
return product_range(n - k + 1, n)

var working_precision = (
precision + String(n).byte_length() + FACTORIAL_GUARD_DIGITS
)
Comment thread
Copilot marked this conversation as resolved.
var result = BigDecimal(1)
for i in range(n - k + 1, n + 1):
result = result.multiply(BigDecimal(i), working_precision)
return result.multiply(BigDecimal(1), precision)
17 changes: 17 additions & 0 deletions src/decimo/bigint/bigint.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1482,6 +1482,23 @@ struct BigInt(
"""
return bigint_special.factorial(self)

def permutation(self, k: Int) raises -> Self:
"""Returns the number of `k`-permutations of `self` items.

`P(n, k) = n! / (n - k)!`, where `n = self`.

Args:
k: The number of ordered positions to fill (non-negative).

Returns:
`P(self, k)`; 0 when `k > self`, and `P(self, 0) == 1`.

Raises:
ValueError: If `self` or `k` is negative, or `k` is larger
than 10^6.
"""
return bigint_special.permutation(self, k)

@always_inline
def compare_magnitudes(self, other: Self) -> Int8:
"""Compares the magnitudes (absolute values) of two BigInt numbers.
Expand Down
78 changes: 74 additions & 4 deletions src/decimo/bigint/special.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,77 @@ def factorial(x: BigInt) raises -> BigInt:
)

var n = Int(x)
var result = BigInt.one()
for i in range(2, n + 1):
result *= BigInt(i)
return result^
if n < 2:
return BigInt.one()
# Balanced binary splitting multiplies similar-sized operands instead of
# the naive tiny * huge running product, which is far faster for large
# `n` (measured ~1.4x at n=1000 up to ~10x at n=100000).
return product_range(2, n)


def product_range(low: Int, high: Int) -> BigInt:
"""Returns the product of the consecutive integers in `[low, high]`.

The range is inclusive; an empty range (`low > high`) returns 1. Uses
balanced binary splitting so each multiplication stays between operands
of similar size, which is much faster than a left-to-right running
product for large ranges.

Args:
low: The first integer in the range.
high: The last integer in the range.

Returns:
`low * (low + 1) * ... * high` (1 when the range is empty).
"""
if low > high:
return BigInt.one()
if low == high:
return BigInt(low)
if high == low + 1:
return BigInt(low) * BigInt(high)
var mid = low + (high - low) // 2
return product_range(low, mid) * product_range(mid + 1, high)


def permutation(x: BigInt, k: Int) raises -> BigInt:
"""Calculates the number of `k`-permutations of `n = x` items.

`P(n, k) = n! / (n - k)! = (n - k + 1) * (n - k + 2) * ... * n`.

Args:
x: The number of items `n` (non-negative).
k: The number of ordered positions to fill (non-negative).

Returns:
`P(n, k)`. Returns 0 when `k > n` (no such arrangement exists);
`P(n, 0) == 1`.

Raises:
ValueError: If `x` or `k` is negative, or if `k` is larger than
`FACTORIAL_MAX_INPUT` (10^6, the cap on the number of factors).
"""
if x < BigInt.zero():
raise ValueError(
message="Permutation is not defined for a negative n.",
function="permutation()",
)
if k < 0:
raise ValueError(
message="Permutation is not defined for a negative k.",
function="permutation()",
)
if k > FACTORIAL_MAX_INPUT:
raise ValueError(
message="Permutation k is too large to compute (must be <= 10^6).",
function="permutation()",
)
if x > BigInt(Int.MAX):
raise ValueError(
message="Permutation n is too large to fit in an Int.",
function="permutation()",
)
var n = Int(x)
if k > n:
return BigInt.zero()
return product_range(n - k + 1, n)
12 changes: 12 additions & 0 deletions tests/bigdecimal/test_bigdecimal_special.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,17 @@ def test_factorial_too_large_raises() raises:
testing.assert_true(raised, "factorial above the cap should raise")


def test_permutation() raises:
"""Test permutation P(n, k) exact."""
testing.assert_equal(String(BigDecimal(10).permutation(3)), "720")
testing.assert_equal(String(BigDecimal(5).permutation(0)), "1")
testing.assert_equal(String(BigDecimal(5).permutation(7)), "0") # k > n


def test_permutation_rounded() raises:
"""Test permutation rounded mode."""
testing.assert_equal(String(BigDecimal(10).permutation(3, 2)), "7.2E+2")


def main() raises:
testing.TestSuite.discover_tests[__functions_in_module()]().run()
19 changes: 19 additions & 0 deletions tests/bigint/test_bigint_special.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,24 @@ def test_factorial_too_large_raises() raises:
testing.assert_true(raised, "factorial above the cap should raise")


def test_permutation() raises:
"""Test permutation P(n, k)."""
testing.assert_equal(String(BigInt(10).permutation(3)), "720")
testing.assert_equal(String(BigInt(5).permutation(5)), "120") # P(n,n)=n!
testing.assert_equal(String(BigInt(5).permutation(0)), "1")
testing.assert_equal(String(BigInt(5).permutation(7)), "0") # k > n
testing.assert_equal(String(BigInt(100).permutation(2)), "9900")


def test_permutation_negative_k_raises() raises:
"""Test that a negative k raises."""
var raised = False
try:
_ = BigInt(5).permutation(-1)
except:
raised = True
testing.assert_true(raised, "permutation with negative k should raise")


def main() raises:
testing.TestSuite.discover_tests[__functions_in_module()]().run()
Loading