Coverage for local_installation/dynasor/correlation_functions.py: 100%

252 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-30 21:04 +0000

1import concurrent 

2from functools import partial 

3from itertools import combinations_with_replacement 

4from typing import Tuple 

5 

6import numpy as np 

7from ase import Atoms 

8from numpy.typing import NDArray 

9 

10from dynasor.logging_tools import logger 

11from dynasor.trajectory import Trajectory, WindowIterator 

12from dynasor.sample import DynamicSample, StaticSample 

13from dynasor.post_processing import fourier_cos 

14from dynasor.core.time_averager import TimeAverager 

15from dynasor.core.reciprocal import calc_rho_q, calc_rho_j_q 

16from dynasor.qpoints.tools import get_index_offset 

17from dynasor.units import radians_per_fs_to_meV 

18 

19 

20def compute_dynamic_structure_factors( 

21 traj: Trajectory, 

22 q_points: NDArray[float], 

23 dt: float, 

24 window_size: int, 

25 window_step: int = 1, 

26 calculate_currents: bool = False, 

27 calculate_incoherent: bool = False, 

28) -> DynamicSample: 

29 """Compute dynamic structure factors. The results are returned in the 

30 form of a :class:`DynamicSample <dynasor.sample.DynamicSample>` 

31 object. 

32 

33 Parameters 

34 ---------- 

35 traj 

36 Input trajectory 

37 q_points 

38 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

39 dt 

40 Time difference in femtoseconds between two consecutive snapshots 

41 in the trajectory. Note that you should *not* change :attr:`dt` if you change 

42 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

43 window_size 

44 Length of the trajectory frame window to use for time correlation calculation. 

45 It is expressed in terms of the number of time lags to consider 

46 and thus determines the smallest frequency resolved. 

47 window_step 

48 Window step (or stride) given as the number of frames between consecutive trajectory 

49 windows. This parameter does *not* affect the time between consecutive frames in the 

50 calculation. If, e.g., :attr:`window_step` > :attr:`window_size`, some frames will not 

51 be used. 

52 calculate_currents 

53 Calculate the current correlations. Requires velocities to be available in :attr:`traj`. 

54 calculate_incoherent 

55 Calculate the incoherent part (self-part) of :math:`F_incoh`. 

56 """ 

57 # sanity check input args 

58 if q_points.shape[1] != 3: 

59 raise ValueError('q-points array has the wrong shape.') 

60 if dt <= 0: 

61 raise ValueError(f'dt must be positive: dt= {dt}') 

62 if window_size <= 2: 

63 raise ValueError(f'window_size must be larger than 2: window_size= {window_size}') 

64 if window_size % 2 != 0: 

65 raise ValueError(f'window_size must be even: window_size= {window_size}') 

66 if window_step <= 0: 

67 raise ValueError(f'window_step must be positive: window_step= {window_step}') 

68 

69 # define internal parameters 

70 n_qpoints = q_points.shape[0] 

71 delta_t = traj.frame_step * dt 

72 N_tc = window_size + 1 

73 

74 # log all setup information 

75 dw = 2 * np.pi / (window_size * delta_t) 

76 w_max = dw * window_size 

77 w_N = 2 * np.pi / (2 * delta_t) # Nyquist angular frequency 

78 

79 logger.info(f'Spacing between samples (frame_step): {traj.frame_step}') 

80 logger.info(f'Time between consecutive frames in input trajectory (dt): {dt} fs') 

81 logger.info(f'Time between consecutive frames used (dt * frame_step): {delta_t} fs') 

82 logger.info(f'Time window size (dt * frame_step * window_size): {delta_t * window_size:.1f} fs') 

83 logger.info(f'Angular frequency resolution: dw = {dw:.6f} rad/fs = ' 

84 f'{dw * radians_per_fs_to_meV:.3f} meV') 

85 logger.info(f'Maximum angular frequency (dw * window_size):' 

86 f' {w_max:.6f} rad/fs = {w_max * radians_per_fs_to_meV:.3f} meV') 

87 logger.info(f'Nyquist angular frequency (2pi / frame_step / dt / 2):' 

88 f' {w_N:.6f} rad/fs = {w_N * radians_per_fs_to_meV:.3f} meV') 

89 

90 if calculate_currents: 

