Skip to content

Commit

Permalink
pseudobulk unit tests and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rasmushenningsson committed Dec 12, 2023
1 parent 46af98f commit a96cfec
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
10 changes: 7 additions & 3 deletions src/pseudobulk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct PseudoBulkModel <: ProjectionModel
end
end
PseudoBulkModel(obs_col, args...; merged_id="id", delim='_', var=:copy) =
PseudoBulkModel(collect(obs_col, args...), merged_id, delim, var)
PseudoBulkModel(collect((obs_col, args...)), merged_id, delim, var)

projection_isequal(m1::PseudoBulkModel, m2::PseudoBulkModel) =
m1.obs_id_cols == m2.obs_id_cols &&
Expand Down Expand Up @@ -61,11 +61,15 @@ function project_impl(data::DataMatrix, model::PseudoBulkModel; verbose=true)
dropmissing!(obs) # This is the new obs annotation

# Find out which group each cell belongs to
obs_ind = table_indexin(data.obs, obs)
obs_ind = table_indexin(data.obs, obs; matchmissing=:equal)

# Create sparse matrix mapping cells to groups
N = size(data,2)
S = sparse(1:N, obs_ind, 1.0, N, size(obs,1))

mask = obs_ind .!== nothing
I = (1:N)[mask]
J = identity.(obs_ind[mask])
S = sparse(I, J, 1.0, N, size(obs,1))

# Make each column sum to one (so that we take mean for each group)
rmul!(S, Diagonal(1.0 ./ vec(sum(S; dims=1))))
Expand Down
4 changes: 2 additions & 2 deletions src/table_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ end
table_cols_equal(a, b; cols=names(b)) =
isequal(select(a, cols; copycols=false), select(b, cols; copycols=false))

function table_indexin(a, b; cols=names(b))
function table_indexin(a, b; cols=names(b), kwargs...)
b = select(b, cols; copycols=false)
a = select(a, cols; copycols=false)
b.__index__ .= 1:size(b,1)
leftjoin!(a,b; on=cols)
leftjoin!(a,b; on=cols, kwargs...)
coalesce.(a.__index__, nothing)
end
59 changes: 59 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,4 +414,63 @@

test_show(c; obs=vcat(names(counts.obs), ["A","B","C"]), models="VarCountsFractionModel")
end

@testset "pseudobulk $name" for (name,data,data_proj) in (("counts",counts,counts_proj), ("transformed",transformed,transformed_proj))
# TODO: Make pseudobulk work with data.matrix::Factorization
# @testset "pseudobulk $name" for (name,data,data_proj) in (("counts",counts,counts_proj), ("transformed",transformed,transformed_proj), ("reduced",reduced,reduced_proj))
d = copy(data)
d.obs.group2 = replace(d.obs.group, "C"=>missing)
d.obs.group3 = rand(StableRNG(276), ("a","b"), size(data,2))
d.obs.twogroup = replace(d.obs.group, "C"=>"A")
X = materialize(d.matrix)

@testset "$annot" for annot in ("group","group2","group3","twogroup")
unique_groups = collect(skipmissing(unique!(sort(d.obs[!,annot]))))

pb = pseudobulk(d, annot)
@test names(pb.obs) == ["id", annot]
@test unique!(sort(pb.obs.id)) == unique_groups
@test unique!(sort(pb.obs[!,annot])) == unique_groups

pb_X = materialize(pb.matrix)
@test size(pb_X,1) == size(X,1)
@test size(pb_X,2) == length(unique_groups)

for g in unique_groups
x = vec(mean(X[:, isequal.(d.obs[!,annot], g)]; dims=2))
gi = findfirst(isequal(g), pb.obs.id)

@test x pb_X[:,gi]
end
end

@testset "$annot1, $annot2" for (annot1,annot2) in (("group","group3"),("group","group2"),("group","twogroup"))
groups = string.(d.obs[!,annot1],'_',d.obs[!,annot2])
mask = .!ismissing.(d.obs[!,annot1]) .& .!ismissing.(d.obs[!,annot2])
unique_groups = unique!(sort!(groups[mask]))

pb = pseudobulk(d, annot1, annot2)
@test names(pb.obs) == ["id", annot1, annot2]
@test unique!(sort(pb.obs.id)) == unique_groups
@test unique!(sort(pb.obs[!,annot1])) == unique!(sort!(d.obs[mask,annot1]))
@test unique!(sort(pb.obs[!,annot2])) == unique!(sort!(d.obs[mask,annot2]))


pb_X = materialize(pb.matrix)
@test size(pb_X,1) == size(X,1)
@test size(pb_X,2) == length(unique_groups)

for g1 in unique(d.obs[mask,annot1]), g2 in unique(d.obs[mask,annot2])
g_mask = isequal.(d.obs[!,annot1], g1) .& isequal.(d.obs[!,annot2], g2)
x = vec(mean(X[:, g_mask]; dims=2))
gi = findfirst(isequal(string(g1,'_',g2)), pb.obs.id)
if any(g_mask) # are there any observations in this group?
@test x pb_X[:,gi]
else
@test gi === nothing
end
end
end

end
end

0 comments on commit a96cfec

Please sign in to comment.