diff --git a/src/Basemesh.jl b/src/Basemesh.jl index d70fc16..513ca1a 100644 --- a/src/Basemesh.jl +++ b/src/Basemesh.jl @@ -13,7 +13,7 @@ function _compute_inverse_lattice(lattice) return inv(lattice) end -function _compute_recip_lattice(lattice) +function _compute_recip_lattice(lattice::Matrix{T}) where {T} return 2T(π) * _compute_inverse_lattice(lattice) end @@ -25,22 +25,21 @@ struct Brillouin{T,DIM} unit_cell_volume::T recip_cell_volume::T G_vector::Vector{SVector{DIM,Int}} +end - function Brillouin(; lattice::Matrix{T}, G_vector=nothing) where {T} - DIM = size(lattice, 1) - recip_lattice = _compute_recip_lattice(lattice) - inv_lattice = _compute_inverse_lattice(lattice) - inv_recip_lattice = _compute_inverse_lattice(recip_lattice) - unit_cell_volume = abs(det(lattice)) - recip_cell_volume = abs(det(recip_lattice)) - if isnothing(G_vector) - G_vector = [SVector{DIM,Int}(zeros(DIM)),] - end - - return new{T,DIM}(lattice, recip_lattice, inv_lattice, inv_recip_lattice, unit_cell_volume, recip_cell_volume, G_vector) +function Brillouin(; lattice::Matrix{T}, G_vector=nothing) where {T} + DIM = size(lattice, 1) + recip_lattice = _compute_recip_lattice(lattice) + inv_lattice = _compute_inverse_lattice(lattice) + inv_recip_lattice = _compute_inverse_lattice(recip_lattice) + unit_cell_volume = abs(det(lattice)) + recip_cell_volume = abs(det(recip_lattice)) + if isnothing(G_vector) + G_vector = [SVector{DIM,Int}(zeros(DIM)),] end -end + return Brillouin{T,DIM}(lattice, recip_lattice, inv_lattice, inv_recip_lattice, unit_cell_volume, recip_cell_volume, G_vector) +end struct UniformBZMesh{T,DIM} <: AbstractMesh{T,DIM} br::Brillouin{T,DIM} @@ -49,6 +48,11 @@ struct UniformBZMesh{T,DIM} <: AbstractMesh{T,DIM} shift::SVector{DIM,Rational} end +# default shift is 1/2, result in Monkhorst-Pack mesh +# with shift = 0, result in Gamma-centered +# can also customize with shift::SVector by calling default constructor +UniformBZMesh(; br::Brillouin{T,DIM}, size, shift::Number=1 // 2) where {T,DIM} = UniformBZMesh{T,DIM}(br, size, SVector{DIM,Rational}(shift .* ones(Int, DIM))) + Base.length(mesh::UniformBZMesh) = prod(mesh.size) Base.size(mesh::UniformBZMesh) = mesh.size Base.size(mesh::UniformBZMesh, I) = mesh.size[I] @@ -57,6 +61,81 @@ function Base.show(io::IO, mesh::UniformBZMesh) println("UniformBZMesh with $(length(mesh)) mesh points") end +@generated function _inds2ind(size::NTuple{DIM,Int}, I) where {DIM} + ex = :(I[DIM] - 1) + for i = (DIM-1):-1:1 + ex = :(I[$i] - 1 + size[$i] * $ex) + end + return :($ex + 1) +end + +@generated function _ind2inds(size::NTuple{DIM,Int}, I::Int) where {DIM} + inds, quotient = :((I - 1) % size[1] + 1), :((I - 1) ÷ size[1]) + for i = 2:DIM-1 + inds, quotient = :($inds..., $quotient % size[$i] + 1), :($quotient ÷ size[$i]) + end + inds = :($inds..., $quotient + 1) + return :(SVector{DIM,Int}($inds)) +end + +function Base.getindex(mesh::UniformBZMesh{T,DIM}, inds...) where {T,DIM} + n = SVector{DIM,Int}(inds) + return mesh.br.recip_lattice * ((n .- 1 .+ mesh.shift) ./ mesh.size) +end + +function Base.getindex(mesh::UniformBZMesh, I::Int) + return Base.getindex(mesh, _ind2inds(mesh.size, I)...) +end + +Base.firstindex(mesh::UniformBZMesh) = 1 +Base.lastindex(mesh::UniformBZMesh) = length(mesh) +# # iterator +Base.iterate(mesh::UniformBZMesh) = (mesh[1], 1) +Base.iterate(mesh::UniformBZMesh, state) = (state >= length(mesh)) ? nothing : (mesh[state+1], state + 1) + +function _indfloor(x, N; edgeshift=1) + # edgeshift = 1 by default in floor function so that end point return N-1 + # edgeshift = 0 in locate function + if x < 1 + return 1 + elseif x >= N + return N - edgeshift + else + return floor(Int, x) + end +end + +function locate(mesh::UniformBZMesh{T,DIM}, x) where {T,DIM} + # find index of nearest grid point to the point + displacement = SVector{DIM,T}(x) + inds = (mesh.br.inv_recip_lattice * displacement) .* mesh.size .+ 1.5 .- mesh.shift .+ 2 .* eps.(T.(mesh.size)) + indexall = 1 + # println((mesh.invlatvec * displacement)) + # println(inds) + factor = 1 + indexall += (_indfloor(inds[1], mesh.size[1]; edgeshift=0) - 1) * factor + for i in 2:DIM + factor *= mesh.size[i-1] + indexall += (_indfloor(inds[i], mesh.size[i]; edgeshift=0) - 1) * factor + end + + return indexall +end + +volume(mesh::UniformBZMesh) = mesh.br.recip_cell_volume +function volume(mesh::UniformBZMesh{T,DIM}, i) where {T,DIM} + inds = _ind2inds(mesh.size, i) + cellarea = T(1.0) + for j in 1:DIM + if inds[j] == 1 + cellarea *= T(0.5) + mesh.shift[j] + elseif inds[j] == mesh.size[j] + cellarea *= T(1.5) - mesh.shift[j] + end + # else cellarea *= 1.0 so nothing + end + return cellarea / length(mesh) * volume(mesh) +end # c.f. DFTK.jl/src/Model.jl # UniformBZMesh iterate on 1st Brillouin Zone @@ -159,18 +238,6 @@ end return :(SVector{DIM,Int}($inds)) end -function _indfloor(x, N; edgeshift=1) - # edgeshift = 1 by default in floor function so that end point return N-1 - # edgeshift = 0 in locate function - if x < 1 - return 1 - elseif x >= N - return N - edgeshift - else - return floor(Int, x) - end -end - function Base.floor(mesh::UniformMesh{DIM,N}, x) where {DIM,N} # find index of nearest grid point to the point displacement = SVector{DIM,Float64}(x) - mesh.origin diff --git a/test/BaseMesh.jl b/test/BaseMesh.jl index 6fcd60c..9299974 100644 --- a/test/BaseMesh.jl +++ b/test/BaseMesh.jl @@ -1,6 +1,76 @@ @testset "Base Mesh" begin rng = MersenneTwister(1234) + @testset "Brillouin" begin + + # square lattice + DIM = 2 + lattice = Matrix([1.0 0; 0 1]') + br = BaseMesh.Brillouin(lattice=lattice) + @test br.inv_lattice .* 2π ≈ br.recip_lattice + @test br.unit_cell_volume ≈ abs(det(lattice)) + @test br.recip_cell_volume ≈ 1 / abs(det(lattice)) * (2π)^DIM + + # triagular lattice + DIM = 2 + lattice = Matrix([2.0 0; 1 sqrt(3)]') + br = BaseMesh.Brillouin(lattice=lattice) + @test br.inv_lattice .* 2π ≈ br.recip_lattice + @test br.unit_cell_volume ≈ abs(det(lattice)) + @test br.recip_cell_volume ≈ 1 / abs(det(lattice)) * (2π)^DIM + + # 3d testing lattice + DIM = 3 + lattice = Matrix([2.0 0 0; 1 sqrt(3) 0; 7 11 19]') + br = BaseMesh.Brillouin(lattice=lattice) + @test br.inv_lattice .* 2π ≈ br.recip_lattice + @test br.unit_cell_volume ≈ abs(det(lattice)) + @test br.recip_cell_volume ≈ 1 / abs(det(lattice)) * (2π)^DIM + end + + @testset "UniformBZMesh" begin + @testset "Indexing" begin + size = (3, 4, 5) + for i in 1:prod(size) + @test i == BaseMesh._inds2ind(size, BaseMesh._ind2inds(size, i)) + end + end + + @testset "Array Interface" begin + N1, N2 = 3, 5 + lattice = Matrix([1/N1/2 0; 0 1.0/N2/2]') .* 2π + # so that bzmesh[i,j] = (2i-1,2j-1) + br = BaseMesh.Brillouin(lattice=lattice) + bzmesh = BaseMesh.UniformBZMesh(br=br, size=(N1, N2)) + for (pi, p) in enumerate(bzmesh) + @test bzmesh[pi] ≈ p # linear index + inds = BaseMesh._ind2inds(bzmesh.size, pi) + @test p ≈ inds .* 2.0 .- 1.0 + @test bzmesh[inds...] ≈ p # cartesian index + end + end + + @testset "locate and volume" begin + size = (3, 5, 7) + lattice = Matrix([2.0 0 0; 1 sqrt(3) 0; 7 11 19]') + # size = (3, 5) + # lattice = Matrix([2.3 0; 0 7.0]') + br = BaseMesh.Brillouin(lattice=lattice) + bzmesh = BaseMesh.UniformBZMesh(br=br, size=size) + vol = 0.0 + for (pi, p) in enumerate(bzmesh) + @test bzmesh[pi] ≈ p # linear index + inds = BaseMesh._ind2inds(bzmesh.size, pi) + @test bzmesh[inds...] ≈ p # cartesian index + + @test BaseMesh.locate(bzmesh, p) == pi + vol += BaseMesh.volume(bzmesh, pi) + end + @test vol ≈ BaseMesh.volume(bzmesh) + end + end + + function dispersion(k) me = 0.5 return dot(k, k) / 2me