91 logger.info('Calculating current (velocity) correlations') 

92 if calculate_incoherent: 

93 logger.info('Calculating incoherent part (self-part) of correlations') 

94 

95 # log some info regarding q-points 

96 logger.info(f'Number of q-points: {n_qpoints}') 

97 

98 q_directions = q_points.copy() 

99 q_distances = np.linalg.norm(q_points, axis=1) 

100 nonzero = q_distances > 0 

101 q_directions[nonzero] /= q_distances[nonzero].reshape(-1, 1) 

102 

103 # setup functions to process frames 

104 def f2_rho(frame): 

105 rho_qs_dict = dict() 

106 for atom_type in frame.positions_by_type.keys(): 

107 x = frame.positions_by_type[atom_type] 

108 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

109 frame.rho_qs_dict = rho_qs_dict 

110 return frame 

111 

112 def f2_rho_and_j(frame): 

113 rho_qs_dict = dict() 

114 jz_qs_dict = dict() 

115 jper_qs_dict = dict() 

116 

117 for atom_type in frame.positions_by_type.keys(): 

118 x = frame.positions_by_type[atom_type] 

119 v = frame.velocities_by_type[atom_type] 

120 rho_qs, j_qs = calc_rho_j_q(x, v, q_points) 

121 jz_qs = np.sum(j_qs * q_directions, axis=1) 

122 jper_qs = j_qs - (jz_qs[:, None] * q_directions) 

123 

124 rho_qs_dict[atom_type] = rho_qs 

125 jz_qs_dict[atom_type] = jz_qs 

126 jper_qs_dict[atom_type] = jper_qs 

127 

128 frame.rho_qs_dict = rho_qs_dict 

129 frame.jz_qs_dict = jz_qs_dict 

130 frame.jper_qs_dict = jper_qs_dict 

131 return frame 

132 

133 if calculate_currents: 

134 element_processor = f2_rho_and_j 

135 else: 

136 element_processor = f2_rho 

137 

138 # setup window iterator 

139 window_iterator = WindowIterator(traj, width=N_tc, window_step=window_step, 

140 element_processor=element_processor) 

141 

142 # define all pairs 

143 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

144 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

145 logger.debug('Considering pairs:') 

146 for pair in pairs: 

147 logger.debug(f' {pair}') 

148 

149 # setup all time averager instances 

150 F_q_t_averager = dict() 

151 for pair in pairs: 

152 F_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

153 if calculate_currents: 

154 Cl_q_t_averager = dict() 

155 Ct_q_t_averager = dict() 

156 for pair in pairs: 

157 Cl_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

158 Ct_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

159 if calculate_incoherent: 

160 F_s_q_t_averager = dict() 

161 for pair in traj.atom_types: 

162 F_s_q_t_averager[pair] = TimeAverager(N_tc, n_qpoints) 

163 

164 # define correlation function 

165 def calc_corr(window, time_i): 

166 # Calculate correlations between two frames in the window without normalization 1/N 

167 f0 = window[0] 

168 fi = window[time_i] 

169 for s1, s2 in pairs: 

170 Fqt = np.real(f0.rho_qs_dict[s1] * fi.rho_qs_dict[s2].conjugate()) 

171 if s1 != s2: 

172 Fqt += np.real(f0.rho_qs_dict[s2] * fi.rho_qs_dict[s1].conjugate()) 

173 F_q_t_averager[(s1, s2)].add_sample(time_i, Fqt) 

174 

175 if calculate_currents: 

176 for s1, s2 in pairs: 

177 Clqt = np.real(f0.jz_qs_dict[s1] * fi.jz_qs_dict[s2].conjugate()) 

178 Ctqt = 0.5 * np.real(np.sum(f0.jper_qs_dict[s1] * 

179 fi.jper_qs_dict[s2].conjugate(), axis=1)) 

180 if s1 != s2: 

181 Clqt += np.real(f0.jz_qs_dict[s2] * fi.jz_qs_dict[s1].conjugate()) 

182 Ctqt += 0.5 * np.real(np.sum(f0.jper_qs_dict[s2] * 

183 fi.jper_qs_dict[s1].conjugate(), axis=1)) 

184 

185 Cl_q_t_averager[(s1, s2)].add_sample(time_i, Clqt) 

