Coverage for local_installation/dynasor/correlation_functions.py: 100%

252 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-05 09:53 +0000

1import concurrent 

2from functools import partial 

3from itertools import combinations_with_replacement 

4from typing import Tuple 

5 

6import numpy as np 

7from ase import Atoms 

8from numpy.typing import NDArray 

9 

10from dynasor.logging_tools import logger 

11from dynasor.trajectory import Trajectory, WindowIterator 

12from dynasor.sample import DynamicSample, StaticSample 

13from dynasor.post_processing import fourier_cos_filon 

14from dynasor.core.time_averager import TimeAverager 

15from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q 

16from dynasor.qpoints.tools import get_index_offset 

17from dynasor.units import radians_per_fs_to_meV 

18 

19 

20def compute_dynamic_structure_factors( 

21 traj: Trajectory, 

22 q_points: NDArray[float], 

23 dt: float, 

24 window_size: int, 

25 window_step: int = 1, 

26 calculate_currents: bool = False, 

27 calculate_incoherent: bool = False, 

28) -> DynamicSample: 

29 """Compute dynamic structure factors. The results are returned in the 

30 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>` 

31 object. 

32 

33 Parameters 

34 ---------- 

35 traj 

36 Input trajectory 

37 q_points 

38 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

39 dt 

40 Time difference in femtoseconds between two consecutive snapshots 

41 in the trajectory. Note that you should *not* change :attr:`dt` if you change 

42 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

43 window_size 

44 Length of the trajectory frame window to use for time correlation calculation. 

45 It is expressed in terms of the number of time lags to consider 

46 and thus determines the smallest frequency resolved. 

47 window_step 

48 Window step (or stride) given as the number of frames between consecutive trajectory 

49 windows. This parameter does *not* affect the time between consecutive frames in the 

50 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not 

51 be used. 

52 calculate_currents 

53 Calculate the current correlations. Requires velocities to be available in :attr:`traj`. 

54 calculate_incoherent 

55 Calculate the incoherent part (self-part) of :math:`F_incoh`. 

56 """ 

57 # sanity check input args 

58 if q_points.shape[1] != 3: 

59 raise ValueError('q-points array has the wrong shape.') 

60 if dt <= 0: 

61 raise ValueError(f'dt must be positive: dt= {dt}') 

62 if window_size <= 2: 

63 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}') 

64 if window_size % 2 != 0: 

65 raise ValueError(f'window_size must be even: window_size= {window_size}') 

66 if window_step <= 0: 

67 raise ValueError(f'window_step must be positive: window_step= {window_step}') 

68 

69 # define internal parameters 

70 n_qpoints = q_points.shape[0] 

71 delta_t = traj.frame_step * dt 

72 N_tc = window_size + 1 

73 

74 # log all setup information 

75 dw = np.pi / (window_size * delta_t) 

76 w_max = dw * window_size 

77 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency 

78 

79 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}') 

80 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs') 

81 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs') 

82 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs') 

83 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = ' 

84 f'{dw * radians_per_fs_to_meV:.3f} meV') 

85 logger.info(f'Maximum angular frequency (dw * window_size):' 

86 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV') 

87 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):' 

88 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV') 

89 

90 if calculate_currents: 

91 logger.info('Calculating current (velocity) correlations') 

92 if calculate_incoherent: 

93 logger.info('Calculating incoherent part (self-part) of correlations') 

94 

95 # log some info regarding q-points 

96 logger.info(f'Number of q-points: {n_qpoints}') 

97 

98 q_directions = q_points.copy() 

99 q_distances = np.linalg.norm(q_points, axis=1) 

100 nonzero = q_distances > 0 

101 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1) 

102 

103 # setup functions to process frames 

104 def f2_rho(frame): 

105 rho_qs_dict = dict() 

106 for atom_type in frame.positions_by_type.keys(): 

107 x = frame.positions_by_type[atom_type] 

108 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

109 frame.rho_qs_dict = rho_qs_dict 

110 return frame 

111 

112 def f2_rho_and_j(frame): 

113 rho_qs_dict = dict() 

114 jz_qs_dict = dict() 

115 jper_qs_dict = dict() 

116 

117 for atom_type in frame.positions_by_type.keys(): 

118 x = frame.positions_by_type[atom_type] 

