Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use SimpleChains to train/run simple NNs #212

Closed
anandijain opened this issue Feb 8, 2023 · 8 comments
Closed

use SimpleChains to train/run simple NNs #212

anandijain opened this issue Feb 8, 2023 · 8 comments

Comments

@anandijain
Copy link

https://github.com/PumasAI/SimpleChains.jl

@anandijain
Copy link
Author

I didn't see a way to add comments to stuff on the project board without converting to issue first.

I just wanted to add the context that SimpleChains doesn't have GPU support, which means it would excise the huge CUDA stack that gets compiled and loaded every time. Another alternative is waiting on FluxML/Flux.jl#2132.

@orso82
Copy link
Member

orso82 commented Feb 9, 2023

@anandijain do you know if Flux is the only package that has the CUDA dependency?

@anandijain
Copy link
Author

anandijain commented Feb 9, 2023

im not sure, i can check before our next meeting.

the way to do it would be to use my code to generate the registry graph, then enumerate all paths from FUSE.jl to CUDA and see if flux is always an intermediate. however, just looking at the SVG, I did see a GA package that had CUDA as a direct dep

@anandijain
Copy link
Author

anandijain commented Feb 14, 2023

Okay, using

using MyPkgGraph, Catlab.Graphics, Graphs, Catlab, Catlab.Graphs
draw(r, s; dir=:out) = to_graphviz(MyPkgGraph.my_depgraph(r, s; dir); node_labels=:label)
draw(s; dir=:out) = draw(GENERAL_REGISTRY, s; dir)

r1 = MyPkgGraph.REGISTRIES[findfirst(reg.name == "GAregistry" for reg in MyPkgGraph.REGISTRIES)]
r2 = MyPkgGraph.GENERAL_REGISTRY
r3 = MyPkgGraph.merge_registries(r1, r2)
g = registry_graph(r3)
draw(r3, "FUSE")

dg = Graphs.SimpleDiGraph(g)
include(joinpath(@__DIR__, "all_simple_paths.jl"))
using Catlab, Catlab.Theories
using Catlab.CategoricalAlgebra
using Catlab.Graphs
using Catlab.Graphics

fuse = incident(g, "FUSE", :label)
cuda = incident(g, "CUDA", :label)

# https://github.com/JuliaGraphs/Graphs.jl/pull/20/files
ps = collect(all_simple_paths(dg, fuse, cuda))
map(p->g[p, :label], ps)

This enumerates all paths from FUSE to CUDA giving:

14-element Vector{SubArray{String, 1, Catlab.Columns.ColumnView{Int64, String, Catlab.ColumnImplementations.DenseInjectiveColumn{String, Vector{String}}, UnitRange{Int64}, Nothing}, Tuple{Vector{Int64}}, false}}:
 ["FUSE", "TAUENN", "TGLFNN", "CUDA"]
 ["FUSE", "TAUENN", "TGLFNN", "Flux", "CUDA"]
 ["FUSE", "TAUENN", "TGLFNN", "Flux", "NNlibCUDA", "cuDNN", "CUDA"]
 ["FUSE", "TAUENN", "TGLFNN", "Flux", "NNlibCUDA", "CUDA"]
 ["FUSE", "TAUENN", "EPEDNN", "Flux", "CUDA"]
 ["FUSE", "TAUENN", "EPEDNN", "Flux", "NNlibCUDA", "cuDNN", "CUDA"]
 ["FUSE", "TAUENN", "EPEDNN", "Flux", "NNlibCUDA", "CUDA"]
 ["FUSE", "TGLFNN", "CUDA"]
 ["FUSE", "TGLFNN", "Flux", "CUDA"]
 ["FUSE", "TGLFNN", "Flux", "NNlibCUDA", "cuDNN", "CUDA"]
 ["FUSE", "TGLFNN", "Flux", "NNlibCUDA", "CUDA"]
 ["FUSE", "EPEDNN", "Flux", "CUDA"]
 ["FUSE", "EPEDNN", "Flux", "NNlibCUDA", "cuDNN", "CUDA"]
 ["FUSE", "EPEDNN", "Flux", "NNlibCUDA", "CUDA"]

So it isn't the case that the only path to CUDA is via Flux, the exception being TGLFNN

here is the svg of the subgraph of the nodes contained in the above lists
cuda_subgraph

@orso82
Copy link
Member

orso82 commented Feb 14, 2023

Ok, but TGLFNN depends on CUDA only because I have this function that is used only during the training with Flux:

function on_device(args)
    if CUDA.functional() && args.use_cuda
        @debug "Training on CUDA GPU"
        CUDA.allowscalar(false)
        return Flux.gpu
    else
        @debug "Training on CPU"
        return Flux.cpu
    end
end

Probably the direct CUDA dependency can be removed by doing Flux.CUDA.functional() and Flux.CUDA.allowscalar

Meaning that if we removed the Flux dependence for the TGLFNN and EPEDNN training, we could drop the CUDA dependence entirely.

@anandijain
Copy link
Author

Yes, thats correct

@anandijain
Copy link
Author

I am looking at the size of your NNs and it does seem like the right size for SimpleChains to be fast. I was talking to Chris (the dev of SimpleChains) and he cautioned that it works, but is a bit limited with features and documentation

@orso82
Copy link
Member

orso82 commented Jul 18, 2023

After discussing with @ChrisRackauckas we have decided this is not worth the effort at this time

@orso82 orso82 closed this as completed Jul 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants