Skip to content

Commit a68a361

Browse files
committed
Move AD to package extensions
1 parent 8c6ac8c commit a68a361

11 files changed

Lines changed: 125 additions & 76 deletions

File tree

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
33
version = "0.11.0"
44

55
[deps]
6-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
76
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
87
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
98
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
@@ -18,15 +17,18 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1817
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1918
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
2019
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
21-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2220

2321
[weakdeps]
22+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2423
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
2524
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
25+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2626

2727
[extensions]
28+
KernelFunctionsChainRulesCoreExt = "ChainRulesCore"
2829
KernelFunctionsKroneckerExt = "Kronecker"
2930
KernelFunctionsPDMatsExt = "PDMats"
31+
KernelFunctionsZygoteRulesExt = "ZygoteRules"
3032

3133
[compat]
3234
ChainRulesCore = "1"

examples/train-kernel-parameters/script.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
# We load KernelFunctions and some other packages. Note that while we use `Zygote` for automatic differentiation and `Flux.optimise` for optimization, you should be able to replace them with your favourite autodiff framework or optimizer.
77

8+
# !!! note
9+
# Zygote is not expected to work on Julia ≥ 1.12. Use a different AD package for
10+
# Julia ≥ 1.12, or use Julia 1.11 to run this example.
11+
812
using KernelFunctions
913
using LinearAlgebra
1014
using Distributions
Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1+
module KernelFunctionsChainRulesCoreExt
2+
3+
using ChainRulesCore:
4+
ChainRulesCore, Tangent, ZeroTangent, NoTangent, @thunk, ProjectTo, unthunk
5+
using Distances: Distances, Euclidean, SqEuclidean
6+
using IrrationalConstants: twoπ
7+
using KernelFunctions:
8+
KernelFunctions, Delta, DotProduct, Sinus, ColVecs, RowVecs
9+
using LinearAlgebra: dot
10+
111
## Forward Rules
212

313
# Note that this is type piracy as the derivative should be NaN for x == y.
414
function ChainRulesCore.frule(
515
(_, Δx, Δy)::Tuple{<:Any,<:Any,<:Any},
6-
d::Distances.Euclidean,
16+
d::Euclidean,
717
x::AbstractVector,
818
y::AbstractVector,
919
)
@@ -116,7 +126,7 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
116126
gradx = π .* sinpi.(2 .* d) ./ s.r .^ 2
117127
function evaluate_pullback::Any)
118128
= -2Δ .* abs2_sind_r ./ s.r
119-
= ChainRulesCore.Tangent{typeof(s)}(; r=r̄)
129+
= Tangent{typeof(s)}(; r=r̄)
120130
return s̄, Δ * gradx, -Δ * gradx
121131
end
122132
return val, evaluate_pullback
@@ -150,7 +160,7 @@ function ChainRulesCore.rrule(
150160
x̄[:, j] .-= ds
151161
end
152162
end
153-
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
163+
= Tangent{typeof(d)}(; r=r̄)
154164
return NoTangent(), d̄, @thunk(project_x(x̄))
155165
end
156166
return Distances.pairwise(d, x; dims), pairwise_pullback
@@ -166,7 +176,7 @@ function ChainRulesCore.rrule(
166176
n = size(x, dims)
167177
m = size(y, dims)
168178
= collect(zero(x))
169-
= collect(zero(y))
179+
ȳ = collect(zero(y))
170180
= zero(d.r)
171181
if dims == 1
172182
for j in 1:m, i in 1:n
@@ -175,7 +185,7 @@ function ChainRulesCore.rrule(
175185
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2
176186
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
177187
x̄[i, :] .+= ds
178-
[j, :] .-= ds
188+
ȳ[j, :] .-= ds
179189
end
180190
elseif dims == 2
181191
for j in 1:m, i in 1:n
@@ -184,11 +194,11 @@ function ChainRulesCore.rrule(
184194
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2
185195
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
186196
x̄[:, i] .+= ds
187-
[:, j] .-= ds
197+
ȳ[:, j] .-= ds
188198
end
189199
end
190-
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
191-
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y())
200+
= Tangent{typeof(d)}(; r=r̄)
201+
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
192202
end
193203
return Distances.pairwise(d, x, y; dims), pairwise_pullback
194204
end
@@ -202,18 +212,18 @@ function ChainRulesCore.rrule(
202212
Δ = unthunk(z̄)
203213
n = size(x, 2)
204214
= collect(zero(x))
205-
= collect(zero(y))
215+
ȳ = collect(zero(y))
206216
= zero(d.r)
207217
for i in 1:n
208218
xi = view(x, :, i)
209219
yi = view(y, :, i)
210220
ds = π .* Δ[i] .* sinpi.(2 .* (xi .- yi)) ./ d.r .^ 2
211221
.-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3
212222
x̄[:, i] .+= ds
213-
[:, i] .-= ds
223+
ȳ[:, i] .-= ds
214224
end
215-
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
216-
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y())
225+
= Tangent{typeof(d)}(; r=r̄)
226+
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
217227
end
218228
return Distances.colwise(d, x, y), colwise_pullback
219229
end
@@ -247,3 +257,5 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix)
247257
end
248258
return RowVecs(X), RowVecs_pullback
249259
end
260+
261+
end
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
module KernelFunctionsZygoteRulesExt
2+
3+
using KernelFunctions: KernelFunctions, Transform, ColVecs, RowVecs, _map
4+
using ZygoteRules:
5+
ZygoteRules, AContext, literal_getproperty, literal_getfield
6+
17
ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs)
28
return ZygoteRules.pullback(_map, t, X)
39
end
@@ -11,3 +17,5 @@ function ZygoteRules._pullback(
1117
) where {f}
1218
return ZygoteRules._pullback(cx, literal_getfield, x, Val{f}())
1319
end
20+
21+
end

src/KernelFunctions.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ export IndependentMOKernel,
5050
export tensor, , compose
5151

5252
using Compat
53-
using ChainRulesCore: ChainRulesCore, Tangent, ZeroTangent, NoTangent
54-
using ChainRulesCore: @thunk, InplaceableThunk, ProjectTo, unthunk
5553
using CompositionsBase
5654
using Distances
5755
using FillArrays
@@ -62,7 +60,6 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2
6260
using LogExpFunctions: softplus
6361
using StatsBase
6462
using TensorCore
65-
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield
6663

6764
# Hack to work around Zygote type inference problems.
6865
const Distances_pairwise = Distances.pairwise
@@ -123,9 +120,6 @@ include("mokernels/slfm.jl")
123120
include("mokernels/intrinsiccoregion.jl")
124121
include("mokernels/lmm.jl")
125122

126-
include("chainrules.jl")
127-
include("zygoterules.jl")
128-
129123
include("TestUtils.jl")
130124

131125
# Kronecker extension stubs

test/basekernels/fbm.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
# Tests failing for ForwardDiff and Zygote@0.6.
1919
# Related to: https://github.com/FluxML/Zygote.jl/issues/1036
20-
f(x, y) = x^y
21-
@test_broken !isinf(
22-
Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1]
23-
)
20+
if _TEST_ZYGOTE
21+
f(x, y) = x^y
22+
@test_broken !isinf(
23+
Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1]
24+
)
25+
end
2426