119 v = frame.velocities_by_type[atom_type] 

120 rho_qs, j_qs = calc_rho_j_q(x, v, q_points) 

121 jz_qs = np.sum(j_qs * q_directions, axis=1) 

122 jper_qs = j_qs - (jz_qs[:, None] * q_directions) 

123 

124 rho_qs_dict[atom_type] = rho_qs 

125 jz_qs_dict[atom_type] = jz_qs 

126 jper_qs_dict[atom_type] = jper_qs 

127 

128 frame.rho_qs_dict = rho_qs_dict 

129 frame.jz_qs_dict = jz_qs_dict 

130 frame.jper_qs_dict = jper_qs_dict 

131 return frame 

132 

133 if calculate_currents: 

134 element_processor = f2_rho_and_j 

135 else: 

136 element_processor = f2_rho 

137 

138 # setup window iterator 

139 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step, 

140 element_processor=element_processor) 

141 

142 # define all pairs 

143 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

144 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

145 logger.debug('Considering pairs:') 

146 for pair in pairs: 

147 logger.debug(f' {pair}') 

148 

149 # setup all time averager instances 

150 F_q_t_averager = dict() 

151 for pair in pairs: 

152 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

153 if calculate_currents: 

154 Cl_q_t_averager = dict() 

155 Ct_q_t_averager = dict() 

156 for pair in pairs: 

157 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

158 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

159 if calculate_incoherent: 

160 F_s_q_t_averager = dict() 

161 for pair in traj.atom_types: 

162 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

163 

164 # define correlation function 

165 def calc_corr(window, time_i): 

166 # Calculate correlations between two frames in the window without normalization 1/N 

167 f0 = window[0] 

168 fi = window[time_i] 

169 for s1, s2 in pairs: 

170 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate()) 

171 if s1 != s2: 

172 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate()) 

173 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt) 

174 

175 if calculate_currents: 

176 for s1, s2 in pairs: 

177 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate()) 

178 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] * 

179 fi.jper_qs_dict[s2].conjugate(), axis=1)) 

180 if s1 != s2: 

181 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate()) 

182 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] * 

183 fi.jper_qs_dict[s1].conjugate(), axis=1)) 

184 

185 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt) 

186 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt) 

187 

188 if calculate_incoherent: 

189 for atom_type in traj.atom_types: 

190 xi = fi.positions_by_type[atom_type] 

191 x0 = f0.positions_by_type[atom_type] 

192 Fsqt = np.real(calc_rho_q(xi - x0, q_points)) 

193 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt) 

194 

195 # run calculation 

196 with concurrent.futures.ThreadPoolExecutor() as tpe: 

197 # This is the "main loop" over the trajectory 

198 for window in window_iterator: 

199 logger.debug(f'processing window {window[0].frame_index} to {window[-1].frame_index}') 

200 

201 # The map conviniently applies calc_corr to all time-lags. However, 

202 # as everything is done in place nothing gets returned so in order 

203 # to start and wait for the processes to finish we must iterate 

204 # over the None values returned 

205 for _ in tpe.map(partial(calc_corr, window), range(len(window))): 

206 pass 

207 

208 # collect results into dict with numpy arrays (n_qpoints, N_tc) 

209 data_dict_corr = dict() 

210 time = delta_t * np.arange(N_tc, dtype=float) 

211 data_dict_corr['q_points'] = q_points 

212 data_dict_corr['time'] = time 

213 

214 F_q_t_tot = np.zeros((n_qpoints, N_tc)) 

215 S_q_w_tot = np.zeros((n_qpoints, N_tc)) 

216 for pair in pairs: 

217 key = '_'.join(pair) 

218 F_q_t = 1 / traj.n_atoms * F_q_t_averager[pair].get_average_all() 

219 w, S_q_w = fourier_cos_filon(F_q_t, delta_t) 

220 S_q_w = np.array(S_q_w) 

221 data_dict_corr['omega'] = w 

222 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t 

223 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w 

224 

225 # sum all partials to the total 

226 F_q_t_tot += F_q_t 

227 S_q_w_tot += S_q_w 

228 data_dict_corr['Fqt_coh'] = F_q_t_tot 

229 data_dict_corr['Sqw_coh'] = S_q_w_tot 

230 

231 if calculate_currents: 

