using Symbolics, SparseArrays, LinearAlgebra, Test
using ReferenceTests
using Symbolics: value
using SymbolicUtils.Code: DestructuredArgs, Func, NameState, Let, cse
@variables a b c1 c2 c3 d e g
oop, iip = Symbolics.build_function([sqrt(a), sin(b)], [a, b], nanmath = true)
@test all(isnan, eval(oop)([-1, Inf]))
out = [0, 0.0]
eval(iip)(out, [-1, Inf])
@test all(isnan, out)

# Multiple argument matrix
h = [a + b + c1 + c2,
     c3 + d + e + g,
     0] # uses the same number of arguments as our application
h_julia(a, b, c, d, e, g) = [a[1] + b[1] + c[1] + c[2],
                             c[3] + d[1] + e[1] + g[1],
                             0]
function h_julia!(out, a, b, c, d, e, g)
    out .= [a[1] + b[1] + c[1] + c[2], c[3] + d[1] + e[1] + g[1], 0]
end

h_str = Symbolics.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g])
h_str2 = Symbolics.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g])
@test h_str[1] == h_str2[1]
@test h_str[2] == h_str2[2]

h_oop = eval(h_str[1])
h_str_par = Symbolics.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g], parallel=Symbolics.MultithreadedForm())
h_str_3 = Symbolics.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g], iip_config = (false, true))
h_str_4 = Symbolics.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g], iip_config = (true, false))

@test contains(repr(h_str_par[1]), "schedule")
@test contains(repr(h_str_par[2]), "schedule")
h_oop_par = eval(h_str_par[1])
h_par_rgf = Symbolics.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g], parallel=Symbolics.MultithreadedForm(), expression=false)
h_ip! = eval(h_str[2])
h_ip_skip! = eval(Symbolics.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g], skipzeros=true, fillzeros=false)[2])
h_ip_skip_par! = eval(Symbolics.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g], skipzeros=true, parallel=Symbolics.MultithreadedForm(), fillzeros=false)[2])
h3_oop = eval(h_str_3[1])
h3_ip = eval(h_str_3[2])
h4_oop = eval(h_str_4[1])
h4_ip = eval(h_str_4[2])
inputs = ([1], [2], [3, 4, 5], [6], [7], [8])

@test h_oop(inputs...) == h_julia(inputs...)
@test h_oop_par(inputs...) == h_julia(inputs...)
@test h_par_rgf[1](inputs...) == h_julia(inputs...)
out_1 = similar(h, Int)
out_2 = similar(out_1)
h_ip!(out_1, inputs...)
h_julia!(out_2, inputs...)
@test_throws ArgumentError h3_oop(inputs...)
@test out_1 == out_2
h3_ip(out_1, inputs...)
@test out_1 == out_2
@test_throws ArgumentError h4_ip(out_1, inputs...)
@test h4_oop(inputs...) == h_julia(inputs...)
out_1 = similar(h, Int)
h_par_rgf[2](out_1, inputs...)
@test out_1 == out_2
fill!(out_1, 10)
h_ip_skip!(out_1, inputs...)
@test out_1[3] == 10
out_1[3] = 0
@test out_1 == out_2

fill!(out_1, 10)
h_ip_skip_par!(out_1, inputs...)
@test out_1[3] == 10
out_1[3] = 0
@test out_1 == out_2

# Multiple input matrix, some unused arguments
h_skip = [a + b + c1; c2 + c3 + g] # skip d, e
h_julia_skip(a, b, c, d, e, g) = [a[1] + b[1] + c[1]; c[2] + c[3] + g[1]]
function h_julia_skip!(out, a, b, c, d, e, g)
    out .= [a[1] + b[1] + c[1]; c[2] + c[3] + g[1]]
end

h_str_skip = Symbolics.build_function(h_skip, [a], [b], [c1, c2, c3], [], [], [g], checkbounds=true)
h_str_skip_cse = Symbolics.build_function(h_skip, [a], [b], [c1, c2, c3], [], [], [g], checkbounds=true, cse=true)
h_oop_skip = eval(h_str_skip[1])
h_ip!_skip = eval(h_str_skip[2])
h_oop_skip_cse = eval(h_str_skip_cse[1])
h_ip!_skip_cse = eval(h_str_skip_cse[2])
inputs_skip = ([1], [2], [3, 4, 5], [], [], [8])

@test h_oop_skip(inputs_skip...) == h_julia_skip(inputs_skip...) == h_oop_skip_cse(inputs_skip...)
out_1_skip = Array{Int64}(undef, 2)
out_2_skip = similar(out_1_skip)
h_ip!_skip(out_1_skip, inputs_skip...)
h_julia_skip!(out_2_skip, inputs_skip...)
@test out_1_skip == out_2_skip

# Same as above, except test ability to call with non-matrix arguments (i.e., for `nt`)
inputs_skip_2 = ([1], [2], [3, 4, 5], [], (a = 1, b = 2), [8])
@test h_oop_skip(inputs_skip_2...) == h_julia_skip(inputs_skip_2...)
out_1_skip_2 = Array{Int64}(undef, 2)
out_2_skip_2 = similar(out_1_skip_2)
h_ip!_skip(out_1_skip_2, inputs_skip_2...)
h_julia_skip!(out_2_skip_2, inputs_skip_2...)
@test out_1_skip_2 == out_2_skip_2

