Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ jobs:
- "pre"
group:
- Core
- Enzyme
exclude:
- version: "pre"
group: Enzyme
uses: "SciML/.github/.github/workflows/tests.yml@v1"
with:
julia-version: "${{ matrix.version }}"
Expand Down
6 changes: 2 additions & 4 deletions ext/FastPowerEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import FastPower: fastpower
using Enzyme
using Enzyme.EnzymeRules: FwdConfig

Enzyme.EnzymeRules.@easy_rule(
FastPower.fastpower(x, y),
( y * fastpower(x, y - 1), Ω * log(x) )
)
Enzyme.EnzymeRules.@easy_rule(FastPower.fastpower(x, y),
(y * fastpower(x, y - 1), Ω * log(x)))

end
95 changes: 52 additions & 43 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,67 @@
using FastPower
using FastPower: fastlog2, fastpower
using Enzyme, EnzymeTestUtils
using ForwardDiff, ReverseDiff, Tracker, Mooncake
using Test

@testset "Fast log2" begin
for x in 0.001:0.001:1.2 # (0, 1+something] is the domain which a controller uses
@test log2(x)≈fastlog2(Float32(x)) atol=1e-3
const GROUP = get(ENV, "GROUP", "All")

if GROUP == "All" || GROUP == "Core"
@testset "Fast log2" begin
for x in 0.001:0.001:1.2 # (0, 1+something] is the domain which a controller uses
@test log2(x)≈fastlog2(Float32(x)) atol=1e-3
end
end
end

@testset "Fast pow" begin
@test fastpower(1, 1) isa Float64
@test fastpower(1.0, 1.0) isa Float64
errors = [abs(^(x, y) - fastpower(x, y)) for x in 0.001:0.001:1, y in 0.08:0.001:0.5]
@test maximum(errors) < 1e-4
@testset "Fast pow" begin
@test fastpower(1, 1) isa Float64
@test fastpower(1.0, 1.0) isa Float64
errors = [abs(^(x, y) - fastpower(x, y)) for x in 0.001:0.001:1, y in 0.08:0.001:0.5]
@test maximum(errors) < 1e-4

errors = [abs(^(x, y) - fastpower(x, y)) for x in 0.001:0.001:1, y in 0.08:0.001:1000.0]
@test maximum(errors) < 1e-3
errors = [abs(^(x, y) - fastpower(x, y)) for x in 0.001:0.001:1, y in 0.08:0.001:1000.0]
@test maximum(errors) < 1e-3

errors = [abs(^(x, y) - fastpower(x, y)) for x in 0.001:0.001:100, y in 0.08:0.001:1.0]
@test maximum(errors) < 1e-2
end
errors = [abs(^(x, y) - fastpower(x, y)) for x in 0.001:0.001:100, y in 0.08:0.001:1.0]
@test maximum(errors) < 1e-2
end

using ForwardDiff, ReverseDiff, Tracker, Mooncake

@testset "Fast pow - Enzyme forward rule" begin
@testset for RT in (Duplicated, DuplicatedNoNeed),
Tx in (Const, Duplicated),
Ty in (Const, Duplicated)
x = 1.0
y = 0.5
test_forward(fastpower, RT, (x, Tx), (y, Ty), atol = 1e-4, rtol = 1e-3)
function mooncake_derivative(f, x)
Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2]
end
@testset "Fast pow - Other AD Engines" begin
x = 1.5123233245141
y = 0.22352354326
@test ForwardDiff.derivative(x -> fastpower(x, x + y), x) ≈
ForwardDiff.derivative(x -> ^(x, x + y), x)
@test Tracker.gradient(x -> fastpower(x, x + y), x)[1] ≈
Tracker.gradient(x -> ^(x, x + y), x)[1]
@test ReverseDiff.gradient(x -> fastpower(x[1], x[1] + y), [x])[1] ≈
ReverseDiff.gradient(x -> ^(x[1], x[1] + y), [x])[1]
@test mooncake_derivative(x -> fastpower(x, x + y), x) ≈
mooncake_derivative(x -> ^(x, x + y), x)
end
end

@testset "Fast pow - Enzyme reverse rule" begin
@testset for RT in (Active,), Tx in (Active, Const), Ty in (Active, Const)
x = 1.0
y = 0.5
test_reverse(fastpower, RT, (x, Tx), (y, Ty), atol = 1e-4, rtol = 1e-3)
if GROUP == "All" || GROUP == "Enzyme"
using Enzyme, EnzymeTestUtils

@testset "Fast pow - Enzyme forward rule" begin
@testset for RT in (Duplicated, DuplicatedNoNeed),
Tx in (Const, Duplicated),
Ty in (Const, Duplicated)

x = 1.0
y = 0.5
test_forward(fastpower, RT, (x, Tx), (y, Ty), atol = 1e-4, rtol = 1e-3)
end
end
end

function mooncake_derivative(f, x)
Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2]
end
@testset "Fast pow - Other AD Engines" begin
x = 1.5123233245141
y = 0.22352354326
@test ForwardDiff.derivative(x -> fastpower(x, x + y), x) ≈
ForwardDiff.derivative(x -> ^(x, x + y), x)
@test Tracker.gradient(x -> fastpower(x, x + y), x)[1] ≈
Tracker.gradient(x -> ^(x, x + y), x)[1]
@test ReverseDiff.gradient(x -> fastpower(x[1], x[1] + y), [x])[1] ≈
ReverseDiff.gradient(x -> ^(x[1], x[1] + y), [x])[1]
@test mooncake_derivative(x -> fastpower(x, x + y), x) ≈
mooncake_derivative(x -> ^(x, x + y), x)
@testset "Fast pow - Enzyme reverse rule" begin
@testset for RT in (Active,), Tx in (Active, Const), Ty in (Active, Const)
x = 1.0
y = 0.5
test_reverse(fastpower, RT, (x, Tx), (y, Ty), atol = 1e-4, rtol = 1e-3)
end
end
end
Loading