232 Cl_q_t_tot = np.zeros((n_qpoints, N_tc)) 

233 Ct_q_t_tot = np.zeros((n_qpoints, N_tc)) 

234 Cl_q_w_tot = np.zeros((n_qpoints, N_tc)) 

235 Ct_q_w_tot = np.zeros((n_qpoints, N_tc)) 

236 for pair in pairs: 

237 key = '_'.join(pair) 

238 Cl_q_t = 1 / traj.n_atoms * Cl_q_t_averager[pair].get_average_all() 

239 Ct_q_t = 1 / traj.n_atoms * Ct_q_t_averager[pair].get_average_all() 

240 _, Cl_q_w = fourier_cos_filon(Cl_q_t, delta_t) 

241 _, Ct_q_w = fourier_cos_filon(Ct_q_t, delta_t) 

242 data_dict_corr[f'Clqt_{key}'] = Cl_q_t 

243 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t 

244 data_dict_corr[f'Clqw_{key}'] = Cl_q_w 

245 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w 

246 

247 # sum all partials to the total 

248 Cl_q_t_tot += Cl_q_t 

249 Ct_q_t_tot += Ct_q_t 

250 Cl_q_w_tot += Cl_q_w 

251 Ct_q_w_tot += Ct_q_w 

252 data_dict_corr['Clqt'] = Cl_q_t_tot 

253 data_dict_corr['Ctqt'] = Ct_q_t_tot 

254 data_dict_corr['Clqw'] = Cl_q_w_tot 

255 data_dict_corr['Ctqw'] = Ct_q_w_tot 

256 

257 if calculate_incoherent: 

258 Fs_q_t_tot = np.zeros((n_qpoints, N_tc)) 

259 Ss_q_w_tot = np.zeros((n_qpoints, N_tc)) 

260 for atom_type in traj.atom_types: 

261 Fs_q_t = 1 / traj.n_atoms * F_s_q_t_averager[atom_type].get_average_all() 

262 _, Ss_q_w = fourier_cos_filon(Fs_q_t, delta_t) 

263 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t 

264 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w 

265 

266 # sum all partials to the total 

267 Fs_q_t_tot += Fs_q_t 

268 Ss_q_w_tot += Ss_q_w 

269 

270 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot 

271 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot 

272 

273 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'] + data_dict_corr['Fqt_incoh'] 

274 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'] + data_dict_corr['Sqw_incoh'] 

275 else: 

276 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'].copy() 

277 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'].copy() 

278 

279 # finalize results with additional meta data 

280 result = DynamicSample(data_dict_corr, atom_types=traj.atom_types, pairs=pairs, 

281 particle_counts=particle_counts, cell=traj.cell, 

282 time_between_frames=delta_t, 

283 maximum_time_lag=delta_t * window_size, 

284 angular_frequency_resolution=dw, 

285 maximum_angular_frequency=w_max, 

286 number_of_frames=traj.number_of_frames_read) 

287 

288 return result 

289 

290 

291def compute_static_structure_factors( 

292 traj: Trajectory, 

293 q_points: NDArray[float], 

294) -> StaticSample: 

295 r"""Compute static structure factors. The results are returned in the 

296 form of a :class:`StaticSample <dynasor.sample.StaticSample>` 

297 object. 

298 

299 Parameters 

300 ---------- 

301 traj 

302 Input trajectory 

303 q_points 

304 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

305 """ 

306 # sanity check input args 

307 if q_points.shape[1] != 3: 

308 raise ValueError('q-points array has the wrong shape.') 

309 

310 n_qpoints = q_points.shape[0] 

311 logger.info(f'Number of q-points: {n_qpoints}') 

312 

313 # define all pairs 

314 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

315 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

316 logger.debug('Considering pairs:') 

317 for pair in pairs: 

318 logger.debug(f' {pair}') 

319 

320 # processing function 

321 def f2_rho(frame): 

322 rho_qs_dict = dict() 

323 for atom_type in frame.positions_by_type.keys(): 

324 x = frame.positions_by_type[atom_type] 

325 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

326 frame.rho_qs_dict = rho_qs_dict 

327 return frame 

328 

329 # setup averager 

330 Sq_averager = dict() 

331 for pair in pairs: 

332 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0 

333 

334 # main loop 

335 for frame in traj: 

336 

337 # process_frame 