186 Ct_q_t_averager[(s1, s2)].add_sample(time_i, Ctqt) 

187 

188 if calculate_incoherent: 

189 for atom_type in traj.atom_types: 

190 xi = fi.positions_by_type[atom_type] 

191 x0 = f0.positions_by_type[atom_type] 

192 Fsqt = np.real(calc_rho_q(xi - x0, q_points)) 

193 F_s_q_t_averager[atom_type].add_sample(time_i, Fsqt) 

194 

195 # run calculation 

196 with concurrent.futures.ThreadPoolExecutor() as tpe: 

197 # This is the "main loop" over the trajectory 

198 for window in window_iterator: 

199 logger.debug(f'processing window {window[0].frame_index} to {window[-1].frame_index}') 

200 

201 # The map conviniently applies calc_corr to all time-lags. However, 

202 # as everything is done in place nothing gets returned so in order 

203 # to start and wait for the processes to finish we must iterate 

204 # over the None values returned 

205 for _ in tpe.map(partial(calc_corr, window), range(len(window))): 

206 pass 

207 

208 # collect results into dict with numpy arrays (n_qpoints, N_tc) 

209 data_dict_corr = dict() 

210 time = delta_t * np.arange(N_tc, dtype=float) 

211 data_dict_corr['q_points'] = q_points 

212 data_dict_corr['time'] = time 

213 

214 F_q_t_tot = np.zeros((n_qpoints, N_tc)) 

215 S_q_w_tot = np.zeros((n_qpoints, N_tc)) 

216 for pair in pairs: 

217 key = '_'.join(pair) 

218 F_q_t = 1 / traj.n_atoms * F_q_t_averager[pair].get_average_all() 

219 w, S_q_w = zip(*[fourier_cos(F, delta_t) for F in F_q_t]) 

220 w = w[0] 

221 S_q_w = np.array(S_q_w) 

222 data_dict_corr['omega'] = w 

223 data_dict_corr[f'Fqt_coh_{key}'] = F_q_t 

224 data_dict_corr[f'Sqw_coh_{key}'] = S_q_w 

225 

226 # sum all partials to the total 

227 F_q_t_tot += F_q_t 

228 S_q_w_tot += S_q_w 

229 data_dict_corr['Fqt_coh'] = F_q_t_tot 

230 data_dict_corr['Sqw_coh'] = S_q_w_tot 

231 

232 if calculate_currents: 

233 Cl_q_t_tot = np.zeros((n_qpoints, N_tc)) 

234 Ct_q_t_tot = np.zeros((n_qpoints, N_tc)) 

235 Cl_q_w_tot = np.zeros((n_qpoints, N_tc)) 

236 Ct_q_w_tot = np.zeros((n_qpoints, N_tc)) 

237 for pair in pairs: 

238 key = '_'.join(pair) 

239 Cl_q_t = 1 / traj.n_atoms * Cl_q_t_averager[pair].get_average_all() 

240 Ct_q_t = 1 / traj.n_atoms * Ct_q_t_averager[pair].get_average_all() 

241 Cl_q_w = np.array([fourier_cos(C, delta_t)[1] for C in Cl_q_t]) 

242 Ct_q_w = np.array([fourier_cos(C, delta_t)[1] for C in Ct_q_t]) 

243 data_dict_corr[f'Clqt_{key}'] = Cl_q_t 

244 data_dict_corr[f'Ctqt_{key}'] = Ct_q_t 

245 data_dict_corr[f'Clqw_{key}'] = Cl_q_w 

246 data_dict_corr[f'Ctqw_{key}'] = Ct_q_w 

247 

248 # sum all partials to the total 

249 Cl_q_t_tot += Cl_q_t 

250 Ct_q_t_tot += Ct_q_t 

251 Cl_q_w_tot += Cl_q_w 

252 Ct_q_w_tot += Ct_q_w 

253 data_dict_corr['Clqt'] = Cl_q_t_tot 

254 data_dict_corr['Ctqt'] = Ct_q_t_tot 

255 data_dict_corr['Clqw'] = Cl_q_w_tot 

256 data_dict_corr['Ctqw'] = Ct_q_w_tot 

257 

258 if calculate_incoherent: 

