1+ #
2+ # License-Identifier: GPL
3+ #
4+ # Copyright (C) 2024 The Yambo Team
5+ #
6+ # Authors: HPC, FP, RR
7+ #
8+ # This file is part of the yambopy project
9+ #
10+ import os
11+ import numpy as np
12+ from netCDF4 import Dataset
13+
14+ class YamboCollisionDB (object ):
15+ """
16+ Class to handle COLLISION databases from Yambo.
17+
18+ Reads collision matrix elements V_nm(k,k',q) or W_nm(k,k',q)
19+ from ndb.COLLISIONS_HXC or ndb.COLLISIONS_COH databases.
20+
21+ Database structure:
22+ - Header file: ndb.COLLISIONS_{type}_header
23+ Contains COLLISIONS_STATE array: [n, m, k_idx, spin] for each collision
24+ Contains X_X_band_range: band range used in response function
25+
26+ - Main data file: ndb.COLLISIONS_{type}
27+ Contains COLLISIONS_v: collision matrix elements (n_collisions, n_kpts, 2)
28+ Contains N_COLLISIONS_STATES: number of k-points in IBZ
29+
30+ Note: COLLISION databases do NOT have fragment files (unlike em1s databases).
31+ Data is stored for IBZ k-points only. Use expand_to_bz() to get full BZ.
32+
33+ Attributes:
34+ collision_type: 'HXC' for bare Coulomb V, 'COH' for screened W
35+ collision_state: Array (4, n_collisions) with [n, m, k_idx, spin] for each collision
36+ collision_v: Array (n_collisions, n_kpts_ibz) of complex collision matrix elements [Hartree]
37+ n_collisions: Number of unique collision pairs (n,m) combinations
38+ n_kpts_ibz: Number of k-points in IBZ
39+ coll_band_range: [min_band, max_band] actual collision bands (1-based Fortran indexing)
40+ response_band_range: [min_band, max_band] from X_X_band_range (response function)
41+ """
42+
43+ def __init__ (self , path : str = '.' , collision_type : str = 'HXC' ):
44+ """
45+ Initialize COLLISION database reader.
46+
47+ Args:
48+ path: Path to directory containing ndb.COLLISIONS_* files
49+ collision_type: 'HXC' for V_nm (bare), 'COH' for W_nm (screened)
50+ """
51+ self .path = path
52+ self .collision_type = collision_type .upper ()
53+
54+ if self .collision_type not in ['HXC' , 'COH' ]:
55+ raise ValueError (f"collision_type must be 'HXC' or 'COH', got { collision_type } " )
56+
57+ # File names
58+ self .base_name = f'ndb.COLLISIONS_{ self .collision_type } '
59+ header_file = f'{ self .base_name } _header'
60+ main_file = self .base_name
61+
62+ header_path = os .path .join (path , header_file )
63+ main_path = os .path .join (path , main_file )
64+
65+ # Check files exist
66+ if not os .path .isfile (header_path ):
67+ raise FileNotFoundError (f"Header file { header_path } not found" )
68+ if not os .path .isfile (main_path ):
69+ raise FileNotFoundError (f"Main file { main_path } not found" )
70+
71+ # Read header file
72+ print (f"Loading { header_file } ..." )
73+ with Dataset (header_path , 'r' ) as f :
74+ # Read collision states: (4, n_collisions) array
75+ # Each column is [n, m, k_idx, spin] for a collision
76+ self .collision_state = np .array (f .variables ['COLLISIONS_STATE' ][:])
77+ self .n_collisions = self .collision_state .shape [1 ]
78+
79+ # Read response function band range (not necessarily collision band range!)
80+ if 'X_X_band_range' in f .variables :
81+ self .response_band_range = np .array (f .variables ['X_X_band_range' ][:])
82+ else :
83+ self .response_band_range = None
84+
85+ # Determine actual collision band range from collision states
86+ n_bands = self .collision_state [0 , :] # First index band
87+ m_bands = self .collision_state [1 , :] # Second index band
88+ self .coll_band_range = np .array ([
89+ min (n_bands .min (), m_bands .min ()),
90+ max (n_bands .max (), m_bands .max ())
91+ ])
92+ self .n_coll_bands = self .coll_band_range [1 ] - self .coll_band_range [0 ] + 1
93+
94+ # Get unique k-points in IBZ
95+ k_indices = np .unique (self .collision_state [2 , :])
96+ self .ibz_k_indices = k_indices
97+ self .n_kpts_ibz = len (k_indices )
98+
99+ print (f" Collision bands: { self .coll_band_range [0 ]} - { self .coll_band_range [1 ]} (Fortran indexing)" )
100+ print (f" Number of collision pairs: { self .n_collisions } " )
101+ print (f" K-points in IBZ: { self .n_kpts_ibz } " )
102+
103+ # Read main data file
104+ print (f"Loading { main_file } ..." )
105+ with Dataset (main_path , 'r' ) as f :
106+ # Read number of k-points (should match IBZ count)
107+ n_kpts_check = int (f .variables ['N_COLLISIONS_STATES' ][:])
108+ if n_kpts_check != self .n_kpts_ibz :
109+ print (f" Warning: N_COLLISIONS_STATES ({ n_kpts_check } ) != n_kpts_ibz ({ self .n_kpts_ibz } )" )
110+
111+ # Read collision matrix elements
112+ # Shape: (n_collisions, n_kpts, 2) where last dim is [real, imag]
113+ coll_v_raw = np .array (f .variables ['COLLISIONS_v' ][:])
114+
115+ # Convert to complex array: (n_collisions, n_kpts_ibz)
116+ self .collision_v_ibz = coll_v_raw [:, :, 0 ] + 1j * coll_v_raw [:, :, 1 ]
117+
118+ # Initially no expansion to full BZ
119+ self .collision_v_bz = None
120+ self .lattice = None
121+
122+ print (f"✓ COLLISION database loaded successfully" )
123+
124+ def expand_to_bz (self , lattice ):
125+ """
126+ Expand collision matrix elements from IBZ to full BZ using YamboLatticeDB.
127+
128+ Args:
129+ lattice: YamboLatticeDB object with kpoints_indexes defined
130+ """
131+ if not hasattr (lattice , 'kpoints_indexes' ):
132+ raise ValueError ("YamboLatticeDB must have kpoints_indexes. Call expand_kpts() first." )
133+
134+ self .lattice = lattice
135+ n_kpts_bz = len (lattice .kpoints_indexes )
136+
137+ print (f"Expanding collision data from IBZ ({ self .n_kpts_ibz } ) to full BZ ({ n_kpts_bz } )..." )
138+
139+ # Expand: collision_v_bz[icoll, ik_bz] = collision_v_ibz[icoll, kpoints_indexes[ik_bz]]
140+ # Note: kpoints_indexes maps ik_bz -> ik_ibz, and k-points in collision_state are 1-based
141+ # So we need to convert: collision_state k-index (1-based) -> Python index (0-based)
142+
143+ # Build full BZ collision data
144+ self .collision_v_bz = np .zeros ((self .n_collisions , n_kpts_bz ), dtype = np .complex128 )
145+
146+ for icoll in range (self .n_collisions ):
147+ # Get the IBZ k-index for this collision (1-based Fortran)
148+ k_ibz_fortran = self .collision_state [2 , icoll ]
149+ k_ibz_python = k_ibz_fortran - 1 # Convert to 0-based
150+
151+ # This collision is defined at k_ibz_python in the IBZ
152+ # Expand it to all k-points in BZ that map to this IBZ point
153+ for ik_bz in range (n_kpts_bz ):
154+ if lattice .kpoints_indexes [ik_bz ] == k_ibz_python :
155+ self .collision_v_bz [icoll , ik_bz ] = self .collision_v_ibz [icoll , k_ibz_python ]
156+
157+ print (f"✓ Collision data expanded to full BZ" )
158+
159+ def get_collision_by_index (self , icoll : int , ik : int , use_bz : bool = False ) -> complex :
160+ """
161+ Get collision matrix element by collision index and k-point index.
162+
163+ Args:
164+ icoll: Collision pair index (0-based Python indexing)
165+ ik: k-point index (0-based Python indexing)
166+ use_bz: If True, use full BZ data; if False, use IBZ data
167+
168+ Returns:
169+ Matrix element in Hartree units (complex)
170+ """
171+ if icoll < 0 or icoll >= self .n_collisions :
172+ raise ValueError (f"Invalid collision index { icoll } . Must be 0 <= icoll < { self .n_collisions } " )
173+
174+ if use_bz :
175+ if self .collision_v_bz is None :
176+ raise ValueError ("BZ data not available. Call expand_to_bz() first." )
177+ if ik < 0 or ik >= self .collision_v_bz .shape [1 ]:
178+ raise ValueError (f"Invalid k-point index { ik } . Must be 0 <= ik < { self .collision_v_bz .shape [1 ]} " )
179+ return self .collision_v_bz [icoll , ik ]
180+ else :
181+ if ik < 0 or ik >= self .n_kpts_ibz :
182+ raise ValueError (f"Invalid k-point index { ik } . Must be 0 <= ik < { self .n_kpts_ibz } " )
183+ return self .collision_v_ibz [icoll , ik ]
184+
185+ def get_collision (self , n : int , m : int , ik : int , spin : int = 1 , use_bz : bool = False ) -> complex :
186+ """
187+ Get collision matrix element V_nm or W_nm for given bands and k-point.
188+
189+ Args:
190+ n, m: Band indices (1-based Fortran indexing, as in Yambo)
191+ ik: k-point index (1-based Fortran indexing, as in Yambo)
192+ spin: Spin index (1 or 2)
193+ use_bz: If True, search in full BZ; if False, search in IBZ
194+
195+ Returns:
196+ Matrix element in Hartree units (complex)
197+ Returns 0 if collision not found in database
198+ """
199+ # Find collision index matching (n, m, k_idx, spin)
200+ for icoll in range (self .n_collisions ):
201+ state = self .collision_state [:, icoll ]
202+ if (state [0 ] == n and state [1 ] == m and
203+ state [2 ] == ik and state [3 ] == spin ):
204+ # Found matching collision
205+ # Return value at the k-point (convert ik from 1-based to 0-based)
206+ return self .get_collision_by_index (icoll , ik - 1 , use_bz = use_bz )
207+
208+ # Collision not found in database
209+ return 0.0 + 0.0j
210+
211+ def get_collision_state (self , icoll : int ) -> tuple :
212+ """
213+ Get the (n, m, k_idx, spin) tuple for a given collision index.
214+
215+ Args:
216+ icoll: Collision pair index (0-based Python indexing)
217+
218+ Returns:
219+ Tuple (n, m, k_idx, spin) with 1-based Fortran indices
220+ """
221+ if icoll < 0 or icoll >= self .n_collisions :
222+ raise ValueError (f"Invalid collision index { icoll } . Must be 0 <= icoll < { self .n_collisions } " )
223+
224+ state = self .collision_state [:, icoll ]
225+ return tuple (int (x ) for x in state )
226+
227+ def get_collision_array (self , icoll : int , use_bz : bool = False ) -> np .ndarray :
228+ """
229+ Get collision matrix elements at all k-points for a given collision.
230+
231+ Args:
232+ icoll: Collision pair index (0-based Python indexing)
233+ use_bz: If True, return full BZ data; if False, return IBZ data
234+
235+ Returns:
236+ Array (n_kpts,) of complex collision matrix elements
237+ """
238+ if icoll < 0 or icoll >= self .n_collisions :
239+ raise ValueError (f"Invalid collision index { icoll } . Must be 0 <= icoll < { self .n_collisions } " )
240+
241+ if use_bz :
242+ if self .collision_v_bz is None :
243+ raise ValueError ("BZ data not available. Call expand_to_bz() first." )
244+ return self .collision_v_bz [icoll , :]
245+ else :
246+ return self .collision_v_ibz [icoll , :]
247+
248+ def get_collision_matrix_at_k (self , ik : int , use_bz : bool = False ,
249+ band_offset : int = 0 ) -> np .ndarray :
250+ """
251+ Get full collision matrix V_nm(k) or W_nm(k) for all bands at a given k-point.
252+
253+ Useful for integration with WannierYamboInterface where band indexing may differ.
254+
255+ Args:
256+ ik: k-point index (0-based Python indexing)
257+ use_bz: If True, use full BZ data; if False, use IBZ data
258+ band_offset: Offset to apply to band indices for 0-based Python indexing
259+ (e.g., if collision bands are 3-6 in Fortran, use band_offset=-3
260+ to get 0-based Python indices 0-3)
261+
262+ Returns:
263+ V_nm or W_nm matrix (n_coll_bands, n_coll_bands) in Hartree units
264+ """
265+ # Initialize matrix
266+ V_nm = np .zeros ((self .n_coll_bands , self .n_coll_bands ), dtype = np .complex128 )
267+
268+ # Determine which k-point to use
269+ if use_bz :
270+ if self .collision_v_bz is None :
271+ raise ValueError ("BZ data not available. Call expand_to_bz() first." )
272+ k_fortran = ik + 1 # Convert to 1-based
273+ else :
274+ k_fortran = ik + 1 # Convert to 1-based
275+
276+ # Fill matrix
277+ for icoll in range (self .n_collisions ):
278+ n , m , k_coll , spin = self .get_collision_state (icoll )
279+
280+ # Check if this collision is for the requested k-point
281+ if k_coll == k_fortran :
282+ # Convert band indices
283+ n_idx = n - self .coll_band_range [0 ] + band_offset
284+ m_idx = m - self .coll_band_range [0 ] + band_offset
285+
286+ if 0 <= n_idx < self .n_coll_bands and 0 <= m_idx < self .n_coll_bands :
287+ V_nm [n_idx , m_idx ] = self .get_collision_by_index (icoll , ik , use_bz = use_bz )
288+
289+ return V_nm
290+
291+ def __str__ (self ):
292+ lines = []
293+ lines .append ("=" * 70 )
294+ lines .append (f"Yambo COLLISION Database ({ self .collision_type } )" )
295+ lines .append ("=" * 70 )
296+ lines .append (f"Path: { self .path } " )
297+ lines .append (f"Base file: { self .base_name } " )
298+ lines .append (f"Type: { 'Bare Coulomb V_nm' if self .collision_type == 'HXC' else 'Screened W_nm' } " )
299+ lines .append (f"N collision pairs: { self .n_collisions } " )
300+ lines .append (f"Collision bands: { self .coll_band_range [0 ]} - { self .coll_band_range [1 ]} ({ self .n_coll_bands } bands)" )
301+ if self .response_band_range is not None :
302+ lines .append (f"Response bands: { self .response_band_range [0 ]} - { self .response_band_range [1 ]} " )
303+ lines .append (f"K-points (IBZ): { self .n_kpts_ibz } " )
304+ if self .collision_v_bz is not None :
305+ lines .append (f"K-points (Full BZ): { self .collision_v_bz .shape [1 ]} " )
306+ lines .append (f"Expanded to BZ: Yes" )
307+ else :
308+ lines .append (f"Expanded to BZ: No" )
309+ lines .append ("" )
310+ lines .append ("Collision states (first 5):" )
311+ for i in range (min (5 , self .n_collisions )):
312+ n , m , k_idx , spin = self .get_collision_state (i )
313+ value_ibz = self .collision_v_ibz [i , k_idx - 1 ] # k_idx is 1-based
314+ lines .append (f" [{ i :3d} ] bands ({ n } ,{ m } ) k={ k_idx } spin={ spin } | V_ibz = { value_ibz :.6e} " )
315+ if self .n_collisions > 5 :
316+ lines .append (f" ... and { self .n_collisions - 5 } more" )
317+ return "\n " .join (lines )
0 commit comments