Skip to content

Commit

Permalink
Merge pull request #31 from numericalEFT/uniformbzmesh
Browse files Browse the repository at this point in the history
Uniformbzmesh
  • Loading branch information
kunyuan authored Oct 25, 2022
2 parents 3890dc0 + cc5464c commit 0f4b01b
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 22 deletions.
5 changes: 4 additions & 1 deletion src/AbstractMeshes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module AbstractMeshes

using ..StaticArrays

export AbstractMesh, locate, volume
export AbstractMesh, locate, volume, fractional_coordinates

# the return value of AbstractMesh should be a SVector{T,DIM}
abstract type AbstractMesh{T,DIM} <: AbstractArray{SVector{T,DIM},DIM} end
Expand Down Expand Up @@ -52,4 +52,7 @@ end
return :(SVector{DIM,Int}($inds))
end

# optional functions
fractional_coordinates(mesh::AbstractMesh, I::Int) = error("not implemented!")

end
25 changes: 23 additions & 2 deletions src/BaseMesh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ function Base.show(io::IO, mesh::UMesh)
println("UMesh with $(length(mesh)) mesh points")
end

function AbstractMeshes.fractional_coordinates(mesh::UMesh{T,DIM}, I::Int) where {T,DIM}
n = SVector{DIM,Int}(ind2inds(mesh.size, I))
return (n .- 1 .+ mesh.shift) ./ mesh.size
end

function AbstractMeshes.fractional_coordinates(mesh::UMesh{T,DIM}, x::AbstractVector) where {T,DIM}
displacement = SVector{DIM,T}(x)
return (mesh.inv_lattice * displacement)
end

function Base.getindex(mesh::UMesh{T,DIM}, inds...) where {T,DIM}
n = SVector{DIM,Int}(inds)
return mesh.origin + mesh.lattice * ((n .- 1 .+ mesh.shift) ./ mesh.size)
Expand All @@ -63,8 +73,8 @@ end

function AbstractMeshes.locate(mesh::UMesh{T,DIM}, x) where {T,DIM}
# find index of nearest grid point to the point
displacement = SVector{DIM,T}(x) - mesh.origin
inds = (mesh.inv_lattice * displacement) .* mesh.size .+ 1.5 .- mesh.shift .+ 2 .* eps.(T.(mesh.size))
svx = SVector{DIM,T}(x)
inds = fractional_coordinates(mesh, svx - mesh.origin) .* mesh.size .+ 1.5 .- mesh.shift .+ 2 .* eps.(T.(mesh.size))
indexall = 1
# println((mesh.invlatvec * displacement))
# println(inds)
Expand All @@ -81,6 +91,17 @@ AbstractMeshes.volume(mesh::UMesh) = mesh.volume
AbstractMeshes.volume(mesh::UMesh, i) = mesh.volume / length(mesh)


# in spglib, grid_address runs from 1-ceil(N/2) to N-ceil(N/2)
# thus -1:2 for N=4 and -2:2 for N=5

function spglib_grid_address_to_index(mesh::UMesh{T,DIM}, ga) where {T,DIM}
inds = ga[1:DIM] # if length(x)==3 but DIM==2, take first two
fcoords = (inds .+ mesh.shift) ./ mesh.size
# shift fcoords
fcoords = [(fcoords[i] < 0) ? (fcoords[i] + 1) : fcoords[i] for i in 1:DIM]
x = mesh.origin + mesh.lattice * fcoords
return locate(mesh, x)
end
#####################################
# LEGACY CODE BELOW
#####################################
Expand Down
2 changes: 1 addition & 1 deletion src/BrillouinZoneMeshes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export PolarMeshes
include("MeshMaps.jl")
using .MeshMaps
export MeshMaps
export SymMap, MappedData
export SymMap, MappedData, MeshMap, ReducedBZMesh

include("meshes/reduced_uiniform_map.jl")

Expand Down
37 changes: 19 additions & 18 deletions src/MeshMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ using ..AbstractMeshes
using ..Model
using ..TreeMeshes
using ..BaseMesh

export MeshMap, ReducedBZMesh

"""
struct MeshMap
Expand All @@ -20,23 +23,24 @@ struct MeshMap
map::Vector{Int}
inv_map::Dict{Int,Vector{Int}}

# function MeshMap(map::Vector{Int})
# reduced_length = maximum(map)
# inv_map = Vector{Vector{Int}}(undef, reduced_length)
# for (i, ind) in enumerate(map)
# # i is index of full mesh, ind is index of reduced mesh
# if isassigned(inv_map, ind)
# push!(inv_map[ind], i)
# else
# inv_map[ind] = Vector{Int}([i,])
# end
# end

# return new(map, reduced_length, inv_map)
# end
function MeshMap(map::Vector{Int})
irreducible_indices = Vector{Int}([])
inv_map = Dict{Int,Vector{Int}}([])
for (i, ind) in enumerate(map)
if !(ind in irreducible_indices)
push!(irreducible_indices, ind)
push!(inv_map, (ind => [i,]))
else
push!(inv_map[ind], i)
end
end
return new(irreducible_indices, map, inv_map)
end
end

# TODO: constructors that generate map for specific type of mesh and symmetry
MeshMap(mesh::AbstractMesh) = error("Map reduce not defined for $(typeof(mesh))!")


## TODO: 1st step: symmetry reduce for M-P mesh(centered uniform mesh)

Expand All @@ -50,14 +54,11 @@ Map-reduced mesh constructed from mesh::MT with symmetry reduction.
- `mesh`: bare mesh from which the reduced mesh constructed
- `meshmap`: map from mesh to the reduced mesh
"""
struct ReducedMesh{T,DIM,MT<:AbstractMesh{T,DIM}} <: AbstractMesh{T,DIM}
struct ReducedBZMesh{T,DIM,MT<:AbstractMesh{T,DIM}} <: AbstractMesh{T,DIM}
mesh::MT
meshmap::MeshMap
end

# TODO: implement AbstractMesh interface
# including AbstractArray interface and locate/volume functions




Expand Down
13 changes: 13 additions & 0 deletions test/MeshMap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
@testset "MeshMaps" begin
@testset "MeshMap" begin

# test MeshMap constructor
map = [1, 2, 2, 1, 2, 6, 6, 2, 2, 6, 6, 2, 1, 2, 2, 1]
mm = MeshMap(map)
# println(mm.irreducible_indices)
@test mm.irreducible_indices == [1, 2, 6]
@test mm.inv_map[1] == [1, 4, 13, 16]
@test mm.inv_map[2] == [2, 3, 5, 8, 9, 12, 14, 15]
@test mm.inv_map[6] == [6, 7, 10, 11]
end
end

0 comments on commit 0f4b01b

Please # to comment.