# Multiple input scalar
h_scalar = a + b + c1 + c2 + c3 + d + e + g
h_julia_scalar(a, b, c, d, e, g) = a[1] + b[1] + c[1] + c[2] + c[3] + d[1] + e[1] + g[1]
h_str_scalar = Symbolics.build_function(h_scalar, [a], [b], [c1, c2, c3], [d], [e], [g])
h_str_scalar2 = Symbolics.build_function(h_scalar, [a], [b], [c1, c2, c3], [d], [e], [g])
h_str_scalar_cse = Symbolics.build_function(h_scalar, [a], [b], [c1, c2, c3], [d], [e], [g], cse=true)
@test h_str_scalar == h_str_scalar2

h_oop_scalar = eval(h_str_scalar)
h_oop_scalar_cse = eval(h_str_scalar_cse)
@test h_oop_scalar(inputs...) == h_julia_scalar(inputs...) == h_oop_scalar_cse(inputs...)

@variables z[1:100]
@variables t x(t) y(t) k
f = eval(build_function((x+y)/k, [x,y,k]))
@test f([1,1,2]) == 1

f = eval(build_function([(x+y)/k], [x,y,k])[1])
@test f([1,1,2]) == [1]

f = eval(build_function([(x+y)/k], [x,y,k])[2])
z = [0.0]
f(z, [1,1,2])
@test z == [1]

f = eval(build_function(sparse([1],[1], [(x+y)/k], 10,10), [x,y,k])[1])

@test size(f([1.,1.,2])) == (10,10)
@test f([1.,1.,2])[1,1] == 1.0
@test sum(f([1.,1.,2])) == 1.0

# Reshaped SparseMatrix optimization
let
    @variables a b c

    x = reshape(sparse([0 a 0; 0 b c]), 3, 2)
    f1,f2=build_function(x, [a,b,c], expression=Val{false})
    y = f1([1,2,3])
    @test y isa Base.ReshapedArray
    @test y.parent isa SparseMatrixCSC
    @test y.parent.rowval == x.parent.rowval
    @test y == [0 2; 0 0; 1 3]

    f1,f2=build_function(@views(x[2:3,1:2]), [a,b,c], expression=Val{false})
    y = f1([1,2,3])
    @test y isa SparseMatrixCSC
    @test y == [0 0; 1 3]
end

let # ModelingToolkit.jl#800
    @variables x
    y = sparse(1:3,1:3,x)

    f1,f2 = build_function(y,x)
    sf1, sf2 = string(f1), string(f2)
    @test !contains(sf1, "CartesianIndex")
    @test !contains(sf2, "CartesianIndex")
    @test contains(sf2, ".nzval")
end

let # Symbolics.jl#123
    ns = 6
    @variables x[1:ns]
    @variables u
    @variables M[1:36]
    @variables qd[1:6]
    output_eq = u*(qd[1]*(M[1]*qd[1] + M[7]*qd[2]))

    @test_reference "target_functions/issue123.c" build_function(output_eq, x, target=Symbolics.CTarget())
end

using Symbolics: value
using SymbolicUtils.Code: Func, toexpr
@variables t x(t)
D = Differential(t)
expr = toexpr(Func([value(D(x))], [], value(D(x))))
@test expr.args[2].args[end] == expr.args[1].args[1] # check function body and function arg
@test expr.args[2].args[end] == :(var"Differential(t)(x(t))")

## Oop Arr case:
#

a = rand(4)
@variables x[1:4]
@test eval(build_function(sin.(cos.(x)), cos.(x))[1])(a) == sin.(a)

# more skipzeros
@variables x,y
f = [0, x]
f_expr = build_function(f, [x,y];skipzeros=true, expression = Val{false})

out = Vector{Float64}(undef, 2)
u = [5.0, 3.1]
@test f_expr[1](u) == [0, 5]
old = out[1]
f_expr[2](out, u)
@test out[1] === old
@test out[2] === u[1]


let # issue#136
    N = 8
    @variables x y
    A = sparse(Tridiagonal([x^i for i in 1:N-1],
                           [x^i * y^(8-i) for i in 1:N],
                           [y^i for i in 1:N-1]))

    val = Dict(x=>1, y=>2)
    B = map(A) do e
        Num(substitute(e, val))
    end

    C = copy(B) - 100*I
    C_2 = copy(C);

    f = build_function(A,[x,y],parallel=Symbolics.MultithreadedForm())[2]
    g = eval(f)
    f_cse = build_function(A,[x,y],parallel=Symbolics.MultithreadedForm(),cse=true)[2]
    g_cse = eval(f_cse)

    g(C, [1,2])
    @test contains(repr(f), "schedule")
    @test isequal(C, B)
    g_cse(C_2, [1,2])
    @test isequal(C_2, B)
end