2527
test_params(k, ([h],))
2628

test/chainrules.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,24 @@
44
y = rand(rng, 5)
55
r = rand(rng, 5)
66

7-
compare_gradient(:Zygote, [x, y]) do xy
8-
Euclidean()(xy[1], xy[2])
9-
end
10-
compare_gradient(:Zygote, [x, y]) do xy
11-
SqEuclidean()(xy[1], xy[2])
12-
end
13-
compare_gradient(:Zygote, [x, y]) do xy
14-
KernelFunctions.DotProduct()(xy[1], xy[2])
15-
end
16-
compare_gradient(:Zygote, [x, y]) do xy
17-
KernelFunctions.Delta()(xy[1], xy[2])
18-
end
19-
compare_gradient(:Zygote, [x, y]) do xy
20-
KernelFunctions.Sinus(r)(xy[1], xy[2])
7+
if _TEST_ZYGOTE
8+
compare_gradient(:Zygote, [x, y]) do xy
9+
Euclidean()(xy[1], xy[2])
10+
end
11+
compare_gradient(:Zygote, [x, y]) do xy
12+
SqEuclidean()(xy[1], xy[2])
13+
end
14+
compare_gradient(:Zygote, [x, y]) do xy
15+
KernelFunctions.DotProduct()(xy[1], xy[2])
16+
end
17+
compare_gradient(:Zygote, [x, y]) do xy
18+
KernelFunctions.Delta()(xy[1], xy[2])
19+
end
20+
compare_gradient(:Zygote, [x, y]) do xy
21+
KernelFunctions.Sinus(r)(xy[1], xy[2])
22+
end
23+
else
24+
@test_broken false # Zygote not supported on Julia >= 1.12
2125
end
2226
@testset "rrules for Sinus(r=$r)" for r in (rand(3),)
2327
dist = KernelFunctions.Sinus(r)

test/kernels/transformedkernel.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,30 @@
3636

