From 0f6752f715cf42fd920ae01d73a2fba3c552cbb6 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 16 Apr 2024 10:43:49 +0200 Subject: [PATCH 1/3] Generalize `rrule` for `svdvals` --- src/rulesets/LinearAlgebra/factorization.jl | 16 +++ src/rulesets/LinearAlgebra/symmetric.jl | 22 ---- test/rulesets/LinearAlgebra/factorization.jl | 127 ++++++++++++------- test/rulesets/LinearAlgebra/symmetric.jl | 22 ---- 4 files changed, 100 insertions(+), 87 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 910dd744b..18cd96643 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -276,6 +276,22 @@ function svd_rev(USV::SVD, Ū, s̄, V̄) return Ā end +##### +##### `svdvals` +##### + +function rrule(::typeof(svdvals), A::AbstractMatrix{<:Number}) + F = svd(A) + U = F.U + Vt = F.Vt + project_A = ProjectTo(A) + function svdvals_pullback(s̄) + S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(unthunk(s̄)) + (NoTangent(), project_A(U * S̄ * Vt)) + end + return F.S, svdvals_pullback +end + ##### ##### `eigen` ##### diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 67693575e..cdadca6c3 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -277,28 +277,6 @@ function _svd_eigvals_sign!(c, U, V) return c end -##### -##### `svdvals` -##### - -# NOTE: rrule defined because `svdvals` calls mutating `svdvals!` internally. -# can be removed when mutation is supported by reverse-mode AD packages -function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}) - λ, back = rrule(eigvals, A) - S = abs.(λ) - p = sortperm(S; rev=true) - permute!(S, p) - function svdvals_pullback(ΔS) - ∂λ = real.(ΔS) - invpermute!(∂λ, p) - ∂λ .*= sign.(λ) - _, ∂A = back(∂λ) - return NoTangent(), unthunk(∂A) - end - svdvals_pullback(ΔS::AbstractZero) = (NoTangent(), ΔS) - return S, svdvals_pullback -end - ##### ##### matrix functions ##### diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 60e2e74be..1dd1aeb37 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -102,61 +102,102 @@ end end end end - @testset "svd" begin - for n in [4, 6, 10], m in [3, 5, 9] - @testset "($n x $m) svd" begin - X = randn(n, m) - test_rrule(svd, X; atol=1e-6, rtol=1e-6) - end - end - for n in [4, 6, 10], m in [3, 5, 10] - @testset "($n x $m) getproperty" begin - X = randn(n, m) - F = svd(X) - rand_adj = adjoint(rand(reverse(size(F.V))...)) + @testset "singular value decomposition" begin + @testset "svd" begin + for n in [4, 6, 10], m in [3, 5, 9] + @testset "($n x $m) svd" begin + X = randn(n, m) + test_rrule(svd, X; atol=1e-6, rtol=1e-6) + end + end - test_rrule(getproperty, F, :U; check_inferred=false) - test_rrule(getproperty, F, :S; check_inferred=false) - test_rrule(getproperty, F, :Vt; check_inferred=false) - test_rrule(getproperty, F, :V; check_inferred=false, output_tangent=rand_adj) + for n in [4, 6, 10], m in [3, 5, 10] + @testset "($n x $m) getproperty" begin + X = randn(n, m) + F = svd(X) + rand_adj = adjoint(rand(reverse(size(F.V))...)) + + test_rrule(getproperty, F, :U; check_inferred=false) + test_rrule(getproperty, F, :S; check_inferred=false) + test_rrule(getproperty, F, :Vt; check_inferred=false) + test_rrule( + getproperty, F, :V; check_inferred=false, output_tangent=rand_adj + ) + end end - end - @testset "Thunked inputs" begin - X = randn(4, 3) - F, dX_pullback = rrule(svd, X) - for p in [:U, :S, :V, :Vt] - Y, dF_pullback = rrule(getproperty, F, p) - Ȳ = randn(size(Y)...) + @testset "Thunked inputs" begin + X = randn(4, 3) + F, dX_pullback = rrule(svd, X) + for p in [:U, :S, :V, :Vt] + Y, dF_pullback = rrule(getproperty, F, p) + Ȳ = randn(size(Y)...) + + _, dF_unthunked, _ = dF_pullback(Ȳ) - _, dF_unthunked, _ = dF_pullback(Ȳ) + # helper to let us check how things are stored. + p_access = p == :V ? :Vt : p + backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p_access) + @assert !(backing_field(dF_unthunked, p) isa AbstractThunk) - # helper to let us check how things are stored. - p_access = p == :V ? :Vt : p - backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p_access) - @assert !(backing_field(dF_unthunked, p) isa AbstractThunk) + dF_thunked = map(f -> Thunk(() -> f), dF_unthunked) + @assert backing_field(dF_thunked, p) isa AbstractThunk + + dself_thunked, dX_thunked = dX_pullback(dF_thunked) + dself_unthunked, dX_unthunked = dX_pullback(dF_unthunked) + @test dself_thunked == dself_unthunked + @test dX_thunked == dX_unthunked + end + end - dF_thunked = map(f->Thunk(()->f), dF_unthunked) - @assert backing_field(dF_thunked, p) isa AbstractThunk + @testset "Helper functions" begin + X = randn(10, 10) + Y = randn(10, 10) + @test ChainRules._mulsubtrans!!(copy(X), Y) ≈ Y .* (X - X') + @test ChainRules._eyesubx!(copy(X)) ≈ I - X - dself_thunked, dX_thunked = dX_pullback(dF_thunked) - dself_unthunked, dX_unthunked = dX_pullback(dF_unthunked) - @test dself_thunked == dself_unthunked - @test dX_thunked == dX_unthunked + Z = randn(Float32, 10, 10) + result = ChainRules._mulsubtrans!!(copy(Z), Y) + @test result ≈ Y .* (Z - Z') + @test eltype(result) == Float64 end end - @testset "Helper functions" begin - X = randn(10, 10) - Y = randn(10, 10) - @test ChainRules._mulsubtrans!!(copy(X), Y) ≈ Y .* (X - X') - @test ChainRules._eyesubx!(copy(X)) ≈ I - X + @testset "svdvals" begin + for n in [4, 6, 10] + for m in [3, 5, 9] + @testset "($n x $m) svdvals" begin + X = randn(n, m) + test_rrule(svdvals, X; atol=1e-6, rtol=1e-6) + end + end + + @testset "rrule for svdvals(::$SymHerm{$T}) ($n x $n, uplo=$uplo)" for SymHerm in + ( + Symmetric, Hermitian + ), + T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), + uplo in (:L, :U) + + A, ΔS = randn(T, n, n), randn(n) + symA = SymHerm(A, uplo) - Z = randn(Float32, 10, 10) - result = ChainRules._mulsubtrans!!(copy(Z), Y) - @test result ≈ Y .* (Z - Z') - @test eltype(result) == Float64 + S = svdvals(symA) + S_ad, back = @inferred rrule(svdvals, symA) + @test S_ad ≈ S # inexact because rrule uses svd not svdvals + ∂self, ∂symA = @inferred back(ΔS) + @test ∂self === NoTangent() + @test ∂symA isa typeof(symA) + @test ∂symA.uplo == symA.uplo + + # pull the cotangent back to A to test against finite differences + ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] + @test ∂A ≈ j′vp(_fdm, A -> svdvals(SymHerm(A, uplo)), ΔS, A)[1] + + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) + end + end end end diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index c00bd1522..593b82148 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -275,28 +275,6 @@ @test @maybe_inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) @test @maybe_inferred(back(CT())) == (NoTangent(), ZeroTangent()) end - - @testset "rrule for svdvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), - T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), - uplo in (:L, :U) - - A, ΔS = randn(T, n, n), randn(n) - symA = SymHerm(A, uplo) - - S = svdvals(symA) - S_ad, back = @maybe_inferred rrule(svdvals, symA) - @test S_ad ≈ S # inexact because rrule uses svd not svdvals - ∂self, ∂symA = @maybe_inferred back(ΔS) - @test ∂self === NoTangent() - @test ∂symA isa typeof(symA) - @test ∂symA.uplo == symA.uplo - - # pull the cotangent back to A to test against finite differences - ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] - @test ∂A ≈ j′vp(_fdm, A -> svdvals(SymHerm(A, uplo)), ΔS, A)[1] - - @test @maybe_inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) - end end @testset "Symmetric/Hermitian matrix functions" begin From 006e8f0ed7cbb62de72c5cb528491f56273b09ed Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 16 Apr 2024 10:45:21 +0200 Subject: [PATCH 2/3] Bump version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b72993248..b95447f44 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.63.0" +version = "1.64.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 8f74dead235de32d7e75c1fb76a84dfadaa39d12 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 16 Apr 2024 10:47:21 +0200 Subject: [PATCH 3/3] Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rulesets/LinearAlgebra/factorization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 18cd96643..1391e6aef 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -287,7 +287,7 @@ function rrule(::typeof(svdvals), A::AbstractMatrix{<:Number}) project_A = ProjectTo(A) function svdvals_pullback(s̄) S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(unthunk(s̄)) - (NoTangent(), project_A(U * S̄ * Vt)) + return (NoTangent(), project_A(U * S̄ * Vt)) end return F.S, svdvals_pullback end