259 Fs_q_t_tot = np.zeros((n_qpoints, N_tc)) 

260 Ss_q_w_tot = np.zeros((n_qpoints, N_tc)) 

261 for atom_type in traj.atom_types: 

262 Fs_q_t = 1 / traj.n_atoms * F_s_q_t_averager[atom_type].get_average_all() 

263 Ss_q_w = np.array([fourier_cos(F, delta_t)[1] for F in Fs_q_t]) 

264 data_dict_corr[f'Fqt_incoh_{atom_type}'] = Fs_q_t 

265 data_dict_corr[f'Sqw_incoh_{atom_type}'] = Ss_q_w 

266 

267 # sum all partials to the total 

268 Fs_q_t_tot += Fs_q_t 

269 Ss_q_w_tot += Ss_q_w 

270 

271 data_dict_corr['Fqt_incoh'] = Fs_q_t_tot 

272 data_dict_corr['Sqw_incoh'] = Ss_q_w_tot 

273 

274 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'] + data_dict_corr['Fqt_incoh'] 

275 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'] + data_dict_corr['Sqw_incoh'] 

276 else: 

277 data_dict_corr['Fqt'] = data_dict_corr['Fqt_coh'].copy() 

278 data_dict_corr['Sqw'] = data_dict_corr['Sqw_coh'].copy() 

279 

280 # finalize results with additional meta data 

281 result = DynamicSample(data_dict_corr, atom_types=traj.atom_types, pairs=pairs, 

282 particle_counts=particle_counts, cell=traj.cell, 

283 time_between_frames=delta_t, 

284 maximum_time_lag=delta_t * window_size, 

285 angular_frequency_resolution=dw, 

286 maximum_angular_frequency=w_max, 

287 number_of_frames=traj.number_of_frames_read) 

288 

289 return result 

290 

291 

292def compute_static_structure_factors( 

293 traj: Trajectory, 

294 q_points: NDArray[float], 

295) -> StaticSample: 

296 r"""Compute static structure factors. The results are returned in the 

297 form of a :class:`StaticSample <dynasor.sample.StaticSample>` 

298 object. 

299 

300 Parameters 

301 ---------- 

302 traj 

303 Input trajectory 

304 q_points 

305 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

306 """ 

307 # sanity check input args 

308 if q_points.shape[1] != 3: 

309 raise ValueError('q-points array has the wrong shape.') 

310 

311 n_qpoints = q_points.shape[0] 

312 logger.info(f'Number of q-points: {n_qpoints}') 

313 

314 # define all pairs 

315 pairs = list(combinations_with_replacement(traj.atom_types, r=2)) 

316 particle_counts = {key: len(val) for key, val in traj.atomic_indices.items()} 

317 logger.debug('Considering pairs:') 

318 for pair in pairs: 

319 logger.debug(f' {pair}') 

320 

321 # processing function 

322 def f2_rho(frame): 

323 rho_qs_dict = dict() 

324 for atom_type in frame.positions_by_type.keys(): 

325 x = frame.positions_by_type[atom_type] 

326 rho_qs_dict[atom_type] = calc_rho_q(x, q_points) 

327 frame.rho_qs_dict = rho_qs_dict 

328 return frame 

329 

330 # setup averager 

331 Sq_averager = dict() 

332 for pair in pairs: 

333 Sq_averager[pair] = TimeAverager(1, n_qpoints) # time average with only timelag=0 

334 

335 # main loop 

336 for frame in traj: 

337 

338 # process_frame 

339 f2_rho(frame) 

340 logger.debug(f'Processing frame {frame.frame_index}') 

341 

342 for s1, s2 in pairs: 

343 # compute correlation 

344 Sq_pair = np.real(frame.rho_qs_dict[s1] * frame.rho_qs_dict[s2].conjugate()) 

345 if s1 != s2: 

346 Sq_pair += np.real(frame.rho_qs_dict[s2] * frame.rho_qs_dict[s1].conjugate()) 

347 Sq_averager[(s1, s2)].add_sample(0, Sq_pair) 

348 

349 # collect results 

350 data_dict = dict() 

351 data_dict['q_points'] = q_points 

352 S_q_tot = np.zeros((n_qpoints, 1)) 

353 for s1, s2 in pairs: 

