Coverage for local_installation/dynasor/correlation_functions.py: 93%
294 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-12-21 12:02 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-12-21 12:02 +0000
1import numba
2import concurrent
3from functools import partial
4from itertools import combinations_with_replacement
5from typing import Tuple
7import numpy as np
8from ase import Atoms
9from ase.units import fs
10from numpy.typing import NDArray
12from dynasor.logging_tools import logger
13from dynasor.trajectory import Trajectory, WindowIterator
14from dynasor.sample import DynamicSample, StaticSample
15from dynasor.post_processing import fourier_cos_filon
16from dynasor.core.time_averager import TimeAverager
17from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q
18from dynasor.qpoints.tools import get_index_offset
19from dynasor.units import radians_per_fs_to_meV
22def compute_dynamic_structure_factors(
23 traj: Trajectory,
24 q_points: NDArray[float],
25 dt: float,
26 window_size: int,
27 window_step: int = 1,
28 calculate_currents: bool = False,
29 calculate_incoherent: bool = False,
30) -> DynamicSample:
31 """Compute dynamic structure factors. The results are returned in the
32 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>`
33 object.
35 Parameters
36 ----------
37 traj
38 Input trajectory
39 q_points
40 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
41 dt
42 Time difference in femtoseconds between two consecutive snapshots
43 in the trajectory. Note that you should *not* change :attr:`dt` if you change
44 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
45 window_size
46 Length of the trajectory frame window to use for time correlation calculation.
47 It is expressed in terms of the number of time lags to consider
48 and thus determines the smallest frequency resolved.
49 window_step
50 Window step (or stride) given as the number of frames between consecutive trajectory
51 windows. This parameter does *not* affect the time between consecutive frames in the
52 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not
53 be used.
54 calculate_currents
55 Calculate the current correlations. Requires velocities to be available in :attr:`traj`.
56 calculate_incoherent
57 Calculate the incoherent part (self-part) of :math:`F_incoh`.
58 """
59 # sanity check input args
60 if q_points.shape[1] != 3:
61 raise ValueError('q-points array has the wrong shape.')
62 if dt <= 0:
63 raise ValueError(f'dt must be positive: dt= {dt}')
64 if window_size <= 2:
65 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}')
66 if window_size % 2 != 0:
67 raise ValueError(f'window_size must be even: window_size= {window_size}')
68 if window_step <= 0:
69 raise ValueError(f'window_step must be positive: window_step= {window_step}')
71 # define internal parameters
72 n_qpoints = q_points.shape[0]
73 delta_t = traj.frame_step * dt
74 N_tc = window_size + 1
76 # log all setup information
77 dw = np.pi / (window_size * delta_t)
78 w_max = dw * window_size
79 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency
81 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}')
82 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs')
83 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs')
84 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs')
85 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = '
86 f'{dw * radians_per_fs_to_meV:.3f} meV')
87 logger.info(f'Maximum angular frequency (dw * window_size):'
88 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV')
89 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):'
90 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV')
92 if calculate_currents:
93 logger.info('Calculating current (velocity) correlations')
94 if calculate_incoherent:
95 logger.info('Calculating incoherent part (self-part) of correlations')
97 # log some info regarding q-points
98 logger.info(f'Number of q-points: {n_qpoints}')
100 q_directions = q_points.copy()
101 q_distances = np.linalg.norm(q_points, axis=1)
102 nonzero = q_distances > 0
103 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1)
105 # setup functions to process frames
106 def f2_rho(frame):
107 rho_qs_dict = dict()
108 for atom_type in frame.positions_by_type.keys():
109 x = frame.positions_by_type[atom_type]
110 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
111 frame.rho_qs_dict = rho_qs_dict
112 return frame
114 def f2_rho_and_j(frame):
115 rho_qs_dict = dict()
116 jz_qs_dict = dict()
117 jper_qs_dict = dict()
119 for atom_type in frame.positions_by_type.keys():
120 x = frame.positions_by_type[atom_type]
121 v = frame.velocities_by_type[atom_type]
122 rho_qs, j_qs = calc_rho_j_q(x, v, q_points)
123 jz_qs = np.sum(j_qs * q_directions, axis=1)
124 jper_qs = j_qs - (jz_qs[:, None] * q_directions)
126 rho_qs_dict[atom_type] = rho_qs
127 jz_qs_dict[atom_type] = jz_qs
128 jper_qs_dict[atom_type] = jper_qs
130 frame.rho_qs_dict = rho_qs_dict
131 frame.jz_qs_dict = jz_qs_dict
132 frame.jper_qs_dict = jper_qs_dict
133 return frame
135 if calculate_currents:
136 element_processor = f2_rho_and_j
137 else:
138 element_processor = f2_rho
140 # setup window iterator
141 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step,
142 element_processor=element_processor)
144 # define all pairs
145 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
146 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
147 logger.debug('Considering pairs:')
148 for pair in pairs:
149 logger.debug(f' {pair}')
151 # setup all time averager instances
152 F_q_t_averager = dict()
153 for pair in pairs:
154 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
155 if calculate_currents:
156 Cl_q_t_averager = dict()
157 Ct_q_t_averager = dict()
158 for pair in pairs:
159 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
160 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
161 if calculate_incoherent:
162 F_s_q_t_averager = dict()
163 for pair in traj.atom_types:
164 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
166 # define correlation function
167 def calc_corr(window, time_i):
168 # Calculate correlations between two frames in the window without normalization 1/N
169 f0 = window[0]
170 fi = window[time_i]
171 for s1, s2 in pairs:
172 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate())
173 if s1 != s2:
174 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate())
175 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt)
177 if calculate_currents:
178 for s1, s2 in pairs:
179 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate())
180 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] *
181 fi.jper_qs_dict[s2].conjugate(), axis=1))
182 if s1 != s2:
183 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate())
184 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] *
185 fi.jper_qs_dict[s1].conjugate(), axis=1))
187 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt)
188 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt)
190 if calculate_incoherent:
191 for atom_type in traj.atom_types:
192 xi = fi.positions_by_type[atom_type]
193 x0 = f0.positions_by_type[atom_type]
194 Fsqt = np.real(calc_rho_q(xi - x0, q_points))
195 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt)
197 # run calculation
198 logging_interval = 1000
199 with concurrent.futures.ThreadPoolExecutor() as tpe:
200 # This is the "main loop" over the trajectory
201 for window in window_iterator:
202 logger.debug(f'Processing window {window[0].frame_index} to {window[-1].frame_index}')
204 if window[0].frame_index % logging_interval == 0:
205 logger.info(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') # noqa
207 # The map conviniently applies calc_corr to all time-lags. However,
208 # as everything is done in place nothing gets returned so in order
209 # to start and wait for the processes to finish we must iterate
210 # over the None values returned
211 for _ in tpe.map(partial(calc_corr, window), range(len(window))):
212 pass
214 # collect results into dict with numpy arrays (n_qpoints, N_tc)
215 data_dict_corr = dict()
216 time = delta_t * np.arange(N_tc, dtype=float)
217 data_dict_corr['q_points'] = q_points
218 data_dict_corr['time'] = time
220 F_q_t_tot = np.zeros((n_qpoints, N_tc))
221 S_q_w_tot = np.zeros((n_qpoints, N_tc))
222 for pair in pairs:
223 key = '_'.join(pair)
224 F_q_t = 1 / traj.n_atoms * F_q_t_averager[pair].get_average_all()
225 w, S_q_w = fourier_cos_filon(F_q_t, delta_t)
226 S_q_w = np.array(S_q_w)
227 data_dict_corr['omega'] = w
228 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t
229 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w
231 # sum all partials to the total
232 F_q_t_tot += F_q_t
233 S_q_w_tot += S_q_w
234 data_dict_corr['Fqt_coh'] = F_q_t_tot
235 data_dict_corr['Sqw_coh'] = S_q_w_tot
237 if calculate_currents:
238 Cl_q_t_tot = np.zeros((n_qpoints, N_tc))
239 Ct_q_t_tot = np.zeros((n_qpoints, N_tc))
240 Cl_q_w_tot = np.zeros((n_qpoints, N_tc))
241 Ct_q_w_tot = np.zeros((n_qpoints, N_tc))
242 for pair in pairs:
243 key = '_'.join(pair)
244 Cl_q_t = 1 / traj.n_atoms * Cl_q_t_averager[pair].get_average_all()
245 Ct_q_t = 1 / traj.n_atoms * Ct_q_t_averager[pair].get_average_all()
246 _, Cl_q_w = fourier_cos_filon(Cl_q_t, delta_t)
247 _, Ct_q_w = fourier_cos_filon(Ct_q_t, delta_t)
248 data_dict_corr[f'Clqt_{key}'] = Cl_q_t
249 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t
250 data_dict_corr[f'Clqw_{key}'] = Cl_q_w
251 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w
253 # sum all partials to the total
254 Cl_q_t_tot += Cl_q_t
255 Ct_q_t_tot += Ct_q_t
256 Cl_q_w_tot += Cl_q_w
257 Ct_q_w_tot += Ct_q_w
258 data_dict_corr['Clqt'] = Cl_q_t_tot
259 data_dict_corr['Ctqt'] = Ct_q_t_tot
260 data_dict_corr['Clqw'] = Cl_q_w_tot
261 data_dict_corr['Ctqw'] = Ct_q_w_tot
263 if calculate_incoherent:
264 Fs_q_t_tot = np.zeros((n_qpoints, N_tc))
265 Ss_q_w_tot = np.zeros((n_qpoints, N_tc))
266 for atom_type in traj.atom_types:
267 Fs_q_t = 1 / traj.n_atoms * F_s_q_t_averager[atom_type].get_average_all()
268 _, Ss_q_w = fourier_cos_filon(Fs_q_t, delta_t)
269 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t
270 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w
272 # sum all partials to the total
273 Fs_q_t_tot += Fs_q_t
274 Ss_q_w_tot += Ss_q_w
276 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot
277 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot
279 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'] + data_dict_corr['Fqt_incoh']
280 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'] + data_dict_corr['Sqw_incoh']
281 else:
282 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'].copy()
283 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'].copy()
285 # finalize results with additional meta data
286 result = DynamicSample(data_dict_corr, atom_types=traj.atom_types, pairs=pairs,
287 particle_counts=particle_counts, cell=traj.cell,
288 time_between_frames=delta_t,
289 maximum_time_lag=delta_t * window_size,
290 angular_frequency_resolution=dw,
291 maximum_angular_frequency=w_max,
292 number_of_frames=traj.number_of_frames_read)
294 return result
297def compute_static_structure_factors(
298 traj: Trajectory,
299 q_points: NDArray[float],
300) -> StaticSample:
301 r"""Compute static structure factors. The results are returned in the
302 form of a :class:`StaticSample <dynasor.sample.StaticSample>`
303 object.
305 Parameters
306 ----------
307 traj
308 Input trajectory
309 q_points
310 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
311 """
312 # sanity check input args
313 if q_points.shape[1] != 3:
314 raise ValueError('q-points array has the wrong shape.')
316 n_qpoints = q_points.shape[0]
317 logger.info(f'Number of q-points: {n_qpoints}')
319 # define all pairs
320 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
321 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
322 logger.debug('Considering pairs:')
323 for pair in pairs:
324 logger.debug(f' {pair}')
326 # processing function
327 def f2_rho(frame):
328 rho_qs_dict = dict()
329 for atom_type in frame.positions_by_type.keys():
330 x = frame.positions_by_type[atom_type]
331 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
332 frame.rho_qs_dict = rho_qs_dict
333 return frame
335 # setup averager
336 Sq_averager = dict()
337 for pair in pairs:
338 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0
340 # main loop
341 for frame in traj:
343 # process_frame
344 f2_rho(frame)
345 logger.debug(f'Processing frame {frame.frame_index}')
347 for s1, s2 in pairs:
348 # compute correlation
349 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate())
350 if s1 != s2:
351 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate())
352 Sq_averager[(s1, s2)].add_sample(0, Sq_pair)
354 # collect results
355 data_dict = dict()
356 data_dict['q_points'] = q_points
357 S_q_tot = np.zeros((n_qpoints, 1))
358 for s1, s2 in pairs:
359 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1)
360 data_dict[f'Sq_{s1}_{s2}'] = Sq
361 S_q_tot += Sq
362 data_dict['Sq'] = S_q_tot
364 # finalize results
365 result = StaticSample(data_dict, atom_types=traj.atom_types, pairs=pairs,
366 particle_counts=particle_counts, cell=traj.cell,
367 number_of_frames=traj.number_of_frames_read)
368 return result
371def compute_spectral_energy_density(
372 traj: Trajectory,
373 ideal_supercell: Atoms,
374 primitive_cell: Atoms,
375 q_points: NDArray[float],
376 dt: float,
377 partial: bool = False
378) -> Tuple[NDArray[float], NDArray[float]]:
379 r"""
380 Compute the spectral energy density (SED) at specific q-points. The results
381 are returned in the form of a tuple, which comprises the angular
382 frequencies in an array of length ``N_times`` in units of rad/fs and the
383 SED in units of eV/(rad/fs) as an array of shape ``(N_qpoints, N_times)``.
384 The normalization is chosen such that integrating the SED of a q-point
385 together with the supplied angular frequenceies omega (rad/fs) yields
386 1/2kBT * number of bands (where number of bands = len(prim) * 3)
388 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010),
389 which should be cited when using this function along with the dynasor reference.
391 **Note 1:**
392 SED analysis is only suitable for crystalline materials without diffusion as
393 atoms are assumed to move around fixed reference positions throughout the entire trajectory.
395 **Note 2:**
396 This implementation reads the full trajectory and can thus consume a lot of memory.
398 Parameters
399 ----------
400 traj
401 Input trajectory
402 ideal_supercell
403 Ideal structure defining the reference positions. Do not change the
404 masses in the ASE atoms objects to dynasor internal units, this will be
405 done internally
406 primitive_cell
407 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`.
408 q_points
409 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
410 dt
411 Time difference in femtoseconds between two consecutive snapshots in
412 the trajectory. Note that you should not change :attr:`dt` if you change
413 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
414 partial
415 If True the SED will be returned decomposed per basis and Cartesian direction.
416 The shape is ``(N_qpoints, N_frequencies, len(primitive_cell), 3)``
417 """
419 delta_t = traj.frame_step * dt
421 # logger
422 logger.info('Running SED')
423 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs')
424 logger.info(f'Number of atoms in primitive_cell: {len(primitive_cell)}')
425 logger.info(f'Number of atoms in ideal_supercell: {len(ideal_supercell)}')
426 logger.info(f'Number of q-points: {q_points.shape[0]}')
428 # check that the ideal supercell agrees with traj
429 if traj.n_atoms != len(ideal_supercell):
430 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.')
432 if len(primitive_cell) >= len(ideal_supercell): 432 ↛ 433line 432 didn't jump to line 433, because the condition on line 432 was never true
433 raise ValueError('primitive_cell contains more atoms than ideal_supercell.')
435 # colllect all velocities, and scale with sqrt(masses)
436 masses = ideal_supercell.get_masses().reshape(-1, 1) / fs**2 # From Dalton to dmu
437 velocities = []
438 for it, frame in enumerate(traj):
439 logger.debug(f'Reading frame {it}')
440 if frame.velocities_by_type is None:
441 raise ValueError(f'Could not read velocities from frame {it}')
442 v = frame.get_velocities_as_array(traj.atomic_indices) # in Å/fs
443 velocities.append(np.sqrt(masses) * v)
444 logger.info(f'Number of snapshots: {len(velocities)}')
446 # Perform the FFT on the last axis for extra speed (maybe not needed)
447 N_samples = len(velocities)
448 velocities = np.array(velocities)
449 # places time index last and makes a copy for continuity
450 velocities = velocities.transpose(1, 2, 0).copy()
451 # #atoms in supercell x 3 directions x #frequencies
452 velocities = np.fft.rfft(velocities, axis=2)
454 # Calcualte indices and offsets needed for the sed method
455 indices, offsets = get_index_offset(ideal_supercell, primitive_cell)
457 # Phase factor for use in FT. #qpoints x #atoms in supercell
458 cell_positions = np.dot(offsets, primitive_cell.cell)
459 phase = np.dot(q_points, cell_positions.T) # #qpoints x #unit cells
460 phase_factors = np.exp(1.0j * phase)
462 # This dict maps the offsets to an index so ndarrays can be over
463 # offset,index instead of atoms in supercell
464 offset_dict = {off: n for n, off in enumerate(set(tuple(offset) for offset in offsets))}
466 # Pick out some shapes
467 n_super, _, n_w = velocities.shape
468 n_qpts = len(q_points)
469 n_prim = len(primitive_cell)
470 n_offsets = len(offset_dict)
472 # This new array will be indexed by index and offset instead (and also transposed)
473 new_velocities = np.zeros((n_w, 3, n_prim, n_offsets), dtype=velocities.dtype)
475 for i in range(n_super):
476 j = indices[i] # atom with index i in the supercell is of basis type j ...
477 n = offset_dict[tuple(offsets[i])] # and its offset has index n
478 new_velocities[:, :, j, n] = velocities[i].T
480 velocities = new_velocities
482 # Same story with the spatial phase factors
483 new_phase_factors = np.zeros((n_qpts, n_prim, n_offsets), dtype=phase_factors.dtype)
485 for i in range(n_super):
486 j = indices[i]
487 n = offset_dict[tuple(offsets[i])]
488 new_phase_factors[:, j, n] = phase_factors[:, i]
490 phase_factors = new_phase_factors
492 # calcualte the density in a numba function
493 density = _sed_inner_loop(phase_factors, velocities)
495 if not partial: 495 ↛ 502line 495 didn't jump to line 502, because the condition on line 495 was never false
496 density = np.sum(density, axis=(2, 3))
498 # units
499 # make so that the velocities were originally in Angstrom / fs to be compatible with eV and Da
501 # the time delta in the fourier transform
502 density = density * delta_t**2
504 # Divide by the length of the time signal
505 density = density / (N_samples * delta_t)
507 # Divide by the number of primitive cells
508 density = density / (n_super / n_prim)
510 # Factor so the sed can be integrated together with the returned omega
511 # numpy fft works with ordinary/linear frequencies and not angular freqs
512 density = density / (2*np.pi)
514 # angular frequencies
515 w = 2 * np.pi * np.fft.rfftfreq(N_samples, delta_t) # rad/fs
517 return w, density
520@numba.njit(parallel=True, fastmath=True)
521def _sed_inner_loop(phase_factors, velocities):
522 """This numba function calculates the spatial FT using precomputed phase factors
524 As the use case can be one or many q-points the parallelization is over the
525 temporal frequency components instead.
526 """
528 n_qpts = phase_factors.shape[0] # q-point index
529 n_prim = phase_factors.shape[1] # basis atom index
530 n_super = phase_factors.shape[2] # unit cell index
532 n_freqs = velocities.shape[0] # frequency, direction, basis atom, unit cell
534 density = np.zeros((n_qpts, n_freqs, n_prim, 3), dtype=np.float64)
536 for w in numba.prange(n_freqs):
537 for k in range(n_qpts):
538 for a in range(3):
539 for b in range(n_prim):
540 tmp = 0.0j
541 for n in range(n_super):
542 tmp += phase_factors[k, b, n] * velocities[w, a, b, n]
543 density[k, w, b, a] += np.abs(tmp)**2
544 return density