338 f2_rho(frame) 

339 logger.debug(f'Processing frame {frame.frame_index}') 

340 

341 for s1, s2 in pairs: 

342 # compute correlation 

343 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate()) 

344 if s1 != s2: 

345 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate()) 

346 Sq_averager[(s1, s2)].add_sample(0, Sq_pair) 

347 

348 # collect results 

349 data_dict = dict() 

350 data_dict['q_points'] = q_points 

351 S_q_tot = np.zeros((n_qpoints, 1)) 

352 for s1, s2 in pairs: 

353 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1) 

354 data_dict[f'Sq_{s1}_{s2}'] = Sq 

355 S_q_tot += Sq 

356 data_dict['Sq'] = S_q_tot 

357 

358 # finalize results 

359 result = StaticSample(data_dict, atom_types=traj.atom_types, pairs=pairs, 

360 particle_counts=particle_counts, cell=traj.cell, 

361 number_of_frames=traj.number_of_frames_read) 

362 return result 

363 

364 

365def compute_spectral_energy_density( 

366 traj: Trajectory, 

367 ideal_supercell: Atoms, 

368 primitive_cell: Atoms, 

369 q_points: NDArray[float], 

370 dt: float, 

371) -> Tuple[NDArray[float], NDArray[float]]: 

372 r""" 

373 Compute the spectral energy density (SED) at specific q-points. 

374 The results are returned in the form of a tuple, which comprises the 

375 angular frequencies in an array of length ``N_times`` in units of 2π/fs 

376 and the SED in units of Da*(Å/fs)² as an array of shape ``(N_qpoints, N_times)``. 

377 

378 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010), 

379 which should be cited when using this function along with the dynasor reference. 

380 

381 **Note 1:** 

382 SED analysis is only suitable for crystalline materials without diffusion as 

383 atoms are assumed to move around fixed reference positions throughout the entire trajectory. 

384 

385 **Note 2:** 

386 This implementation reads the full trajectory and can thus consume a lot of memory. 

387 

388 Parameters 

389 ---------- 

390 traj 

391 Input trajectory 

392 ideal_supercell 

393 Ideal structure defining the reference positions 

394 primitive_cell 

395 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`. 

396 q_points 

397 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

398 dt 

399 Time difference in femtoseconds between two consecutive snapshots in 

400 the trajectory. Note that you should not change :attr:`dt` if you change 

401 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

402 """ 

403 

404 delta_t = traj.frame_step * dt 

405 

406 # logger 

407 logger.info('Running SED') 

408 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs') 

409 

410 # check that the ideal supercell agrees with traj 

411 if traj.n_atoms != len(ideal_supercell): 

412 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.') 

413 

414 # colllect all velocities, and scale with sqrt(masses) 

415 masses = ideal_supercell.get_masses().reshape(-1, 1) 

416 velocities = [] 

417 for it, frame in enumerate(traj): 

418 logger.debug(f'Reading frame {it}') 

419 if frame.velocities_by_type is None: 

420 raise ValueError(f'Could not read velocities from frame {it}') 

421 v = frame.get_velocities_as_array(traj.atomic_indices) 

422 velocities.append(np.sqrt(masses) * v) 

423 

424 N_samples = len(velocities) 

425 velocities = np.array(velocities) 

426 velocities = velocities.transpose(1, 2, 0).copy() 

427 velocities = np.fft.rfft(velocities, axis=2) 

428 

429 # calculate SED 

430 indices, offsets = get_index_offset(ideal_supercell, primitive_cell) 

431 

432 pos = np.dot(q_points, np.dot(offsets, primitive_cell.cell).T) 

433 exppos = np.exp(1.0j * pos) 

434 density = np.zeros((len(q_points), velocities.shape[2])) 

435 for alpha in range(3): 

436 for b in range(len(primitive_cell)): 

437 tmp = np.zeros(density.shape, dtype=complex) 

438 for i in range(len(indices)): 

439 index = indices[i] 

440 if index != b: 

441 continue 

442 tmp += np.outer(exppos[:, i], velocities[i, alpha]) 

443 

444 density += np.abs(tmp)**2 

445 

446 # angular frequencies 

447 w = 2 * np.pi * np.fft.rfftfreq(N_samples, delta_t) # radians/fs 

448 

449 return w, density