Coverage for local_installation/dynasor/correlation_functions.py: 100%
252 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-30 21:04 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-30 21:04 +0000
1import concurrent
2from functools import partial
3from itertools import combinations_with_replacement
4from typing import Tuple
6import numpy as np
7from ase import Atoms
8from numpy.typing import NDArray
10from dynasor.logging_tools import logger
11from dynasor.trajectory import Trajectory, WindowIterator
12from dynasor.sample import DynamicSample, StaticSample
13from dynasor.post_processing import fourier_cos
14from dynasor.core.time_averager import TimeAverager
15from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q
16from dynasor.qpoints.tools import get_index_offset
17from dynasor.units import radians_per_fs_to_meV
20def compute_dynamic_structure_factors(
21 traj: Trajectory,
22 q_points: NDArray[float],
23 dt: float,
24 window_size: int,
25 window_step: int = 1,
26 calculate_currents: bool = False,
27 calculate_incoherent: bool = False,
28) -> DynamicSample:
29 """Compute dynamic structure factors. The results are returned in the
30 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>`
31 object.
33 Parameters
34 ----------
35 traj
36 Input trajectory
37 q_points
38 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
39 dt
40 Time difference in femtoseconds between two consecutive snapshots
41 in the trajectory. Note that you should *not* change :attr:`dt` if you change
42 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
43 window_size
44 Length of the trajectory frame window to use for time correlation calculation.
45 It is expressed in terms of the number of time lags to consider
46 and thus determines the smallest frequency resolved.
47 window_step
48 Window step (or stride) given as the number of frames between consecutive trajectory
49 windows. This parameter does *not* affect the time between consecutive frames in the
50 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not
51 be used.
52 calculate_currents
53 Calculate the current correlations. Requires velocities to be available in :attr:`traj`.
54 calculate_incoherent
55 Calculate the incoherent part (self-part) of :math:`F_incoh`.
56 """
57 # sanity check input args
58 if q_points.shape[1] != 3:
59 raise ValueError('q-points array has the wrong shape.')
60 if dt <= 0:
61 raise ValueError(f'dt must be positive: dt= {dt}')
62 if window_size <= 2:
63 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}')
64 if window_size % 2 != 0:
65 raise ValueError(f'window_size must be even: window_size= {window_size}')
66 if window_step <= 0:
67 raise ValueError(f'window_step must be positive: window_step= {window_step}')
69 # define internal parameters
70 n_qpoints = q_points.shape[0]
71 delta_t = traj.frame_step * dt
72 N_tc = window_size + 1
74 # log all setup information
75 dw = 2 * np.pi / (window_size * delta_t)
76 w_max = dw * window_size
77 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency
79 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}')
80 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs')
81 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs')
82 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs')
83 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = '
84 f'{dw * radians_per_fs_to_meV:.3f} meV')
85 logger.info(f'Maximum angular frequency (dw * window_size):'
86 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV')
87 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):'
88 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV')
90 if calculate_currents:
91 logger.info('Calculating current (velocity) correlations')
92 if calculate_incoherent:
93 logger.info('Calculating incoherent part (self-part) of correlations')
95 # log some info regarding q-points
96 logger.info(f'Number of q-points: {n_qpoints}')
98 q_directions = q_points.copy()
99 q_distances = np.linalg.norm(q_points, axis=1)
100 nonzero = q_distances > 0
101 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1)
103 # setup functions to process frames
104 def f2_rho(frame):
105 rho_qs_dict = dict()
106 for atom_type in frame.positions_by_type.keys():
107 x = frame.positions_by_type[atom_type]
108 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
109 frame.rho_qs_dict = rho_qs_dict
110 return frame
112 def f2_rho_and_j(frame):
113 rho_qs_dict = dict()
114 jz_qs_dict = dict()
115 jper_qs_dict = dict()
117 for atom_type in frame.positions_by_type.keys():
118 x = frame.positions_by_type[atom_type]
119 v = frame.velocities_by_type[atom_type]
120 rho_qs, j_qs = calc_rho_j_q(x, v, q_points)
121 jz_qs = np.sum(j_qs * q_directions, axis=1)
122 jper_qs = j_qs - (jz_qs[:, None] * q_directions)
124 rho_qs_dict[atom_type] = rho_qs
125 jz_qs_dict[atom_type] = jz_qs
126 jper_qs_dict[atom_type] = jper_qs
128 frame.rho_qs_dict = rho_qs_dict
129 frame.jz_qs_dict = jz_qs_dict
130 frame.jper_qs_dict = jper_qs_dict
131 return frame
133 if calculate_currents:
134 element_processor = f2_rho_and_j
135 else:
136 element_processor = f2_rho
138 # setup window iterator
139 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step,
140 element_processor=element_processor)
142 # define all pairs
143 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
144 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
145 logger.debug('Considering pairs:')
146 for pair in pairs:
147 logger.debug(f' {pair}')
149 # setup all time averager instances
150 F_q_t_averager = dict()
151 for pair in pairs:
152 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
153 if calculate_currents:
154 Cl_q_t_averager = dict()
155 Ct_q_t_averager = dict()
156 for pair in pairs:
157 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
158 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
159 if calculate_incoherent:
160 F_s_q_t_averager = dict()
161 for pair in traj.atom_types:
162 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints)
164 # define correlation function
165 def calc_corr(window, time_i):
166 # Calculate correlations between two frames in the window without normalization 1/N
167 f0 = window[0]
168 fi = window[time_i]
169 for s1, s2 in pairs:
170 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate())
171 if s1 != s2:
172 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate())
173 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt)
175 if calculate_currents:
176 for s1, s2 in pairs:
177 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate())
178 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] *
179 fi.jper_qs_dict[s2].conjugate(), axis=1))
180 if s1 != s2:
181 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate())
182 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] *
183 fi.jper_qs_dict[s1].conjugate(), axis=1))
185 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt)
186 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt)
188 if calculate_incoherent:
189 for atom_type in traj.atom_types:
190 xi = fi.positions_by_type[atom_type]
191 x0 = f0.positions_by_type[atom_type]
192 Fsqt = np.real(calc_rho_q(xi - x0, q_points))
193 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt)
195 # run calculation
196 with concurrent.futures.ThreadPoolExecutor() as tpe:
197 # This is the "main loop" over the trajectory
198 for window in window_iterator:
199 logger.debug(f'processing window {window[0].frame_index} to {window[-1].frame_index}')
201 # The map conviniently applies calc_corr to all time-lags. However,
202 # as everything is done in place nothing gets returned so in order
203 # to start and wait for the processes to finish we must iterate
204 # over the None values returned
205 for _ in tpe.map(partial(calc_corr, window), range(len(window))):
206 pass
208 # collect results into dict with numpy arrays (n_qpoints, N_tc)
209 data_dict_corr = dict()
210 time = delta_t * np.arange(N_tc, dtype=float)
211 data_dict_corr['q_points'] = q_points
212 data_dict_corr['time'] = time
214 F_q_t_tot = np.zeros((n_qpoints, N_tc))
215 S_q_w_tot = np.zeros((n_qpoints, N_tc))
216 for pair in pairs:
217 key = '_'.join(pair)
218 F_q_t = 1 / traj.n_atoms * F_q_t_averager[pair].get_average_all()
219 w, S_q_w = zip(*[fourier_cos(F, delta_t) for F in F_q_t])
220 w = w[0]
221 S_q_w = np.array(S_q_w)
222 data_dict_corr['omega'] = w
223 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t
224 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w
226 # sum all partials to the total
227 F_q_t_tot += F_q_t
228 S_q_w_tot += S_q_w
229 data_dict_corr['Fqt_coh'] = F_q_t_tot
230 data_dict_corr['Sqw_coh'] = S_q_w_tot
232 if calculate_currents:
233 Cl_q_t_tot = np.zeros((n_qpoints, N_tc))
234 Ct_q_t_tot = np.zeros((n_qpoints, N_tc))
235 Cl_q_w_tot = np.zeros((n_qpoints, N_tc))
236 Ct_q_w_tot = np.zeros((n_qpoints, N_tc))
237 for pair in pairs:
238 key = '_'.join(pair)
239 Cl_q_t = 1 / traj.n_atoms * Cl_q_t_averager[pair].get_average_all()
240 Ct_q_t = 1 / traj.n_atoms * Ct_q_t_averager[pair].get_average_all()
241 Cl_q_w = np.array([fourier_cos(C, delta_t)[1] for C in Cl_q_t])
242 Ct_q_w = np.array([fourier_cos(C, delta_t)[1] for C in Ct_q_t])
243 data_dict_corr[f'Clqt_{key}'] = Cl_q_t
244 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t
245 data_dict_corr[f'Clqw_{key}'] = Cl_q_w
246 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w
248 # sum all partials to the total
249 Cl_q_t_tot += Cl_q_t
250 Ct_q_t_tot += Ct_q_t
251 Cl_q_w_tot += Cl_q_w
252 Ct_q_w_tot += Ct_q_w
253 data_dict_corr['Clqt'] = Cl_q_t_tot
254 data_dict_corr['Ctqt'] = Ct_q_t_tot
255 data_dict_corr['Clqw'] = Cl_q_w_tot
256 data_dict_corr['Ctqw'] = Ct_q_w_tot
258 if calculate_incoherent:
259 Fs_q_t_tot = np.zeros((n_qpoints, N_tc))
260 Ss_q_w_tot = np.zeros((n_qpoints, N_tc))
261 for atom_type in traj.atom_types:
262 Fs_q_t = 1 / traj.n_atoms * F_s_q_t_averager[atom_type].get_average_all()
263 Ss_q_w = np.array([fourier_cos(F, delta_t)[1] for F in Fs_q_t])
264 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t
265 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w
267 # sum all partials to the total
268 Fs_q_t_tot += Fs_q_t
269 Ss_q_w_tot += Ss_q_w
271 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot
272 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot
274 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'] + data_dict_corr['Fqt_incoh']
275 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'] + data_dict_corr['Sqw_incoh']
276 else:
277 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'].copy()
278 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'].copy()
280 # finalize results with additional meta data
281 result = DynamicSample(data_dict_corr, atom_types=traj.atom_types, pairs=pairs,
282 particle_counts=particle_counts, cell=traj.cell,
283 time_between_frames=delta_t,
284 maximum_time_lag=delta_t * window_size,
285 angular_frequency_resolution=dw,
286 maximum_angular_frequency=w_max,
287 number_of_frames=traj.number_of_frames_read)
289 return result
292def compute_static_structure_factors(
293 traj: Trajectory,
294 q_points: NDArray[float],
295) -> StaticSample:
296 r"""Compute static structure factors. The results are returned in the
297 form of a :class:`StaticSample <dynasor.sample.StaticSample>`
298 object.
300 Parameters
301 ----------
302 traj
303 Input trajectory
304 q_points
305 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
306 """
307 # sanity check input args
308 if q_points.shape[1] != 3:
309 raise ValueError('q-points array has the wrong shape.')
311 n_qpoints = q_points.shape[0]
312 logger.info(f'Number of q-points: {n_qpoints}')
314 # define all pairs
315 pairs = list(combinations_with_replacement(traj.atom_types, r=2))
316 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()}
317 logger.debug('Considering pairs:')
318 for pair in pairs:
319 logger.debug(f' {pair}')
321 # processing function
322 def f2_rho(frame):
323 rho_qs_dict = dict()
324 for atom_type in frame.positions_by_type.keys():
325 x = frame.positions_by_type[atom_type]
326 rho_qs_dict[atom_type] = calc_rho_q(x, q_points)
327 frame.rho_qs_dict = rho_qs_dict
328 return frame
330 # setup averager
331 Sq_averager = dict()
332 for pair in pairs:
333 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0
335 # main loop
336 for frame in traj:
338 # process_frame
339 f2_rho(frame)
340 logger.debug(f'Processing frame {frame.frame_index}')
342 for s1, s2 in pairs:
343 # compute correlation
344 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate())
345 if s1 != s2:
346 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate())
347 Sq_averager[(s1, s2)].add_sample(0, Sq_pair)
349 # collect results
350 data_dict = dict()
351 data_dict['q_points'] = q_points
352 S_q_tot = np.zeros((n_qpoints, 1))
353 for s1, s2 in pairs:
354 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1)
355 data_dict[f'Sq_{s1}_{s2}'] = Sq
356 S_q_tot += Sq
357 data_dict['Sq'] = S_q_tot
359 # finalize results
360 result = StaticSample(data_dict, atom_types=traj.atom_types, pairs=pairs,
361 particle_counts=particle_counts, cell=traj.cell,
362 number_of_frames=traj.number_of_frames_read)
363 return result
366def compute_spectral_energy_density(
367 traj: Trajectory,
368 ideal_supercell: Atoms,
369 primitive_cell: Atoms,
370 q_points: NDArray[float],
371 dt: float,
372) -> Tuple[NDArray[float], NDArray[float]]:
373 r"""
374 Compute the spectral energy density (SED) at specific q-points.
375 The results are returned in the form of a tuple, which comprises the
376 angular frequencies in an array of length ``N_times`` in units of 2π/fs
377 and the SED in units of Da*(Å/fs)² as an array of shape ``(N_qpoints, N_times)``.
379 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010),
380 which should be cited when using this function along with the dynasor reference.
382 **Note 1:**
383 SED analysis is only suitable for crystalline materials without diffusion as
384 atoms are assumed to move around fixed reference positions throughout the entire trajectory.
386 **Note 2:**
387 This implementation reads the full trajectory and can thus consume a lot of memory.
389 Parameters
390 ----------
391 traj
392 Input trajectory
393 ideal_supercell
394 Ideal structure defining the reference positions
395 primitive_cell
396 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`.
397 q_points
398 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates
399 dt
400 Time difference in femtoseconds between two consecutive snapshots in
401 the trajectory. Note that you should not change :attr:`dt` if you change
402 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`.
403 """
405 delta_t = traj.frame_step * dt
407 # logger
408 logger.info('Running SED')
409 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs')
411 # check that the ideal supercell agrees with traj
412 if traj.n_atoms != len(ideal_supercell):
413 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.')
415 # colllect all velocities
416 velocities = []
417 for it, frame in enumerate(traj):
418 logger.debug(f'Reading frame {it}')
419 if frame.velocities_by_type is None:
420 raise ValueError(f'Could not read velocities from frame {it}')
421 v = frame.get_velocities_as_array(traj.atomic_indices)
422 velocities.append(v)
424 velocities = np.array(velocities)
425 velocities = velocities.transpose(1, 2, 0).copy()
426 velocities = np.fft.fft(velocities, axis=2)
428 # calculate SED
429 masses = primitive_cell.get_masses()
430 indices, offsets = get_index_offset(ideal_supercell, primitive_cell)
432 pos = np.dot(q_points, np.dot(offsets, primitive_cell.cell).T)
433 exppos = np.exp(1.0j * pos)
434 density = np.zeros((len(q_points), velocities.shape[2]))
435 for alpha in range(3):
436 for b in range(len(masses)):
437 tmp = np.zeros(density.shape, dtype=complex)
438 for i in range(len(indices)):
439 index = indices[i]
440 if index != b:
441 continue
442 tmp += np.outer(exppos[:, i], velocities[i, alpha])
444 density += masses[b] * np.abs(tmp)**2
446 # angular frequencies
447 w = np.linspace(0.0, 2 * np.pi / delta_t, density.shape[1]) # units of 2pi/fs
449 return w, density