Coverage for dynasor / post_processing / atomic_weighting.py: 99%

75 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 12:31 +0000

1from copy import deepcopy 

2from typing import Optional 

3from warnings import warn 

4import numpy as np 

5from dynasor.post_processing.weights import Weights 

6from dynasor.sample import Sample, StaticSample, DynamicSample 

7from numpy.typing import NDArray 

8 

9 

10def get_weighted_sample(sample: Sample, 

11 weights: Weights, 

12 atom_type_map: Optional[dict[str, str]] = None) -> Sample: 

13 r""" 

14 Weights correlation functions with atomic weighting factors 

15 

16 The weighting of a partial dynamic structure factor 

17 :math:`S_\mathrm{AB}(\boldsymbol{q}, \omega)` 

18 for atom types :math:`A` and :math:`B` is carried out as 

19 

20 .. math:: 

21 

22 S_\mathrm{AB}(\boldsymbol{q}, \omega) 

23 = f_\mathrm{A}(\boldsymbol{q}) f_\mathrm{B}(\boldsymbol{q}) 

24 S_\mathrm{AB}(\boldsymbol{q}, \omega) 

25 

26 :math:`f_\mathrm{A}(\boldsymbol{q})` and :math:`f_\mathrm{B}(\boldsymbol{q})` 

27 are atom-type and :math:`\boldsymbol{q}`-point dependent weights. 

28 

29 If sample has incoherent correlation functions, but :attr:`weights` does not contain 

30 information on how to weight the incoherent part, then it will be dropped from the 

31 returned :attr:`Sample` object (and analogously for current correlation functions). 

32 

33 Parameters 

34 ---------- 

35 sample 

36 Input sample to be weighted. 

37 weights 

38 Object containing the weights :math:`f_\mathrm{X}(\boldsymbol{q})`. 

39 atom_type_map 

40 Map between the atom types in the :class:`Sample` and the ones used in 

41 the :class:`Weights` object, e.g., Ba → Ba(2+). 

42 

43 Returns 

44 ------- 

45 A :class:`Sample` instance with the weighted partial and total structure factors. 

46 """ 

47 

48 # check input arguments 

49 if sample.has_incoherent and not weights.supports_incoherent: 

50 warn('The Weights class does not support incoherent scattering, dropping the latter ' 

51 'from the weighted sample.') 

52 

53 if sample.has_currents and not weights.supports_currents: 

54 warn('The Weights class does not support current correlations, dropping the latter ' 

55 'from the weighted sample.') 

56 

57 # setup new input dicts for new Sample 

58 data_dict = dict() 

59 for key in sample.dimensions: 

60 data_dict[key] = sample[key] 

61 

62 # Map the atom types in the sample to the types in the weights object. 

63 # Useful, for instance, when using weights for charged atomic species. 

64 atom_types = [(at, at) for at in sample.atom_types] 

65 if atom_type_map is not None: 

66 # Fallback to use the atom type from species (`at`) if it's not mapped. 

67 atom_types = [(at, atom_type_map.get(at, at)) for at in atom_type_map] 

68 

69 # generate atomic weights for each q-point and compile to arrays 

70 if 'q_norms' in sample.dimensions: 

71 q_norms = sample.q_norms 

72 else: 

73 q_norms = np.linalg.norm(sample.q_points, axis=1) 

74 

75 weights_coh = dict() 

76 for at, weight_at in atom_types: 

77 weight_array = np.reshape([weights.get_weight_coh(weight_at, q) for q in q_norms], (-1, 1)) 

78 weights_coh[at] = weight_array 

79 if sample.has_incoherent and weights.supports_incoherent: 

80 weights_incoh = dict() 

81 for at, weight_at in atom_types: 

82 weight_array = np.reshape([ 

83 weights.get_weight_incoh(weight_at, q) for q in q_norms 

84 ], (-1, 1)) 

85 weights_incoh[at] = weight_array 

86 

87 # weighting of correlation functions 

88 if isinstance(sample, StaticSample): 

89 data_dict_Sq = _compute_weighting_coherent(sample, 'Sq', weights_coh) 

90 data_dict.update(data_dict_Sq) 

91 elif isinstance(sample, DynamicSample): 91 ↛ 117line 91 didn't jump to line 117 because the condition on line 91 was always true

92 # coherent 

93 Fqt_coh_dict = _compute_weighting_coherent(sample, 'Fqt_coh', weights_coh) 

94 data_dict.update(Fqt_coh_dict) 

95 Sqw_coh_dict = _compute_weighting_coherent(sample, 'Sqw_coh', weights_coh) 

96 data_dict.update(Sqw_coh_dict) 

97 

98 # incoherent 

99 if sample.has_incoherent and weights.supports_incoherent: 

100 Fqt_incoh_dict = _compute_weighting_incoherent(sample, 'Fqt_incoh', weights_incoh) 

101 data_dict.update(Fqt_incoh_dict) 

102 Sqw_incoh_dict = _compute_weighting_incoherent(sample, 'Sqw_incoh', weights_incoh) 

103 data_dict.update(Sqw_incoh_dict) 

104 

105 # currents 

106 if sample.has_currents and weights.supports_currents: 

107 Clqt_dict = _compute_weighting_coherent(sample, 'Clqt', weights_coh) 

108 data_dict.update(Clqt_dict) 

109 Clqw_dict = _compute_weighting_coherent(sample, 'Clqw', weights_coh) 

110 data_dict.update(Clqw_dict) 

111 

112 Ctqt_dict = _compute_weighting_coherent(sample, 'Ctqt', weights_coh) 

113 data_dict.update(Ctqt_dict) 

114 Ctqw_dict = _compute_weighting_coherent(sample, 'Ctqw', weights_coh) 

115 data_dict.update(Ctqw_dict) 

116 

117 new_sample = sample.__class__( 

118 data_dict, 

119 simulation_data=deepcopy(sample.simulation_data), 

120 history=deepcopy(sample.history)) 

121 new_sample._append_history( 

122 'get_weighted_sample', 

123 dict( 

124 atom_type_map=atom_type_map, 

125 weights_class=weights.__class__.__name__, 

126 weights_parameters=weights.parameters.to_dict(), 

127 )) 

128 

129 return new_sample 

130 

131 

132def _compute_weighting_coherent( 

133 sample: Sample, 

134 name: str, 

135 weight_dict: dict, 

136) -> dict[str, NDArray[float]]: 

137 """ 

138 Helper function for weighting and summing partial coherent correlation functions. 

139 """ 

140 data_dict = dict() 

141 total = np.zeros(sample[name].shape) 

142 for s1, s2 in sample.pairs: 

143 key_pair = f'{name}_{s1}_{s2}' 

144 partial = np.real(np.conjugate(weight_dict[s1]) * weight_dict[s2]) * sample[key_pair] 

145 data_dict[key_pair] = partial 

146 total += partial 

147 data_dict[name] = total 

148 return data_dict 

149 

150 

151def _compute_weighting_incoherent( 

152 sample: Sample, 

153 name: str, 

154 weight_dict: dict, 

155) -> dict[str, NDArray[float]]: 

156 """ 

157 Helper function for weighting and summing partial incoherent correlation functions. 

158 """ 

159 data_dict = dict() 

160 total = np.zeros(sample[name].shape) 

161 for s1 in sample.atom_types: 

162 key = f'{name}_{s1}' 

163 partial = np.real(np.conjugate(weight_dict[s1]) * weight_dict[s1]) * sample[key] 

164 data_dict[key] = partial 

165 total += partial 

166 data_dict[name] = total 

167 return data_dict