Coverage for local_installation/dynasor/correlation_functions.py: 93%
284 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-09-27 15:43 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-09-27 15:43 +0000
1import numba
2import concurrent
3from functools import partial
4from itertools import combinations_with_replacement
5from typing import Tuple
7import numpy as np
8from ase import Atoms
9from numpy.typing import NDArray
11from dynasor.logging_tools import logger
12from dynasor.trajectory import Trajectory, WindowIterator
13from dynasor.sample import DynamicSample, StaticSample
14from dynasor.post_processing import fourier_cos_filon
15from dynasor.core.time_averager import TimeAverager
16from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q
17from dynasor.qpoints.tools import get_index_offset
18from dynasor.units import radians_per_fs_to_meV
21def compute_dynamic_structure_factors(
22 traj: Trajectory,
23 q_points: NDArray[float],
24 dt: float,
25 window_size: int,
26 window_step: int = 1,
27 calculate_currents: bool = False,
28 calculate_incoherent: bool = False,
29) -> DynamicSample:
30 """Compute dynamic structure factors. The results are returned in the
31 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>`
32 object.
34 Parameters
35 ----------
36 traj
37 Input trajectory
38 q_points
39 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
40 dt
41 Time difference in femtoseconds between two consecutive snapshots
42 in the trajectory. Note that you should *not* change :attr:`dt` if you change
43 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
44 window_size
45 Length of the trajectory frame window to use for time correlation calculation.
46 It is expressed in terms of the number of time lags to consider
47 and thus determines the smallest frequency resolved.
48 window_step
49 Window step (or stride) given as the number of frames between consecutive trajectory
50 windows. This parameter does *not* affect the time between consecutive frames in the
51 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not
52 be used.
53 calculate_currents
54 Calculate the current correlations. Requires velocities to be available in :attr:`traj`.
55 calculate_incoherent
56 Calculate the incoherent part (self-part) of :math:`F_incoh`.
57 """
58 # sanity check input args
59 if q_points.shape[1] != 3:
60 raise ValueError('q-points array has the wrong shape.')
61 if dt <= 0:
62 raise ValueError(f'dt must be positive: dt= {dt}')
63 if window_size <= 2:
64 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}')
65 if window_size % 2 != 0:
66 raise ValueError(f'window_size must be even: window_size= {window_size}')
67 if window_step <= 0:
68 raise ValueError(f'window_step must be positive: window_step= {window_step}')
70 # define internal parameters
71 n_qpoints = q_points.shape[0]
72 delta_t = traj.frame_step * dt
73 N_tc = window_size + 1
75 # log all setup information
76 dw = np.pi / (window_size * delta_t)
77 w_max = dw * window_size
78 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency
80 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}')
81 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs')
82 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs')
83 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs')
84 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = '
85 f'{dw * radians_per_fs_to_meV:.3f} meV')
86 logger.info(f'Maximum angular frequency (dw * window_size):'
87 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV')
88 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):'
89 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV')
91 if calculate_currents:
92 logger.info('Calculating current (velocity) correlations')
93 if calculate_incoherent:
94 logger.info('Calculating incoherent part (self-part) of correlations')
96 # log some info regarding q-points
97 logger.info(f'Number of q-points: {n_qpoints}')
99 q_directions = q_points.copy()
100 q_distances = np.linalg.norm(q_points, axis=1)
101 nonzero = q_distances > 0
102 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1)
104 # setup functions to process frames
105 def f2_rho(frame):
106 rho_qs_dict = dict()
107 for atom_type in frame.positions_by_type.keys():
108 x = frame.positions_by_type[atom_type]
109 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
110 frame.rho_qs_dict = rho_qs_dict
111 return frame
113 def f2_rho_and_j(frame):
114 rho_qs_dict = dict()
115 jz_qs_dict = dict()
116 jper_qs_dict = dict()
118 for atom_type in frame.positions_by_type.keys():
119 x = frame.positions_by_type[atom_type]
120 v = frame.velocities_by_type[atom_type]
121 rho_qs, j_qs = calc_rho_j_q(x, v, q_points)
122 jz_qs = np.sum(j_qs * q_directions, axis=1)
123 jper_qs = j_qs - (jz_qs[:, None] * q_directions)
125 rho_qs_dict[atom_type] = rho_qs
126 jz_qs_dict[atom_type] = jz_qs
127 jper_qs_dict[atom_type] = jper_qs
129 frame.rho_qs_dict = rho_qs_dict
130 frame.jz_qs_dict = jz_qs_dict
131 frame.jper_qs_dict = jper_qs_dict
132 return frame
134 if calculate_currents:
135 element_processor = f2_rho_and_j
136 else:
137 element_processor = f2_rho
139 # setup window iterator
140 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step,
141 element_processor=element_processor)
143 # define all pairs
144 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
145 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
146 logger.debug('Considering pairs:')
147 for pair in pairs:
148 logger.debug(f' {pair}')
150 # setup all time averager instances
151 F_q_t_averager = dict()
152 for pair in pairs:
153 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
154 if calculate_currents:
155 Cl_q_t_averager = dict()
156 Ct_q_t_averager = dict()
157 for pair in pairs:
158 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
159 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
160 if calculate_incoherent:
161 F_s_q_t_averager = dict()
162 for pair in traj.atom_types:
163 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
165 # define correlation function
166 def calc_corr(window, time_i):
167 # Calculate correlations between two frames in the window without normalization 1/N
168 f0 = window[0]
169 fi = window[time_i]
170 for s1, s2 in pairs:
171 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate())
172 if s1 != s2:
173 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate())
174 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt)
176 if calculate_currents:
177 for s1, s2 in pairs:
178 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate())
179 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] *
180 fi.jper_qs_dict[s2].conjugate(), axis=1))
181 if s1 != s2:
182 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate())
183 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] *
184 fi.jper_qs_dict[s1].conjugate(), axis=1))
186 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt)
187 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt)
189 if calculate_incoherent:
190 for atom_type in traj.atom_types:
191 xi = fi.positions_by_type[atom_type]
192 x0 = f0.positions_by_type[atom_type]
193 Fsqt = np.real(calc_rho_q(xi - x0, q_points))
194 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt)
196 # run calculation
197 with concurrent.futures.ThreadPoolExecutor() as tpe:
198 # This is the "main loop" over the trajectory
199 for window in window_iterator:
200 logger.debug(f'processing window {window[0].frame_index} to {window[-1].frame_index}')
202 # The map conviniently applies calc_corr to all time-lags. However,
203 # as everything is done in place nothing gets returned so in order
204 # to start and wait for the processes to finish we must iterate
205 # over the None values returned
206 for _ in tpe.map(partial(calc_corr, window), range(len(window))):
207 pass
209 # collect results into dict with numpy arrays (n_qpoints, N_tc)
210 data_dict_corr = dict()
211 time = delta_t * np.arange(N_tc, dtype=float)
212 data_dict_corr['q_points'] = q_points
213 data_dict_corr['time'] = time
215 F_q_t_tot = np.zeros((n_qpoints, N_tc))
216 S_q_w_tot = np.zeros((n_qpoints, N_tc))
217 for pair in pairs:
218 key = '_'.join(pair)
219 F_q_t = 1 / traj.n_atoms * F_q_t_averager[pair].get_average_all()
220 w, S_q_w = fourier_cos_filon(F_q_t, delta_t)
221 S_q_w = np.array(S_q_w)
222 data_dict_corr['omega'] = w
223 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t
224 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w
226 # sum all partials to the total
227 F_q_t_tot += F_q_t
228 S_q_w_tot += S_q_w
229 data_dict_corr['Fqt_coh'] = F_q_t_tot
230 data_dict_corr['Sqw_coh'] = S_q_w_tot
232 if calculate_currents:
233 Cl_q_t_tot = np.zeros((n_qpoints, N_tc))
234 Ct_q_t_tot = np.zeros((n_qpoints, N_tc))
235 Cl_q_w_tot = np.zeros((n_qpoints, N_tc))
236 Ct_q_w_tot = np.zeros((n_qpoints, N_tc))
237 for pair in pairs:
238 key = '_'.join(pair)
239 Cl_q_t = 1 / traj.n_atoms * Cl_q_t_averager[pair].get_average_all()
240 Ct_q_t = 1 / traj.n_atoms * Ct_q_t_averager[pair].get_average_all()
241 _, Cl_q_w = fourier_cos_filon(Cl_q_t, delta_t)
242 _, Ct_q_w = fourier_cos_filon(Ct_q_t, delta_t)
243 data_dict_corr[f'Clqt_{key}'] = Cl_q_t
244 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t
245 data_dict_corr[f'Clqw_{key}'] = Cl_q_w
246 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w
248 # sum all partials to the total
249 Cl_q_t_tot += Cl_q_t
250 Ct_q_t_tot += Ct_q_t
251 Cl_q_w_tot += Cl_q_w
252 Ct_q_w_tot += Ct_q_w
253 data_dict_corr['Clqt'] = Cl_q_t_tot
254 data_dict_corr['Ctqt'] = Ct_q_t_tot
255 data_dict_corr['Clqw'] = Cl_q_w_tot
256 data_dict_corr['Ctqw'] = Ct_q_w_tot
258 if calculate_incoherent:
259 Fs_q_t_tot = np.zeros((n_qpoints, N_tc))
260 Ss_q_w_tot = np.zeros((n_qpoints, N_tc))
261 for atom_type in traj.atom_types:
262 Fs_q_t = 1 / traj.n_atoms * F_s_q_t_averager[atom_type].get_average_all()
263 _, Ss_q_w = fourier_cos_filon(Fs_q_t, delta_t)
264 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t
265 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w
267 # sum all partials to the total
268 Fs_q_t_tot += Fs_q_t
269 Ss_q_w_tot += Ss_q_w
271 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot
272 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot
274 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'] + data_dict_corr['Fqt_incoh']
275 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'] + data_dict_corr['Sqw_incoh']
276 else:
277 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'].copy()
278 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'].copy()
280 # finalize results with additional meta data
281 result = DynamicSample(data_dict_corr, atom_types=traj.atom_types, pairs=pairs,
282 particle_counts=particle_counts, cell=traj.cell,
283 time_between_frames=delta_t,
284 maximum_time_lag=delta_t * window_size,
285 angular_frequency_resolution=dw,
286 maximum_angular_frequency=w_max,
287 number_of_frames=traj.number_of_frames_read)
289 return result
292def compute_static_structure_factors(
293 traj: Trajectory,
294 q_points: NDArray[float],
295) -> StaticSample:
296 r"""Compute static structure factors. The results are returned in the
297 form of a :class:`StaticSample <dynasor.sample.StaticSample>`
298 object.
300 Parameters
301 ----------
302 traj
303 Input trajectory
304 q_points
305 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
306 """
307 # sanity check input args
308 if q_points.shape[1] != 3:
309 raise ValueError('q-points array has the wrong shape.')
311 n_qpoints = q_points.shape[0]
312 logger.info(f'Number of q-points: {n_qpoints}')
314 # define all pairs
315 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
316 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
317 logger.debug('Considering pairs:')
318 for pair in pairs:
319 logger.debug(f' {pair}')
321 # processing function
322 def f2_rho(frame):
323 rho_qs_dict = dict()
324 for atom_type in frame.positions_by_type.keys():
325 x = frame.positions_by_type[atom_type]
326 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
327 frame.rho_qs_dict = rho_qs_dict
328 return frame
330 # setup averager
331 Sq_averager = dict()
332 for pair in pairs:
333 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0
335 # main loop
336 for frame in traj:
338 # process_frame
339 f2_rho(frame)
340 logger.debug(f'Processing frame {frame.frame_index}')
342 for s1, s2 in pairs:
343 # compute correlation
344 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate())
345 if s1 != s2:
346 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate())
347 Sq_averager[(s1, s2)].add_sample(0, Sq_pair)
349 # collect results
350 data_dict = dict()
351 data_dict['q_points'] = q_points
352 S_q_tot = np.zeros((n_qpoints, 1))
353 for s1, s2 in pairs:
354 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1)
355 data_dict[f'Sq_{s1}_{s2}'] = Sq
356 S_q_tot += Sq
357 data_dict['Sq'] = S_q_tot
359 # finalize results
360 result = StaticSample(data_dict, atom_types=traj.atom_types, pairs=pairs,
361 particle_counts=particle_counts, cell=traj.cell,
362 number_of_frames=traj.number_of_frames_read)
363 return result
366def compute_spectral_energy_density(
367 traj: Trajectory,
368 ideal_supercell: Atoms,
369 primitive_cell: Atoms,
370 q_points: NDArray[float],
371 dt: float,
372) -> Tuple[NDArray[float], NDArray[float]]:
373 r"""
374 Compute the spectral energy density (SED) at specific q-points.
375 The results are returned in the form of a tuple, which comprises the
376 angular frequencies in an array of length ``N_times`` in units of 2π/fs
377 and the SED in units of Da*(Å/fs)² as an array of shape ``(N_qpoints, N_times)``.
379 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010),
380 which should be cited when using this function along with the dynasor reference.
382 **Note 1:**
383 SED analysis is only suitable for crystalline materials without diffusion as
384 atoms are assumed to move around fixed reference positions throughout the entire trajectory.
386 **Note 2:**
387 This implementation reads the full trajectory and can thus consume a lot of memory.
389 Parameters
390 ----------
391 traj
392 Input trajectory
393 ideal_supercell
394 Ideal structure defining the reference positions
395 primitive_cell
396 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`.
397 q_points
398 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
399 dt
400 Time difference in femtoseconds between two consecutive snapshots in
401 the trajectory. Note that you should not change :attr:`dt` if you change
402 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
403 """
405 delta_t = traj.frame_step * dt
407 # logger
408 logger.info('Running SED')
409 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs')
410 logger.info(f'Number of atoms in primitive_cell: {len(primitive_cell)}')
411 logger.info(f'Number of atoms in ideal_supercell: {len(ideal_supercell)}')
412 logger.info(f'Number of q-points: {q_points.shape[0]}')
414 # check that the ideal supercell agrees with traj
415 if traj.n_atoms != len(ideal_supercell):
416 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.')
418 if len(primitive_cell) >= len(ideal_supercell): 418 ↛ 419line 418 didn't jump to line 419, because the condition on line 418 was never true
419 raise ValueError('primitive_cell contains more atoms than ideal_supercell.')
421 # colllect all velocities, and scale with sqrt(masses)
422 masses = ideal_supercell.get_masses().reshape(-1, 1)
423 velocities = []
424 for it, frame in enumerate(traj):
425 logger.debug(f'Reading frame {it}')
426 if frame.velocities_by_type is None:
427 raise ValueError(f'Could not read velocities from frame {it}')
428 v = frame.get_velocities_as_array(traj.atomic_indices)
429 velocities.append(np.sqrt(masses) * v)
430 logger.info(f'Number of snapshots: {len(velocities)}')
432 # Perform the FFT on the last axis for extra speed (maybe not needed)
433 N_samples = len(velocities)
434 velocities = np.array(velocities)
435 # places time index last and makes a copy for continuity
436 velocities = velocities.transpose(1, 2, 0).copy()
437 # #atoms in supercell x 3 directions x #frequencies
438 velocities = np.fft.rfft(velocities, axis=2)
440 # Calcualte indices and offsets needed for the sed method
441 indices, offsets = get_index_offset(ideal_supercell, primitive_cell)
443 # Phase factor for use in FT. #qpoints x #atoms in supercell
444 cell_positions = np.dot(offsets, primitive_cell.cell)
445 phase = np.dot(q_points, cell_positions.T) # #qpoints x #unit cells
446 phase_factors = np.exp(1.0j * phase)
448 # This dict maps the offsets to an index so ndarrays can be over
449 # offset,index instead of atoms in supercell
450 offset_dict = {off: n for n, off in enumerate(set(tuple(offset) for offset in offsets))}
452 # Pick out some shapes
453 n_super, _, n_w = velocities.shape
454 n_qpts = len(q_points)
455 n_prim = len(primitive_cell)
456 n_offsets = len(offset_dict)
458 # This new array will be indexed by index and offset instead (and also transposed)
459 new_velocities = np.zeros((n_w, 3, n_prim, n_offsets), dtype=velocities.dtype)
461 for i in range(n_super):
462 j = indices[i] # atom with index i in the supercell is of basis type j ...
463 n = offset_dict[tuple(offsets[i])] # and its offset has index n
464 new_velocities[:, :, j, n] = velocities[i].T
466 velocities = new_velocities
468 # Same story with the spatial phase factors
469 new_phase_factors = np.zeros((n_qpts, n_prim, n_offsets), dtype=phase_factors.dtype)
471 for i in range(n_super):
472 j = indices[i]
473 n = offset_dict[tuple(offsets[i])]
474 new_phase_factors[:, j, n] = phase_factors[:, i]
476 phase_factors = new_phase_factors
478 # calcualte the density in a numba function
479 density = _sed_inner_loop(phase_factors, velocities)
481 # angular frequencies
482 w = 2 * np.pi * np.fft.rfftfreq(N_samples, delta_t) # radians/fs
484 return w, density
487@numba.njit(parallel=True, fastmath=True)
488def _sed_inner_loop(phase_factors, velocities):
489 """This numba function calculates the spatial FT using precomputed phase factors
491 As the use case can be one or many q-points the parallelization is over the
492 temporal frequency components instead.
493 """
495 n_qpts = phase_factors.shape[0] # q-point index
496 n_prim = phase_factors.shape[1] # basis atom index
497 n_super = phase_factors.shape[2] # unit cell index
499 n_freqs = velocities.shape[0] # frequency, direction, basis atom, unit cell
501 density = np.zeros((n_qpts, n_freqs), dtype=np.float64)
503 for w in numba.prange(n_freqs):
504 for k in range(n_qpts):
505 for a in range(3):
506 for b in range(n_prim):
507 tmp = 0.0j
508 for n in range(n_super):
509 tmp += phase_factors[k, b, n] * velocities[w, a, b, n]
510 density[k, w] += np.abs(tmp)**2
511 return density