354 Sq = 1 / traj.n_atoms * Sq_averager[(s1, s2)].get_average_at_timelag(0).reshape(-1, 1) 

355 data_dict[f'Sq_{s1}_{s2}'] = Sq 

356 S_q_tot += Sq 

357 data_dict['Sq'] = S_q_tot 

358 

359 # finalize results 

360 result = StaticSample(data_dict, atom_types=traj.atom_types, pairs=pairs, 

361 particle_counts=particle_counts, cell=traj.cell, 

362 number_of_frames=traj.number_of_frames_read) 

363 return result 

364 

365 

366def compute_spectral_energy_density( 

367 traj: Trajectory, 

368 ideal_supercell: Atoms, 

369 primitive_cell: Atoms, 

370 q_points: NDArray[float], 

371 dt: float, 

372) -> Tuple[NDArray[float], NDArray[float]]: 

373 r""" 

374 Compute the spectral energy density (SED) at specific q-points. 

375 The results are returned in the form of a tuple, which comprises the 

376 angular frequencies in an array of length ``N_times`` in units of 2π/fs 

377 and the SED in units of Da*(Å/fs)² as an array of shape ``(N_qpoints, N_times)``. 

378 

379 More details can be found in Thomas *et al.*, Physical Review B **81**, 081411 (2010), 

380 which should be cited when using this function along with the dynasor reference. 

381 

382 **Note 1:** 

383 SED analysis is only suitable for crystalline materials without diffusion as 

384 atoms are assumed to move around fixed reference positions throughout the entire trajectory. 

385 

386 **Note 2:** 

387 This implementation reads the full trajectory and can thus consume a lot of memory. 

388 

389 Parameters 

390 ---------- 

391 traj 

392 Input trajectory 

393 ideal_supercell 

394 Ideal structure defining the reference positions 

395 primitive_cell 

396 Underlying primitive structure. Must be aligned correctly with :attr:`ideal_supercell`. 

397 q_points 

398 Array of q-points in units of 2π/Å with shape ``(N_qpoints, 3)`` in Cartesian coordinates 

399 dt 

400 Time difference in femtoseconds between two consecutive snapshots in 

401 the trajectory. Note that you should not change :attr:`dt` if you change 

402 :attr:`frame_step <dynasor.trajectory.Trajectory.frame_step>` in :attr:`traj`. 

403 """ 

404 

405 delta_t = traj.frame_step * dt 

406 

407 # logger 

408 logger.info('Running SED') 

409 logger.info(f'Time between consecutive frames (dt * frame_step): {delta_t} fs') 

410 

411 # check that the ideal supercell agrees with traj 

412 if traj.n_atoms != len(ideal_supercell): 

413 raise ValueError('ideal_supercell must contain the same number of atoms as the trajectory.') 

414 

415 # colllect all velocities 

416 velocities = [] 

417 for it, frame in enumerate(traj): 

418 logger.debug(f'Reading frame {it}') 

419 if frame.velocities_by_type is None: 

420 raise ValueError(f'Could not read velocities from frame {it}') 

421 v = frame.get_velocities_as_array(traj.atomic_indices) 

422 velocities.append(v) 

423 

424 velocities = np.array(velocities) 

425 velocities = velocities.transpose(1, 2, 0).copy() 

426 velocities = np.fft.fft(velocities, axis=2) 

427 

428 # calculate SED 

429 masses = primitive_cell.get_masses() 

430 indices, offsets = get_index_offset(ideal_supercell, primitive_cell) 

431 

432 pos = np.dot(q_points, np.dot(offsets, primitive_cell.cell).T) 

433 exppos = np.exp(1.0j * pos) 

434 density = np.zeros((len(q_points), velocities.shape[2])) 

435 for alpha in range(3): 

436 for b in range(len(masses)): 

437 tmp = np.zeros(density.shape, dtype=complex) 

438 for i in range(len(indices)): 

439 index = indices[i] 

440 if index != b: 

441 continue 

442 tmp += np.outer(exppos[:, i], velocities[i, alpha]) 

443 

444 density += masses[b] * np.abs(tmp)**2 

445 

446 # angular frequencies 

447 w = np.linspace(0.0, 2 * np.pi / delta_t, density.shape[1]) # units of 2pi/fs 

448 

449 return w, density