Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor refactoring of the FFTs #996

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ include("SymOp.jl")

export Smearing
export Model
export FFTGrid
export MonkhorstPack, ExplicitKpoints
export PlaneWaveBasis
export compute_fft_size
Expand All @@ -85,8 +86,9 @@ include("Smearing.jl")
include("Model.jl")
include("structure.jl")
include("bzmesh.jl")
include("PlaneWaveBasis.jl")
include("Kpoint.jl")
include("fft.jl")
include("PlaneWaveBasis.jl")
include("orbitals.jl")
include("input_output.jl")

Expand Down
94 changes: 94 additions & 0 deletions src/Kpoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
G_vectors(fft_size::Tuple)


of given sizes.
"""
function G_vectors(fft_size::Union{Tuple,AbstractVector})
# Note that a collect(G_vectors_generator(fft_size)) is 100-fold slower
# than this implementation, hence the code duplication.
start = .- cld.(fft_size .- 1, 2)
stop = fld.(fft_size .- 1, 2)
axes = [[collect(0:stop[i]); collect(start[i]:-1)] for i = 1:3]
[Vec3{Int}(i, j, k) for i in axes[1], j in axes[2], k in axes[3]]
end

function G_vectors_generator(fft_size::Union{Tuple,AbstractVector})
# The generator version is used mainly in symmetry.jl for lowpass_for_symmetry! and
# accumulate_over_symmetries!, which are 100-fold slower with G_vector(fft_size).
start = .- cld.(fft_size .- 1, 2)
stop = fld.(fft_size .- 1, 2)
axes = [[collect(0:stop[i]); collect(start[i]:-1)] for i = 1:3]
(Vec3{Int}(i, j, k) for i in axes[1], j in axes[2], k in axes[3])
end

"""
Discretization information for ``k``-point-dependent quantities such as orbitals.
More generally, a ``k``-point is a block of the Hamiltonian;
e.g. collinear spin is treated by doubling the number of ``k``-points.
"""
struct Kpoint{T <: Real, GT <: AbstractVector{Vec3{Int}}}
spin::Int # Spin component can be 1 or 2 as index into what is
# # returned by the `spin_components` function
coordinate::Vec3{T} # Fractional coordinate of k-point
G_vectors::GT # Wave vectors in integer coordinates (vector of Vec3{Int})
# # ({G, 1/2 |k+G|^2 ≤ Ecut})
# This is not assumed to be in any particular order
mapping::Vector{Int} # Index of G_vectors[i] on the FFT grid:
# # G_vectors(basis)[kpt.mapping[i]] == G_vectors(basis, kpt)[i]
mapping_inv::Dict{Int, Int} # Inverse of `mapping`:
# # G_vectors(basis)[i] == G_vectors(basis, kpt)[mapping_inv[i]]
end

function Kpoint(spin::Integer, coordinate::AbstractVector{<:Real},
recip_lattice::AbstractMatrix{T}, fft_size, Ecut;
variational=true, architecture::AbstractArchitecture) where {T}
mapping = Int[]
Gvecs_k = Vec3{Int}[]
k = Vec3{T}(coordinate)
# provide a rough hint so that the arrays don't have to be resized so much
n_guess = div(prod(fft_size), 8)
sizehint!(mapping, n_guess)
sizehint!(Gvecs_k, n_guess)
for (i, G) in enumerate(G_vectors(fft_size))
if !variational || norm2(recip_lattice * (G + k)) / 2 ≤ Ecut
push!(mapping, i)
push!(Gvecs_k, G)
end
end
Gvecs_k = to_device(architecture, Gvecs_k)

mapping_inv = Dict(ifull => iball for (iball, ifull) in enumerate(mapping))
Kpoint(spin, k, Gvecs_k, mapping, mapping_inv)
end

# Construct the kpoint with coordinate equivalent_kpt.coordinate + ΔG.
# Equivalent to (but faster than) Kpoint(equivalent_kpt.coordinate + ΔG).
function construct_from_equivalent_kpt(fft_size, equivalent_kpt, coordinate, ΔG)
linear = LinearIndices(fft_size)
# Mapping is the same as if created from scratch, although it is not ordered.
mapping = map(CartesianIndices(fft_size)[equivalent_kpt.mapping]) do G
linear[CartesianIndex(mod1.(Tuple(G + CartesianIndex(ΔG...)), fft_size))]
end
mapping_inv = Dict(ifull => iball for (iball, ifull) in enumerate(mapping))
Kpoint(equivalent_kpt.spin, Vec3(coordinate), equivalent_kpt.G_vectors .+ Ref(ΔG),
mapping, mapping_inv)
end

@timing function build_kpoints(model::Model{T}, fft_size, kcoords, Ecut;
variational=true,
architecture::AbstractArchitecture) where {T}
# Build all k-points for the first spin

Check warning on line 81 in src/Kpoint.jl

View check run for this annotation

Codecov / codecov/patch

src/Kpoint.jl#L78-L81

Added lines #L78 - L81 were not covered by tests
kpoints_spin_1 = [Kpoint(1, k, model.recip_lattice, fft_size, Ecut;
variational, architecture)
for k in kcoords]
all_kpoints = similar(kpoints_spin_1, 0)
for iσ = 1:model.n_spin_components
for kpt in kpoints_spin_1
push!(all_kpoints, Kpoint(iσ, kpt.coordinate,
kpt.G_vectors, kpt.mapping, kpt.mapping_inv))
end
end
all_kpoints
end

174 changes: 43 additions & 131 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,6 @@ abstract type AbstractBasis{T <: Real} end
# products of orbitals. This also defines the real-space grid
# (as the dual of the cubic basis set).

"""
Discretization information for ``k``-point-dependent quantities such as orbitals.
More generally, a ``k``-point is a block of the Hamiltonian;
e.g. collinear spin is treated by doubling the number of ``k``-points.
"""
struct Kpoint{T <: Real, GT <: AbstractVector{Vec3{Int}}}
spin::Int # Spin component can be 1 or 2 as index into what is
# # returned by the `spin_components` function
coordinate::Vec3{T} # Fractional coordinate of k-point
G_vectors::GT # Wave vectors in integer coordinates (vector of Vec3{Int})
# # ({G, 1/2 |k+G|^2 ≤ Ecut})
# This is not assumed to be in any particular order
mapping::Vector{Int} # Index of G_vectors[i] on the FFT grid:
# # G_vectors(basis)[kpt.mapping[i]] == G_vectors(basis, kpt)[i]
mapping_inv::Dict{Int, Int} # Inverse of `mapping`:
# # G_vectors(basis)[i] == G_vectors(basis, kpt)[mapping_inv[i]]
end

@doc raw"""
A plane-wave discretized `Model`.
Normalization conventions:
Expand Down Expand Up @@ -63,20 +45,8 @@ struct PlaneWaveBasis{T,
variational::Bool # Is the k-point specific basis variationally consistent with
# the basis used for the density / potential?

## Plans for forward and backward FFT
# All these plans are *completely unnormalized* (eg FFT * BFFT != I)
# The normalizations are performed in ifft/fft according to
# the DFTK conventions (see above)
opFFT # out-of-place FFT plan
ipFFT # in-place FFT plan
opBFFT # inverse plans (unnormalized plan; backward in FFTW terminology)
ipBFFT
fft_normalization::T # fft = fft_normalization * FFT
ifft_normalization::T # ifft = ifft_normalization * BFFT

# "cubic" basis in reciprocal and real space, on which potentials and densities are stored
G_vectors::T_G_vectors
r_vectors::T_r_vectors
# A FFTGrid containing all necessary data for FFT opertations related to this basis
fft_grid::FFTGrid{T, VT, T_G_vectors, T_r_vectors}

## MPI-local information of the kpoints this processor treats
# Irreducible kpoints. In the case of collinear spin,
Expand Down Expand Up @@ -127,43 +97,12 @@ Base.Broadcast.broadcastable(basis::PlaneWaveBasis) = Ref(basis)
Base.eltype(::PlaneWaveBasis{T}) where {T} = T


function Kpoint(spin::Integer, coordinate::AbstractVector{<:Real},
recip_lattice::AbstractMatrix{T}, fft_size, Ecut;
variational=true, architecture::AbstractArchitecture) where {T}
mapping = Int[]
Gvecs_k = Vec3{Int}[]
k = Vec3{T}(coordinate)
# provide a rough hint so that the arrays don't have to be resized so much
n_guess = div(prod(fft_size), 8)
sizehint!(mapping, n_guess)
sizehint!(Gvecs_k, n_guess)
for (i, G) in enumerate(G_vectors(fft_size))
if !variational || norm2(recip_lattice * (G + k)) / 2 ≤ Ecut
push!(mapping, i)
push!(Gvecs_k, G)
end
end
Gvecs_k = to_device(architecture, Gvecs_k)

mapping_inv = Dict(ifull => iball for (iball, ifull) in enumerate(mapping))
Kpoint(spin, k, Gvecs_k, mapping, mapping_inv)
end
function Kpoint(basis::PlaneWaveBasis, coordinate::AbstractVector, spin::Int)
Kpoint(spin, coordinate, basis.model.recip_lattice, basis.fft_size, basis.Ecut;
basis.variational, basis.architecture)
end
# Construct the kpoint with coordinate equivalent_kpt.coordinate + ΔG.
# Equivalent to (but faster than) Kpoint(equivalent_kpt.coordinate + ΔG).
function construct_from_equivalent_kpt(basis, equivalent_kpt, coordinate, ΔG)
linear = LinearIndices(basis.fft_size)
# Mapping is the same as if created from scratch, although it is not ordered.
mapping = map(CartesianIndices(basis.fft_size)[equivalent_kpt.mapping]) do G
linear[CartesianIndex(mod1.(Tuple(G + CartesianIndex(ΔG...)), basis.fft_size))]
end
mapping_inv = Dict(ifull => iball for (iball, ifull) in enumerate(mapping))
Kpoint(equivalent_kpt.spin, Vec3(coordinate), equivalent_kpt.G_vectors .+ Ref(ΔG),
mapping, mapping_inv)
end


# Returns the kpoint at given coordinate. If outside the Brillouin zone, it is created
# from an equivalent kpoint in the basis (also returned)
function get_kpoint(basis::PlaneWaveBasis{T}, kcoord, spin) where {T}
Expand All @@ -172,28 +111,11 @@ function get_kpoint(basis::PlaneWaveBasis{T}, kcoord, spin) where {T}
if iszero(ΔG)
kpt = equivalent_kpt
else
kpt = construct_from_equivalent_kpt(basis, equivalent_kpt, kcoord, ΔG)
kpt = construct_from_equivalent_kpt(basis.fft_size, equivalent_kpt, kcoord, ΔG)
end
(; kpt, equivalent_kpt)
end

@timing function build_kpoints(model::Model{T}, fft_size, kcoords, Ecut;
variational=true,
architecture::AbstractArchitecture) where {T}
# Build all k-points for the first spin
kpoints_spin_1 = [Kpoint(1, k, model.recip_lattice, fft_size, Ecut;
variational, architecture)
for k in kcoords]
all_kpoints = similar(kpoints_spin_1, 0)
for iσ = 1:model.n_spin_components
for kpt in kpoints_spin_1
push!(all_kpoints, Kpoint(iσ, kpt.coordinate,
kpt.G_vectors, kpt.mapping, kpt.mapping_inv))
end
end
all_kpoints
end

# Lowest-level constructor, should not be called directly.
# All given parameters must be the same on all processors
# and are stored in PlaneWaveBasis for easy reconstruction.
Expand Down Expand Up @@ -249,17 +171,7 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Real, fft_size::Tuple{Int, Int, I
kweights_global = convert(Vector{T}, kdata.kweights)

# Setup FFT plans
Gs = to_device(architecture, G_vectors(fft_size))
(ipFFT, opFFT, ipBFFT, opBFFT) = build_fft_plans!(similar(Gs, Complex{T}, fft_size))

# Normalization constants
# fft = fft_normalization * FFT
# The convention we want is
# ψ(r) = sum_G c_G e^iGr / sqrt(Ω)
# so that the ifft has to normalized by 1/sqrt(Ω).
# The other constant is chosen because FFT * BFFT = N
ifft_normalization = 1/sqrt(model.unit_cell_volume)
fft_normalization = sqrt(model.unit_cell_volume) / length(ipFFT)
fft_grid = FFTGrid(fft_size, model.unit_cell_volume, architecture)

# Compute k-point information and spread them across processors
# Right now we split only the kcoords: both spin channels have to be handled
Expand Down Expand Up @@ -325,20 +237,14 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Real, fft_size::Tuple{Int, Int, I
error("Can't mix multi-threading and GPU computations yet.")
end

VT = value_type(T)
dvol = model.unit_cell_volume ./ prod(fft_size)
r_vectors = [(Vec3{VT}(idx.I) .- (1, 1, 1)) ./ VT.(fft_size)
for idx in CartesianIndices(fft_size)]
r_vectors = to_device(architecture, r_vectors)
terms = Vector{Any}(undef, length(model.term_types)) # Dummy terms array, filled below

basis = PlaneWaveBasis{T, value_type(T), Arch, typeof(Gs), typeof(r_vectors),
typeof(kpoints[1].G_vectors)}(
basis = PlaneWaveBasis{T, value_type(T), Arch, typeof(fft_grid.G_vectors),
typeof(fft_grid.r_vectors), typeof(kpoints[1].G_vectors)}(
model, fft_size, dvol,
Ecut, variational,
opFFT, ipFFT, opBFFT, ipBFFT,
fft_normalization, ifft_normalization,
Gs, r_vectors,
fft_grid,
kpoints, kweights, kgrid,
kcoords_global, kweights_global,
comm_kpts, krange_thisproc, krange_allprocs, krange_thisproc_allspin,
Expand Down Expand Up @@ -413,30 +319,6 @@ e.g. an [`MonkhorstPack`](@ref) or a [`ExplicitKpoints`](@ref) grid.
basis.comm_kpts, basis.architecture)
end

"""
G_vectors(fft_size::Tuple)

The wave vectors `G` in reduced (integer) coordinates for a cubic basis set
of given sizes.
"""
function G_vectors(fft_size::Union{Tuple,AbstractVector})
# Note that a collect(G_vectors_generator(fft_size)) is 100-fold slower
# than this implementation, hence the code duplication.
start = .- cld.(fft_size .- 1, 2)
stop = fld.(fft_size .- 1, 2)
axes = [[collect(0:stop[i]); collect(start[i]:-1)] for i = 1:3]
[Vec3{Int}(i, j, k) for i in axes[1], j in axes[2], k in axes[3]]
end

function G_vectors_generator(fft_size::Union{Tuple,AbstractVector})
# The generator version is used mainly in symmetry.jl for lowpass_for_symmetry! and
# accumulate_over_symmetries!, which are 100-fold slower with G_vector(fft_size).
start = .- cld.(fft_size .- 1, 2)
stop = fld.(fft_size .- 1, 2)
axes = [[collect(0:stop[i]); collect(start[i]:-1)] for i = 1:3]
(Vec3{Int}(i, j, k) for i in axes[1], j in axes[2], k in axes[3])
end


@doc raw"""
G_vectors(basis::PlaneWaveBasis)
Expand All @@ -445,11 +327,9 @@ end
The list of wave vectors ``G`` in reduced (integer) coordinates of a `basis`
or a ``k``-point `kpt`.
"""
G_vectors(basis::PlaneWaveBasis) = basis.G_vectors
G_vectors(basis::PlaneWaveBasis) = basis.fft_grid.G_vectors
G_vectors(::PlaneWaveBasis, kpt::Kpoint) = kpt.G_vectors



@doc raw"""
G_vectors_cart(basis::PlaneWaveBasis)
G_vectors_cart(basis::PlaneWaveBasis, kpt::Kpoint)
Expand Down Expand Up @@ -487,7 +367,7 @@ end

The list of ``r`` vectors, in reduced coordinates. By convention, this is in [0,1)^3.
"""
r_vectors(basis::PlaneWaveBasis) = basis.r_vectors
r_vectors(basis::PlaneWaveBasis) = basis.fft_grid.r_vectors

@doc raw"""
r_vectors_cart(basis::PlaneWaveBasis)
Expand Down Expand Up @@ -665,3 +545,35 @@ function scatter_kpts_block(basis::PlaneWaveBasis, data::Union{Nothing,AbstractA
splitted
end
end

"""
Forward FFT calls to the PlaneWaveBasis fft_grid field
"""
ifft!(f_real::AbstractArray3, basis::PlaneWaveBasis, f_fourier::AbstractArray3) =
ifft!(f_real, basis.fft_grid, f_fourier)

ifft!(f_real::AbstractArray3, basis::PlaneWaveBasis, kpt::Kpoint,
f_fourier::AbstractVector; normalize=true) =
ifft!(f_real, basis.fft_grid, kpt, f_fourier; normalize=normalize)

ifft(basis::PlaneWaveBasis, f_fourier::AbstractArray) = ifft(basis.fft_grid, f_fourier)

ifft(basis::PlaneWaveBasis, kpt::Kpoint, f_fourier::AbstractVector; kwargs...) =
ifft(basis.fft_grid, kpt, f_fourier; kwargs ...)

irfft(basis::PlaneWaveBasis, f_fourier::AbstractArray) = irfft(basis.fft_grid, f_fourier)

fft!(f_fourier::AbstractArray3, basis::PlaneWaveBasis, f_real::AbstractArray3) =
fft!(f_fourier, basis.fft_grid, f_real)

fft!(f_fourier::AbstractVector, basis::PlaneWaveBasis, kpt::Kpoint,
f_real::AbstractArray3; normalize=true) =
fft!(f_fourier, basis.fft_grid, kpt, f_real; normalize=normalize)

fft(basis::PlaneWaveBasis, f_real::AbstractArray) = fft(basis.fft_grid, f_real)

fft(basis::PlaneWaveBasis, kpt::Kpoint, f_real::AbstractArray3; kwargs...) =
fft(basis.fft_grid, kpt, f_real; kwargs...)

ifft_matrix(basis::PlaneWaveBasis) = ifft_matrix(basis.fft_grid)
fft_matrix(basis::PlaneWaveBasis) = fft_matrix(basis.fft_grid)
Loading
Loading