Coverage for local_installation/dynasor/correlation_functions.py: 93%
290 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-04-16 06:13 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-04-16 06:13 +0000
1import numba
2import concurrent
3from functools import partial
4from itertools import combinations_with_replacement
5from typing import Tuple
7import numpy as np
8from ase import Atoms
9from ase.units import fs
10from numpy.typing import NDArray
12from dynasor.logging_tools import logger
13from dynasor.trajectory import Trajectory, WindowIterator
14from dynasor.sample import DynamicSample, StaticSample
15from dynasor.post_processing import fourier_cos_filon
16from dynasor.core.time_averager import TimeAverager
17from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q
18from dynasor.qpoints.tools import get_index_offset
19from dynasor.units import radians_per_fs_to_meV
22def compute_dynamic_structure_factors(
23 traj: Trajectory,
24 q_points: NDArray[float],
25 dt: float,
26 window_size: int,
27 window_step: int = 1,
28 calculate_currents: bool = False,
29 calculate_incoherent: bool = False,
30) -> DynamicSample:
31 """Compute dynamic structure factors. The results are returned in the
32 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>`
33 object.
35 Parameters
36 ----------
37 traj
38 Input trajectory
39 q_points
40 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
41 dt
42 Time difference in femtoseconds between two consecutive snapshots
43 in the trajectory. Note that you should *not* change :attr:`dt` if you change
44 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
45 window_size
46 Length of the trajectory frame window to use for time correlation calculation.
47 It is expressed in terms of the number of time lags to consider
48 and thus determines the smallest frequency resolved.
49 window_step
50 Window step (or stride) given as the number of frames between consecutive trajectory
51 windows. This parameter does *not* affect the time between consecutive frames in the
52 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not
53 be used.
54 calculate_currents
55 Calculate the current correlations. Requires velocities to be available in :attr:`traj`.
56 calculate_incoherent
57 Calculate the incoherent part (self-part) of :math:`F_incoh`.
58 """
59 # sanity check input args
60 if q_points.shape[1] != 3:
61 raise ValueError('q-points array has the wrong shape.')
62 if dt <= 0:
63 raise ValueError(f'dt must be positive: dt= {dt}')
64 if window_size <= 2:
65 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}')
66 if window_size % 2 != 0:
67 raise ValueError(f'window_size must be even: window_size= {window_size}')
68 if window_step <= 0:
69 raise ValueError(f'window_step must be positive: window_step= {window_step}')
71 # define internal parameters
72 n_qpoints = q_points.shape[0]
73 delta_t = traj.frame_step * dt
74 N_tc = window_size + 1
76 # log all setup information
77 dw = np.pi / (window_size * delta_t)
78 w_max = dw * window_size
79 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency
81 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}')
82 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs')
83 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs')
84 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs')
85 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = '
86 f'{dw * radians_per_fs_to_meV:.3f} meV')
87 logger.info(f'Maximum angular frequency (dw * window_size):'
88 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV')
89 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):'
90 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV')
92 if calculate_currents:
93 logger.info('Calculating current (velocity) correlations')
94 if calculate_incoherent:
95 logger.info('Calculating incoherent part (self-part) of correlations')
97 # log some info regarding q-points
98 logger.info(f'Number of q-points: {n_qpoints}')
100 q_directions = q_points.copy()
101 q_distances = np.linalg.norm(q_points, axis=1)
102 nonzero = q_distances > 0
103 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1)
105 # setup functions to process frames
106 def f2_rho(frame):
107 rho_qs_dict = dict()
108 for atom_type in frame.positions_by_type.keys():
109 x = frame.positions_by_type[atom_type]
110 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
111 frame.rho_qs_dict = rho_qs_dict
112 return frame
114 def f2_rho_and_j(frame):
115 rho_qs_dict = dict()
116 jz_qs_dict = dict()
117 jper_qs_dict = dict()
119 for atom_type in frame.positions_by_type.keys():
120 x = frame.positions_by_type[atom_type]
121 v = frame.velocities_by_type[atom_type]
122 rho_qs, j_qs = calc_rho_j_q(x, v, q_points)
123 jz_qs = np.sum(j_qs * q_directions, axis=1)
124 jper_qs = j_qs - (jz_qs[:, None] * q_directions)
126 rho_qs_dict[atom_type] = rho_qs
127 jz_qs_dict[atom_type] = jz_qs
128 jper_qs_dict[atom_type] = jper_qs
130 frame.rho_qs_dict = rho_qs_dict
131 frame.jz_qs_dict = jz_qs_dict
132 frame.jper_qs_dict = jper_qs_dict
133 return frame
135 if calculate_currents:
136 element_processor = f2_rho_and_j
137 else:
138 element_processor = f2_rho
140 # setup window iterator
141 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step,
142 element_processor=element_processor)
144 # define all pairs
145 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
146 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
147 logger.debug('Considering pairs:')
148 for pair in pairs:
149 logger.debug(f' {pair}')
151 # setup all time averager instances
152 F_q_t_averager = dict()
153 for pair in pairs:
154 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
155 if calculate_currents:
156 Cl_q_t_averager = dict()
157 Ct_q_t_averager = dict()
158 for pair in pairs:
159 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
160 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
161 if calculate_incoherent:
162 F_s_q_t_averager = dict()
163 for pair in traj.atom_types:
164 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
166 # define correlation function
167 def calc_corr(window, time_i):
168 # Calculate correlations between two frames in the window without normalization 1/N
169 f0 = window[0]
170 fi = window[time_i]
171 for s1, s2 in pairs:
172 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate())
173 if s1 != s2:
174 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate())
175 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt)
177 if calculate_currents:
178 for s1, s2 in pairs:
179 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate())
180 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] *
181 fi.jper_qs_dict[s2].conjugate(), axis=1))
182 if s1 != s2:
183 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate())
184 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] *
185 fi.jper_qs_dict[s1].conjugate(), axis=1))
187 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt)
188 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt)
190 if calculate_incoherent:
191 for atom_type in traj.atom_types:
192 xi = fi.positions_by_type[atom_type]
193 x0 = f0.positions_by_type[atom_type]
194 Fsqt = np.real(calc_rho_q(xi - x0, q_points))
195 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt)
197 # run calculation
198 logging_interval = 1000
199 with concurrent.futures.ThreadPoolExecutor() as tpe:
200 # This is the "main loop" over the trajectory
201 for window in window_iterator:
202 logger.debug(f'Processing window {window[0].frame_index} to {window[-1].frame_index}')
204 if window[0].frame_index % logging_interval == 0:
205 logger.info(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') # noqa
207 # The map conviniently applies calc_corr to all time-lags. However,
208 # as everything is done in place nothing gets returned so in order
209 # to start and wait for the processes to finish we must iterate
210 # over the None values returned
211 for _ in tpe.map(partial(calc_corr, window), range(len(window))):
212 pass
214 # collect results into dict with numpy arrays (n_qpoints, N_tc)
215 data_dict_corr = dict()
216 time = delta_t * np.arange(N_tc, dtype=float)
217 data_dict_corr['q_points'] = q_points
218 data_dict_corr['time'] = time
220 F_q_t_tot = np.zeros((n_qpoints, N_tc))
221 S_q_w_tot = np.zeros((n_qpoints, N_tc))
222 for pair in pairs:
223 key = '_'.join(pair)
224 F_q_t = 1 / traj.n_atoms * F_q_t_averager[pair].get_average_all()
225 w, S_q_w = fourier_cos_filon(F_q_t, delta_t)
226 S_q_w = np.array(S_q_w)
227 data_dict_corr['omega'] = w
228 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t
229 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w
231 # sum all partials to the total
232 F_q_t_tot += F_q_t
233 S_q_w_tot += S_q_w
234 data_dict_corr['Fqt_coh'] = F_q_t_tot
235 data_dict_corr['Sqw_coh'] = S_q_w_tot
237 if calculate_currents:
238 Cl_q_t_tot = np.zeros((n_qpoints, N_tc))
239 Ct_q_t_tot = np.zeros((n_qpoints, N_tc))
240 Cl_q_w_tot = np.zeros((n_qpoints, N_tc))
241 Ct_q_w_tot = np.zeros((n_qpoints, N_tc))
242 for pair in pairs:
243 key = '_'.join(pair)
244 Cl_q_t = 1 / traj.n_atoms * Cl_q_t_averager[pair].get_average_all()
245 Ct_q_t = 1 / traj.n_atoms * Ct_q_t_averager[pair].get_average_all()
246 _, Cl_q_w = fourier_cos_filon(Cl_q_t, delta_t)
247 _, Ct_q_w = fourier_cos_filon(Ct_q_t, delta_t)
248 data_dict_corr[f'Clqt_{key}'] = Cl_q_t
249 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t
250 data_dict_corr[f'Clqw_{key}'] = Cl_q_w
251 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w
253 # sum all partials to the total
254 Cl_q_t_tot += Cl_q_t
255 Ct_q_t_tot += Ct_q_t
256 Cl_q_w_tot += Cl_q_w
257 Ct_q_w_tot += Ct_q_w
258 data_dict_corr['Clqt'] = Cl_q_t_tot
259 data_dict_corr['Ctqt'] = Ct_q_t_tot
260 data_dict_corr['Clqw'] = Cl_q_w_tot
261 data_dict_corr['Ctqw'] = Ct_q_w_tot
263 if calculate_incoherent:
264 Fs_q_t_tot = np.zeros((n_qpoints, N_tc))
265 Ss_q_w_tot = np.zeros((n_qpoints, N_tc))
266 for atom_type in traj.atom_types:
267 Fs_q_t = 1 / traj.n_atoms * F_s_q_t_averager[atom_type].get_average_all()
268 _, Ss_q_w = fourier_cos_filon(Fs_q_t, delta_t)
269 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t
270 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w
272 # sum all partials to the total
273 Fs_q_t_tot += Fs_q_t
274 Ss_q_w_tot += Ss_q_w
276 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot
277 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot
279 # finalize results with additional meta data
280 result = DynamicSample(data_dict_corr, atom_types=traj.atom_types, pairs=pairs,
281 particle_counts=particle_counts, cell=traj.cell,
282 time_between_frames=delta_t,
283 maximum_time_lag=delta_t * window_size,
284 angular_frequency_resolution=dw,
285 maximum_angular_frequency=w_max,
286 number_of_frames=traj.number_of_frames_read)
288 return result
291def compute_static_structure_factors(
292 traj: Trajectory,
293 q_points: NDArray[float],
294) -> StaticSample:
295 r"""Compute static structure factors. The results are returned in the
296 form of a :class:`StaticSample <dynasor.sample.StaticSample>`
297 object.
299 Parameters
300 ----------
301 traj
302 Input trajectory
303 q_points
304 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
305 """
306 # sanity check input args
307 if q_points.shape[1] != 3:
308 raise ValueError('q-points array has the wrong shape.')
310 n_qpoints = q_points.shape[0]
311 logger.info(f'Number of q-points: {n_qpoints}')
313 # define all pairs
314 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
315 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
316 logger.debug('Considering pairs:')
317 for pair in pairs:
318 logger.debug(f' {pair}')
320 # processing function
321 def f2_rho(frame):
322 rho_qs_dict = dict()
323 for atom_type in frame.positions_by_type.keys():
324 x = frame.positions_by_type[atom_type]
325 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
326 frame.rho_qs_dict = rho_qs_dict
327 return frame
329 # setup averager
330 Sq_averager = dict()
331 for pair in pairs:
332 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0
334 # main loop
335 for frame in traj:
337 # process_frame
338 f2_rho(frame)
339 logger.debug(f'Processing frame {frame.frame_index}')
341 for s1, s2 in pairs:
342 # compute correlation
343 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate())
344 if s1 != s2:
345 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate())
346 Sq_averager[(s1, s2)].add_sample(0, Sq_pair)
348 # collect results
349 data_dict = dict()
350 data_dict['q_points'] = q_points
351 S_q_tot = np.zeros((n_qpoints, 1))
352 for s1, s2 in pairs:
353 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1)
354 data_dict[f'Sq_{s1}_{s2}'] = Sq
355 S_q_tot += Sq
356 data_dict['Sq'] = S_q_tot
358 # finalize results
359 result = StaticSample(data_dict, atom_types=traj.atom_types, pairs=pairs,
360 particle_counts=particle_counts, cell=traj.cell,
361 number_of_frames=traj.number_of_frames_read)
362 return result
365def compute_spectral_energy_density(
366 traj: Trajectory,
367 ideal_supercell: Atoms,
368 primitive_cell: Atoms,
369 q_points: NDArray[float],
370 dt: float,
371 partial: bool = False
372) -> Tuple[NDArray[float], NDArray[float]]:
373 r"""
374 Compute the spectral energy density (SED) at specific q-points. The results
375 are returned in the form of a tuple, which comprises the angular
376 frequencies in an array of length ``N_times`` in units of rad/fs and the
377 SED in units of eV/(rad/fs) as an array of shape ``(N_qpoints, N_times)``.
378 The normalization is chosen such that integrating the SED of a q-point
379 together with the supplied angular frequenceies omega (rad/fs) yields
380 1/2kBT * number of bands (where number of bands = len(prim) * 3)
382 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010),
383 which should be cited when using this function along with the dynasor reference.
385 **Note 1:**
386 SED analysis is only suitable for crystalline materials without diffusion as
387 atoms are assumed to move around fixed reference positions throughout the entire trajectory.
389 **Note 2:**
390 This implementation reads the full trajectory and can thus consume a lot of memory.
392 Parameters
393 ----------
394 traj
395 Input trajectory
396 ideal_supercell
397 Ideal structure defining the reference positions. Do not change the
398 masses in the ASE atoms objects to dynasor internal units, this will be
399 done internally
400 primitive_cell
401 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`.
402 q_points
403 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
404 dt
405 Time difference in femtoseconds between two consecutive snapshots in
406 the trajectory. Note that you should not change :attr:`dt` if you change
407 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
408 partial
409 If True the SED will be returned decomposed per basis and Cartesian direction.
410 The shape is ``(N_qpoints, N_frequencies, len(primitive_cell), 3)``
411 """
413 delta_t = traj.frame_step * dt
415 # logger
416 logger.info('Running SED')
417 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs')
418 logger.info(f'Number of atoms in primitive_cell: {len(primitive_cell)}')
419 logger.info(f'Number of atoms in ideal_supercell: {len(ideal_supercell)}')
420 logger.info(f'Number of q-points: {q_points.shape[0]}')
422 # check that the ideal supercell agrees with traj
423 if traj.n_atoms != len(ideal_supercell):
424 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.')
426 if len(primitive_cell) >= len(ideal_supercell): 426 ↛ 427line 426 didn't jump to line 427, because the condition on line 426 was never true
427 raise ValueError('primitive_cell contains more atoms than ideal_supercell.')
429 # colllect all velocities, and scale with sqrt(masses)
430 masses = ideal_supercell.get_masses().reshape(-1, 1) / fs**2 # From Dalton to dmu
431 velocities = []
432 for it, frame in enumerate(traj):
433 logger.debug(f'Reading frame {it}')
434 if frame.velocities_by_type is None:
435 raise ValueError(f'Could not read velocities from frame {it}')
436 v = frame.get_velocities_as_array(traj.atomic_indices) # in Å/fs
437 velocities.append(np.sqrt(masses) * v)
438 logger.info(f'Number of snapshots: {len(velocities)}')
440 # Perform the FFT on the last axis for extra speed (maybe not needed)
441 N_samples = len(velocities)
442 velocities = np.array(velocities)
443 # places time index last and makes a copy for continuity
444 velocities = velocities.transpose(1, 2, 0).copy()
445 # #atoms in supercell x 3 directions x #frequencies
446 velocities = np.fft.rfft(velocities, axis=2)
448 # Calcualte indices and offsets needed for the sed method
449 indices, offsets = get_index_offset(ideal_supercell, primitive_cell)
451 # Phase factor for use in FT. #qpoints x #atoms in supercell
452 cell_positions = np.dot(offsets, primitive_cell.cell)
453 phase = np.dot(q_points, cell_positions.T) # #qpoints x #unit cells
454 phase_factors = np.exp(1.0j * phase)
456 # This dict maps the offsets to an index so ndarrays can be over
457 # offset,index instead of atoms in supercell
458 offset_dict = {off: n for n, off in enumerate(set(tuple(offset) for offset in offsets))}
460 # Pick out some shapes
461 n_super, _, n_w = velocities.shape
462 n_qpts = len(q_points)
463 n_prim = len(primitive_cell)
464 n_offsets = len(offset_dict)
466 # This new array will be indexed by index and offset instead (and also transposed)
467 new_velocities = np.zeros((n_w, 3, n_prim, n_offsets), dtype=velocities.dtype)
469 for i in range(n_super):
470 j = indices[i] # atom with index i in the supercell is of basis type j ...
471 n = offset_dict[tuple(offsets[i])] # and its offset has index n
472 new_velocities[:, :, j, n] = velocities[i].T
474 velocities = new_velocities
476 # Same story with the spatial phase factors
477 new_phase_factors = np.zeros((n_qpts, n_prim, n_offsets), dtype=phase_factors.dtype)
479 for i in range(n_super):
480 j = indices[i]
481 n = offset_dict[tuple(offsets[i])]
482 new_phase_factors[:, j, n] = phase_factors[:, i]
484 phase_factors = new_phase_factors
486 # calcualte the density in a numba function
487 density = _sed_inner_loop(phase_factors, velocities)
489 if not partial: 489 ↛ 496line 489 didn't jump to line 496, because the condition on line 489 was never false
490 density = np.sum(density, axis=(2, 3))
492 # units
493 # make so that the velocities were originally in Angstrom / fs to be compatible with eV and Da
495 # the time delta in the fourier transform
496 density = density * delta_t**2
498 # Divide by the length of the time signal
499 density = density / (N_samples * delta_t)
501 # Divide by the number of primitive cells
502 density = density / (n_super / n_prim)
504 # Factor so the sed can be integrated together with the returned omega
505 # numpy fft works with ordinary/linear frequencies and not angular freqs
506 density = density / (2*np.pi)
508 # angular frequencies
509 w = 2 * np.pi * np.fft.rfftfreq(N_samples, delta_t) # rad/fs
511 return w, density
514@numba.njit(parallel=True, fastmath=True)
515def _sed_inner_loop(phase_factors, velocities):
516 """This numba function calculates the spatial FT using precomputed phase factors
518 As the use case can be one or many q-points the parallelization is over the
519 temporal frequency components instead.
520 """
522 n_qpts = phase_factors.shape[0] # q-point index
523 n_prim = phase_factors.shape[1] # basis atom index
524 n_super = phase_factors.shape[2] # unit cell index
526 n_freqs = velocities.shape[0] # frequency, direction, basis atom, unit cell
528 density = np.zeros((n_qpts, n_freqs, n_prim, 3), dtype=np.float64)
530 for w in numba.prange(n_freqs):
531 for k in range(n_qpts):
532 for a in range(3):
533 for b in range(n_prim):
534 tmp = 0.0j
535 for n in range(n_super):
536 tmp += phase_factors[k, b, n] * velocities[w, a, b, n]
537 density[k, w, b, a] += np.abs(tmp)**2
538 return density