let #issue#587
    using Symbolics, SparseArrays

    N = 100 # try with N = 5 and N = 100
    _S = sprand(N, N, 0.1)
    _Q = Array(sprand(N, N, 0.1))

    F(z) = [
            _S * z
            _Q * z.^2
           ]

    Symbolics.@variables z[1:N]

    sj = Symbolics.sparsejacobian(F(z), z)

    f_expr = build_function(sj, z)
    myf = eval(first(f_expr))
    J = myf(rand(N))

    @test typeof(J) <: SparseMatrixCSC
    @test nnz(J) == nnz(sj)
end

# test header wrapping of scalar build function
let
    @variables x p t
    ex = t + p * x^2
    integrator = gensym(:MTKIntegrator)
    header = expr -> let integrator = integrator
        Func([expr.args[1], expr.args[2], DestructuredArgs(expr.args[3:end], integrator,
                                                           inds = [:p])], [], expr.body)
    end
    f = build_function(ex, [value(x)], value(t), [value(p)]; expression = Val{false},
                                                                wrap_code = header)
    p = (a = 10, p = [2])
    @test f([3], 1, p) == 19
end

let #658
    using Symbolics
    @variables a, X1[1:3], X2[1:3]
    k = eval(build_function(a * X1 + X2, X1, X2, a)[1])
    @test k(ones(3), ones(3), 1.5) == [2.5, 2.5, 2.5]
end

@testset "ArrayOp codegen" begin
    @variables x[1:2]
    T = value(x .^ 2)
    @test_nowarn toexpr(T, NameState())
end

@testset "`similarto` keyword argument" begin
    @variables x[1:2]
    T = collect(value(x .^ 2))
    fn = build_function(T, collect(x); expression = false)[1]
    @test_throws MethodError fn((1.0, 2.0))
    fn = build_function(T, collect(x); similarto = Array, expression = false)[1]
    @test fn((1.0, 2.0)) ≈ [1.0, 4.0]
end

@testset "`build_function` with array symbolics" begin
    @variables x[1:4]
    for var in [x[1:2], x[1:2] .+ 0.0, Symbolics.unwrap(x[1:2])]
        foop, fiip = build_function(var[1:2], x; expression = false)
        @test foop(ones(4)) ≈ ones(2)
        buf = zeros(2)
        fiip(buf, ones(4))
        @test buf ≈ ones(2)
    end
end

@testset "cse with arrayops" begin
    @variables x[1:3] y f(..)
    t = x .+ y
    t = t .* f(t)
    res = cse(value(t))
    @test res isa Let
    @test !isempty(res.pairs)
end

@testset "`CallWithMetadata` in `DestructuredArgs` with `create_bindings = false`" begin
    @variables x f(..)
    fn = build_function(f(x), DestructuredArgs([f]; create_bindings = false), x; expression = Val{false})
    @test fn([isodd], 3)
end

@testset "iip_config with RGF" begin
    @variables a b
    oop, iip = build_function([a + b, a - b], a, b; iip_config = (false, false), expression = Val{false})
    @test_throws ArgumentError oop(1, 2)
    @test_throws ArgumentError iip(ones(2), 1, 2)

    @variables a[1:2]
    oop, iip = build_function(a .* 2, a; iip_config = (false, false), expression = Val{false})
    @test_throws ArgumentError oop(ones(2))
    @test_throws ArgumentError iip(ones(2), ones(2))
end

@testset "unwrapping/CSE in array of symbolics codegen" begin
    @variables a b
    oop, _ = build_function([a^2 + b^2, a^2 + b^2], a, b; expression = Val{true}, cse = true)

    function find_create_array(expr)
        while expr isa Expr && (!Meta.isexpr(expr, :call) || expr.args[1] != SymbolicUtils.Code.create_array)
            expr = expr.args[end]
        end
        return expr
    end

    expr = find_create_array(oop)
    # CSE works, we just need to test that it's happening and OOP is the easiest way to do it
    @test Meta.isexpr(expr, :call) && expr.args[1] == SymbolicUtils.Code.create_array &&
          expr.args[end] isa Symbol && expr.args[end-1] isa Symbol
end

@testset "CSE with operators" begin
    @variables t x(t)
    D = Differential(t)
    f = build_function(x + D(x), [x, D(x)]; cse = true, expression = Val{false})
    @test f([1, 2]) == 3
end

@testset "`build_function` with `UpperTriangular`" begin
    function f_test(J,u)
        J[1,1] = u[1]
        J[1,2] = u[2]
        J[2,1] = -u[1]
        J[2,2] = -u[2]
        return nothing
    end

    @variables u[1:2]
    J = fill!(Array{Num}(undef, 2, 2), 0)
    f_test(J, u)
    up_J = UpperTriangular(J - Diagonal(J))

    out, fjac_upper_expr = build_function(up_J, u; skipzeros = true, expression = false)
    Jtmp = UpperTriangular(zeros(2, 2))
    utmp = rand(2)
    @test_nowarn fjac_upper_expr(Jtmp, utmp)
    @test Jtmp[3] == utmp[2]
end
