Coverage for dynasor / correlation_functions.py: 94%
292 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 12:31 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 12:31 +0000
1import concurrent
2from functools import partial
3from itertools import combinations_with_replacement
4from typing import Optional
6import numba
7import numpy as np
8from ase import Atoms
9from ase.units import fs
10from numpy.typing import NDArray
12from dynasor.logging_tools import logger
13from dynasor.trajectory import Trajectory, WindowIterator
14from dynasor.sample import DynamicSample, StaticSample
16from dynasor.post_processing import fourier_cos_filon
17from dynasor.core.time_averager import TimeAverager
18from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q
19from dynasor.tools.structures import get_offset_index
20from dynasor.units import radians_per_fs_to_meV
23def compute_dynamic_structure_factors(
24 traj: Trajectory,
25 q_points: NDArray[float],
26 dt: float,
27 window_size: int,
28 window_step: Optional[int] = 1,
29 calculate_currents: Optional[bool] = False,
30 calculate_incoherent: Optional[bool] = False,
31) -> DynamicSample:
32 r"""Compute the dynamic structure factors. The results are returned in the
33 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>`
34 object.
36 Parameters
37 ----------
38 traj
39 Input trajectory.
40 q_points
41 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates.
42 dt
43 Time difference in femtoseconds between two consecutive snapshots
44 in the trajectory. Note that you should *not* change :attr:`dt` if you change
45 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
46 window_size
47 Length of the trajectory frame window to use for time correlation calculation.
48 It is expressed in terms of the number of time lags to consider
49 and thus determines the smallest frequency resolved.
50 window_step
51 Window step (or stride) given as the number of frames between consecutive trajectory
52 windows. This parameter does *not* affect the time between consecutive frames in the
53 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not
54 be used.
55 calculate_currents
56 Calculate the current correlations. Requires velocities to be available in :attr:`traj`.
57 calculate_incoherent
58 Calculate the incoherent part (self-part) of :math:`F_\text{incoh}`.
59 """
60 # sanity check input args
61 if q_points.shape[1] != 3:
62 raise ValueError('q-points array has the wrong shape.')
63 if dt <= 0:
64 raise ValueError(f'dt must be positive: dt= {dt}')
65 if window_size <= 2:
66 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}')
67 if window_size % 2 != 0:
68 raise ValueError(f'window_size must be even: window_size= {window_size}')
69 if window_step <= 0:
70 raise ValueError(f'window_step must be positive: window_step= {window_step}')
72 # define internal parameters
73 n_qpoints = q_points.shape[0]
74 delta_t = traj.frame_step * dt
75 N_tc = window_size + 1
77 # log all setup information
78 dw = np.pi / (window_size * delta_t)
79 w_max = dw * window_size
80 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency
82 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}')
83 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs')
84 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs')
85 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs')
86 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = '
87 f'{dw * radians_per_fs_to_meV:.3f} meV')
88 logger.info(f'Maximum angular frequency (dw * window_size):'
89 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV')
90 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):'
91 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV')
93 if calculate_currents:
94 logger.info('Calculating current (velocity) correlations')
95 if calculate_incoherent:
96 logger.info('Calculating incoherent part (self-part) of correlations')
98 # log some info regarding q-points
99 logger.info(f'Number of q-points: {n_qpoints}')
101 q_directions = q_points.copy()
102 q_distances = np.linalg.norm(q_points, axis=1)
103 nonzero = q_distances > 0
104 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1)
106 # setup functions to process frames
107 def f2_rho(frame):
108 rho_qs_dict = dict()
109 for atom_type in frame.positions_by_type.keys():
110 x = frame.positions_by_type[atom_type]
111 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
112 frame.rho_qs_dict = rho_qs_dict
113 return frame
115 def f2_rho_and_j(frame):
116 rho_qs_dict = dict()
117 jz_qs_dict = dict()
118 jper_qs_dict = dict()
120 for atom_type in frame.positions_by_type.keys():
121 x = frame.positions_by_type[atom_type]
122 v = frame.velocities_by_type[atom_type]
123 rho_qs, j_qs = calc_rho_j_q(x, v, q_points)
124 jz_qs = np.sum(j_qs * q_directions, axis=1)
125 jper_qs = j_qs - (jz_qs[:, None] * q_directions)
127 rho_qs_dict[atom_type] = rho_qs
128 jz_qs_dict[atom_type] = jz_qs
129 jper_qs_dict[atom_type] = jper_qs
131 frame.rho_qs_dict = rho_qs_dict
132 frame.jz_qs_dict = jz_qs_dict
133 frame.jper_qs_dict = jper_qs_dict
134 return frame
136 if calculate_currents:
137 element_processor = f2_rho_and_j
138 else:
139 element_processor = f2_rho
141 # setup window iterator
142 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step,
143 element_processor=element_processor)
145 # define all pairs
146 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
147 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
148 logger.debug('Considering pairs:')
149 for pair in pairs:
150 logger.debug(f' {pair}')
152 # set up all time averager instances
153 F_q_t_averager = dict()
154 for pair in pairs:
155 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
156 if calculate_currents:
157 Cl_q_t_averager = dict()
158 Ct_q_t_averager = dict()
159 for pair in pairs:
160 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
161 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
162 if calculate_incoherent:
163 F_s_q_t_averager = dict()
164 for pair in traj.atom_types:
165 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
167 # define correlation function
168 def calc_corr(window, time_i):
169 # Calculate correlations between two frames in the window without normalization 1/N
170 f0 = window[0]
171 fi = window[time_i]
172 for s1, s2 in pairs:
173 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate())
174 if s1 != s2:
175 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate())
176 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt)
178 if calculate_currents:
179 for s1, s2 in pairs:
180 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate())
181 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] *
182 fi.jper_qs_dict[s2].conjugate(), axis=1))
183 if s1 != s2:
184 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate())
185 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] *
186 fi.jper_qs_dict[s1].conjugate(), axis=1))
188 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt)
189 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt)
191 if calculate_incoherent:
192 for atom_type in traj.atom_types:
193 xi = fi.positions_by_type[atom_type]
194 x0 = f0.positions_by_type[atom_type]
195 Fsqt = np.real(calc_rho_q(xi - x0, q_points))
196 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt)
198 # run calculation
199 logging_interval = 1000
200 with concurrent.futures.ThreadPoolExecutor() as tpe:
201 # This is the "main loop" over the trajectory
202 for window in window_iterator:
203 logger.debug(f'Processing window {window[0].frame_index} to {window[-1].frame_index}')
205 if window[0].frame_index % logging_interval == 0:
206 logger.info(f'Processing window {window[0].frame_index} to {window[-1].frame_index}') # noqa
208 # The map conviniently applies calc_corr to all time-lags. However,
209 # as everything is done in place nothing gets returned so in order
210 # to start and wait for the processes to finish we must iterate
211 # over the None values returned
212 for _ in tpe.map(partial(calc_corr, window), range(len(window))):
213 pass
215 # collect results into dict with numpy arrays (n_qpoints, N_tc)
216 data_dict_corr = dict()
217 time = delta_t * np.arange(N_tc, dtype=float)
218 data_dict_corr['q_points'] = q_points
219 data_dict_corr['time'] = time
221 F_q_t_tot = np.zeros((n_qpoints, N_tc))
222 S_q_w_tot = np.zeros((n_qpoints, N_tc))
223 for pair in pairs:
224 key = '_'.join(pair)
225 F_q_t = 1 / traj.n_atoms * F_q_t_averager[pair].get_average_all()
226 w, S_q_w = fourier_cos_filon(F_q_t, delta_t)
227 S_q_w = np.array(S_q_w)
228 data_dict_corr['omega'] = w
229 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t
230 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w
232 # sum all partials to the total
233 F_q_t_tot += F_q_t
234 S_q_w_tot += S_q_w
235 data_dict_corr['Fqt_coh'] = F_q_t_tot
236 data_dict_corr['Sqw_coh'] = S_q_w_tot
238 if calculate_currents:
239 Cl_q_t_tot = np.zeros((n_qpoints, N_tc))
240 Ct_q_t_tot = np.zeros((n_qpoints, N_tc))
241 Cl_q_w_tot = np.zeros((n_qpoints, N_tc))
242 Ct_q_w_tot = np.zeros((n_qpoints, N_tc))
243 for pair in pairs:
244 key = '_'.join(pair)
245 Cl_q_t = 1 / traj.n_atoms * Cl_q_t_averager[pair].get_average_all()
246 Ct_q_t = 1 / traj.n_atoms * Ct_q_t_averager[pair].get_average_all()
247 _, Cl_q_w = fourier_cos_filon(Cl_q_t, delta_t)
248 _, Ct_q_w = fourier_cos_filon(Ct_q_t, delta_t)
249 data_dict_corr[f'Clqt_{key}'] = Cl_q_t
250 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t
251 data_dict_corr[f'Clqw_{key}'] = Cl_q_w
252 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w
254 # sum all partials to the total
255 Cl_q_t_tot += Cl_q_t
256 Ct_q_t_tot += Ct_q_t
257 Cl_q_w_tot += Cl_q_w
258 Ct_q_w_tot += Ct_q_w
259 data_dict_corr['Clqt'] = Cl_q_t_tot
260 data_dict_corr['Ctqt'] = Ct_q_t_tot
261 data_dict_corr['Clqw'] = Cl_q_w_tot
262 data_dict_corr['Ctqw'] = Ct_q_w_tot
264 if calculate_incoherent:
265 Fs_q_t_tot = np.zeros((n_qpoints, N_tc))
266 Ss_q_w_tot = np.zeros((n_qpoints, N_tc))
267 for atom_type in traj.atom_types:
268 Fs_q_t = 1 / traj.n_atoms * F_s_q_t_averager[atom_type].get_average_all()
269 _, Ss_q_w = fourier_cos_filon(Fs_q_t, delta_t)
270 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t
271 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w
273 # sum all partials to the total
274 Fs_q_t_tot += Fs_q_t
275 Ss_q_w_tot += Ss_q_w
277 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot
278 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot
280 # finalize results with additional metadata
281 new_sample = DynamicSample(
282 data_dict_corr,
283 simulation_data=dict(
284 atom_types=traj.atom_types, pairs=pairs,
285 particle_counts=particle_counts,
286 cell=traj.cell,
287 time_between_frames=delta_t,
288 maximum_time_lag=delta_t * window_size,
289 angular_frequency_resolution=dw,
290 maximum_angular_frequency=w_max,
291 number_of_frames=traj.number_of_frames_read,
292 ))
293 new_sample._append_history(
294 'compute_dynamic_structure_factors',
295 dict(
296 dt=dt,
297 window_size=window_size,
298 window_step=window_step,
299 calculate_currents=calculate_currents,
300 calculate_incoherent=calculate_incoherent,
301 ))
303 return new_sample
306def compute_static_structure_factors(
307 traj: Trajectory,
308 q_points: NDArray[float],
309) -> StaticSample:
310 r"""Compute the static structure factors. The results are returned in the
311 form of a :class:`StaticSample <dynasor.sample.StaticSample>`
312 object.
314 Parameters
315 ----------
316 traj
317 Input trajectory.
318 q_points
319 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates.
320 """
321 # sanity check input args
322 if q_points.shape[1] != 3:
323 raise ValueError('q-points array has the wrong shape.')
325 n_qpoints = q_points.shape[0]
326 logger.info(f'Number of q-points: {n_qpoints}')
328 # define all pairs
329 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
330 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
331 logger.debug('Considering pairs:')
332 for pair in pairs:
333 logger.debug(f' {pair}')
335 # processing function
336 def f2_rho(frame):
337 rho_qs_dict = dict()
338 for atom_type in frame.positions_by_type.keys():
339 x = frame.positions_by_type[atom_type]
340 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
341 frame.rho_qs_dict = rho_qs_dict
342 return frame
344 # setup averager
345 Sq_averager = dict()
346 for pair in pairs:
347 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0
349 # main loop
350 for frame in traj:
352 # process_frame
353 f2_rho(frame)
354 logger.debug(f'Processing frame {frame.frame_index}')
356 for s1, s2 in pairs:
357 # compute correlation
358 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate())
359 if s1 != s2:
360 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate())
361 Sq_averager[(s1, s2)].add_sample(0, Sq_pair)
363 # collect results
364 data_dict = dict()
365 data_dict['q_points'] = q_points
366 S_q_tot = np.zeros((n_qpoints, 1))
367 for s1, s2 in pairs:
368 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1)
369 data_dict[f'Sq_{s1}_{s2}'] = Sq
370 S_q_tot += Sq
371 data_dict['Sq'] = S_q_tot
373 # finalize results
374 new_sample = StaticSample(
375 data_dict,
376 simulation_data=dict(
377 atom_types=traj.atom_types,
378 pairs=pairs,
379 particle_counts=particle_counts,
380 cell=traj.cell,
381 number_of_frames=traj.number_of_frames_read,
382 ))
383 new_sample._append_history('compute_static_structure_factors')
385 return new_sample
388def compute_spectral_energy_density(
389 traj: Trajectory,
390 ideal_supercell: Atoms,
391 primitive_cell: Atoms,
392 q_points: NDArray[float],
393 dt: float,
394 partial: Optional[bool] = False
395) -> tuple[NDArray[float], NDArray[float]]:
396 r"""
397 Compute the spectral energy density (SED) at specific q-points. The results
398 are returned in the form of a tuple, which comprises the angular
399 frequencies in an array of length ``N_times`` in units of rad/fs and the
400 SED in units of eV/(rad/fs) as an array of shape ``(N_qpoints, N_times)``.
401 The normalization is chosen such that integrating the SED of a q-point
402 together with the supplied angular frequencies omega (rad/fs) yields
403 `1/2 kB T` * number of bands (where number of bands = `len(prim) * 3`)
405 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010),
406 which should be cited when using this function along with the dynasor reference.
408 **Note 1:**
409 SED analysis is only suitable for crystalline materials without diffusion as
410 atoms are assumed to move around fixed reference positions throughout the entire trajectory.
412 **Note 2:**
413 This implementation reads the full trajectory and can thus consume a lot of memory.
415 Parameters
416 ----------
417 traj
418 Input trajectory.
419 ideal_supercell
420 Ideal structure defining the reference positions. Do not change the masses
421 in the ASE :class:`Atoms` objects to dynasor internal units, this will be
422 done internally.
423 primitive_cell
424 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`.
425 q_points
426 Array of q-points in units of rad/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates.
427 dt
428 Time difference in femtoseconds between two consecutive snapshots in
429 the trajectory. Note that you should not change :attr:`dt` if you change
430 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
431 partial
432 If True the SED will be returned decomposed per basis and Cartesian direction.
433 The shape is ``(N_qpoints, N_frequencies, len(primitive_cell), 3)``.
434 """
436 delta_t = traj.frame_step * dt
438 # logger
439 logger.info('Running SED')
440 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs')
441 logger.info(f'Number of atoms in primitive_cell: {len(primitive_cell)}')
442 logger.info(f'Number of atoms in ideal_supercell: {len(ideal_supercell)}')
443 logger.info(f'Number of q-points: {q_points.shape[0]}')
445 # check that the ideal supercell agrees with traj
446 if traj.n_atoms != len(ideal_supercell):
447 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.')
449 if len(primitive_cell) >= len(ideal_supercell):
450 raise ValueError('primitive_cell contains more atoms than ideal_supercell.')
452 # colllect all velocities, and scale with sqrt(masses)
453 masses = ideal_supercell.get_masses().reshape(-1, 1) / fs**2 # From Dalton to dmu
454 velocities = []
455 for it, frame in enumerate(traj):
456 logger.debug(f'Reading frame {it}')
457 if frame.velocities_by_type is None:
458 raise ValueError(f'Could not read velocities from frame {it}')
459 v = frame.get_velocities_as_array(traj.atomic_indices) # in Å/fs
460 velocities.append(np.sqrt(masses) * v)
461 logger.info(f'Number of snapshots: {len(velocities)}')
463 # Perform the FFT on the last axis for extra speed (maybe not needed)
464 N_samples = len(velocities)
465 velocities = np.array(velocities)
466 # places time index last and makes a copy for continuity
467 velocities = velocities.transpose(1, 2, 0).copy()
468 # #atoms in supercell x 3 directions x #frequencies
469 velocities = np.fft.rfft(velocities, axis=2)
471 # Calcualte indices and offsets needed for the sed method
472 offsets, indices = get_offset_index(primitive_cell, ideal_supercell)
474 # Phase factor for use in FT. #qpoints x #atoms in supercell
475 cell_positions = np.dot(offsets, primitive_cell.cell)
476 phase = np.dot(q_points, cell_positions.T) # #qpoints x #unit cells
477 phase_factors = np.exp(1.0j * phase)
479 # This dict maps the offsets to an index so ndarrays can be over
480 # offset,index instead of atoms in supercell
481 offset_dict = {off: n for n, off in enumerate(set(tuple(offset) for offset in offsets))}
483 # Pick out some shapes
484 n_super, _, n_w = velocities.shape
485 n_qpts = len(q_points)
486 n_prim = len(primitive_cell)
487 n_offsets = len(offset_dict)
489 # This new array will be indexed by index and offset instead (and also transposed)
490 new_velocities = np.zeros((n_w, 3, n_prim, n_offsets), dtype=velocities.dtype)
492 for i in range(n_super):
493 j = indices[i] # atom with index i in the supercell is of basis type j ...
494 n = offset_dict[tuple(offsets[i])] # and its offset has index n
495 new_velocities[:, :, j, n] = velocities[i].T
497 velocities = new_velocities
499 # Same story with the spatial phase factors
500 new_phase_factors = np.zeros((n_qpts, n_prim, n_offsets), dtype=phase_factors.dtype)
502 for i in range(n_super):
503 j = indices[i]
504 n = offset_dict[tuple(offsets[i])]
505 new_phase_factors[:, j, n] = phase_factors[:, i]
507 phase_factors = new_phase_factors
509 # calculate the density in a numba function
510 density = _sed_inner_loop(phase_factors, velocities)
512 if not partial:
513 density = np.sum(density, axis=(2, 3))
515 # units
516 # make so that the velocities were originally in Angstrom / fs to be compatible with eV and Da
518 # the time delta in the fourier transform
519 density = density * delta_t**2
521 # Divide by the length of the time signal
522 density = density / (N_samples * delta_t)
524 # Divide by the number of primitive cells
525 density = density / (n_super / n_prim)
527 # Factor so the sed can be integrated together with the returned omega
528 # numpy fft works with ordinary/linear frequencies and not angular freqs
529 density = density / (2*np.pi)
531 # angular frequencies
532 w = 2 * np.pi * np.fft.rfftfreq(N_samples, delta_t) # rad/fs
534 return w, density
537@numba.njit(parallel=True, fastmath=True)
538def _sed_inner_loop(phase_factors, velocities):
539 """This numba function calculates the spatial
540 Fourier transform using precomputed phase factors.
542 As the use case can be one or many q-points the parallelization is over the
543 temporal frequency components instead.
544 """
546 n_qpts = phase_factors.shape[0] # q-point index
547 n_prim = phase_factors.shape[1] # basis atom index
548 n_super = phase_factors.shape[2] # unit cell index
550 n_freqs = velocities.shape[0] # frequency, direction, basis atom, unit cell
552 density = np.zeros((n_qpts, n_freqs, n_prim, 3), dtype=np.float64)
554 for w in numba.prange(n_freqs):
555 for k in range(n_qpts):
556 for a in range(3):
557 for b in range(n_prim):
558 tmp = 0.0j
559 for n in range(n_super):
560 tmp += phase_factors[k, b, n] * velocities[w, a, b, n]
561 density[k, w, b, a] += np.abs(tmp)**2
562 return density