3737
# Test implicit gradients
3838
@testset "Implicit gradients" begin
39-
k = SqExponentialKernel() ScaleTransform(2.0)
40-
ps = params(k)
41-
X = rand(10, 1)
42-
x = vec(X)
43-
A = rand(10, 10)
44-
# Implicit
45-
g1 = Zygote.gradient(ps) do
46-
tr(kernelmatrix(k, X; obsdim=1) * A)
47-
end
48-
# Explicit
49-
g2 = Zygote.gradient(k) do k
50-
tr(kernelmatrix(k, X; obsdim=1) * A)
51-
end
39+
if _TEST_ZYGOTE
40+
k = SqExponentialKernel() ScaleTransform(2.0)
41+
ps = params(k)
42+
X = rand(10, 1)
43+
x = vec(X)
44+
A = rand(10, 10)
45+
# Implicit
46+
g1 = Zygote.gradient(ps) do
47+
tr(kernelmatrix(k, X; obsdim=1) * A)
48+
end
49+
# Explicit
50+
g2 = Zygote.gradient(k) do k
51+
tr(kernelmatrix(k, X; obsdim=1) * A)
52+
end
5253

53-
# Implicit for a vector
54-
g3 = Zygote.gradient(ps) do
55-
tr(kernelmatrix(k, x) * A)
54+
# Implicit for a vector
55+
g3 = Zygote.gradient(ps) do
56+
tr(kernelmatrix(k, x) * A)
57+
end
58+
@test g1[first(ps)] first(g2).transform.s
59+
@test g1[first(ps)] g3[first(ps)]
60+
else
61+
@test_broken false # Zygote not supported on Julia >= 1.12
5662
end
57-
@test g1[first(ps)] first(g2).transform.s
58-
@test g1[first(ps)] g3[first(ps)]
5963
end
6064

6165
@testset "Parameters" begin

test/test_utils.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ end
3131

3232
# AD utilities
3333

34+
const _TEST_ZYGOTE = VERSION < v"1.12"
35+
3436
# Type to work around some performance issues that can happen on the reverse-pass of Zygote.
3537
# This context doesn't allow any globals. Don't use this if you use globals in your
3638
# programme.
@@ -42,6 +44,9 @@ Zygote.accum_param(::NoContext, x, Δ) = Δ
4244

4345
const FDM = FiniteDifferences.central_fdm(5, 1)
4446

47+
const _DEFAULT_ADS = _TEST_ZYGOTE ? [:Zygote, :ForwardDiff, :ReverseDiff] :
48+
[:ForwardDiff, :ReverseDiff]
49+
4550
gradient(f, s::Symbol, args) = gradient(f, Val(s), args)
4651

4752
function gradient(f, ::Val{:Zygote}, args)
@@ -90,7 +95,7 @@ testdiagfunction(k::MOKernel, A) = sum(kernelmatrix_diag(k, A))
9095
testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B))
9196

9297
function test_ADs(
93-
kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3]
98+
kernelfunction, args=nothing; ADs=_DEFAULT_ADS, dims=[3, 3]
9499
)
95100
test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims)
96101
if !test_fd.anynonpass
@@ -101,14 +106,18 @@ function test_ADs(
101106
end
102107

103108
function check_zygote_type_stability(f, args...; ctx=Zygote.Context())
109+
if !_TEST_ZYGOTE
110+
@test_broken false
111+
return
112+
end
104113
@inferred f(args...)
105114
@inferred Zygote._pullback(ctx, f, args...)
106115
out, pb = Zygote._pullback(ctx, f, args...)
107116
@inferred collect(pb(out))
108117
end
109118

110119
function test_ADs(
111-
k::MOKernel; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=(in=3, out=2, obs=3)
120+
k::MOKernel; ADs=_DEFAULT_ADS, dims=(in=3, out=2, obs=3)
112121
)
113122
test_fd = test_FiniteDiff(k, dims)
114123
if !test_fd.anynonpass
@@ -372,6 +381,10 @@ function test_zygote_perf_heuristic(
372381
f, name::String, args1, args2, passes, Δ1=nothing, Δ2=nothing
373382
)
374383
@testset "$name" begin
384+
if !_TEST_ZYGOTE
385+
@test_broken false
386+
return
387+
end
375388
primal, fwd, pb = ad_constant_allocs_heuristic(f, args1, args2; Δ1, Δ2)
376389
if passes[1]
377390
@test primal[1] == primal[2]

test/transform/selecttransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
@test gx ga
104104
end
105105

106-
@testset "$(AD)" for AD in [:ReverseDiff, :Zygote]
106+
@testset "$(AD)" for AD in (_TEST_ZYGOTE ? [:ReverseDiff, :Zygote] : [:ReverseDiff])
107107
@test_broken let
108108
gx = gradient(AD, X) do x
109109
testfunction(tx_row, x, 2)

0 commit comments

Comments
 (0)