Coverage for local_installation/dynasor/sample.py: 92%

123 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-12-21 12:02 +0000

1import numpy as np 

2import pandas as pd 

3from typing import Dict, Any 

4 

5 

6class Sample: 

7 """ 

8 Class for holding correlation functions and some additional meta data. 

9 Sample objects can be written to and read from file. 

10 

11 Parameters 

12 ---------- 

13 data_dict 

14 Dictionary with correlation functions. 

15 meta_data 

16 Dictionary with meta data, for example atom-types, simulation cell, number of atoms, 

17 time stamps, user names, etc. 

18 """ 

19 

20 def __init__(self, data_dict: Dict[str, Any], **meta_data: Dict[str, Any]): 

21 

22 # set data dict as attributes 

23 self._data_keys = list(data_dict) 

24 for key in data_dict: 

25 setattr(self, key, data_dict[key]) 

26 

27 # set system parameters 

28 self.meta_data = meta_data 

29 self._atom_types = meta_data['atom_types'] 

30 self._pairs = meta_data['pairs'] 

31 self._particle_counts = meta_data['particle_counts'] 

32 self._cell = meta_data['cell'] 

33 

34 def __getitem__(self, key): 

35 """ Makes it possible to get the attributes using Sample['key'] """ 

36 try: 

37 return getattr(self, key) 

38 except AttributeError: 

39 raise KeyError(key) 

40 

41 def write_to_npz(self, fname: str): 

42 """ Write object to file in numpy's npz format. 

43 

44 Parameters 

45 ---------- 

46 fname 

47 Name of the file in which to store the Sample object. 

48 """ 

49 data_to_save = dict(name=self.__class__.__name__) 

50 data_to_save['meta_data'] = self.meta_data 

51 data_dict = dict() 

52 for key in self._data_keys: 

53 data_dict[key] = getattr(self, key) 

54 data_to_save['data_dict'] = data_dict 

55 np.savez_compressed(fname, **data_to_save) 

56 

57 @property 

58 def available_correlation_functions(self): 

59 """ All the available correlation functions in sample. """ 

60 keys_to_skip = set(['q_points', 'q_norms', 'time', 'omega']) 

61 return sorted(list(set(self._data_keys) - keys_to_skip)) 

62 

63 @property 

64 def dimensions(self): 

65 r"""The dimensions for the samples, e.g., for :math:`S(q, \omega)` 

66 the dimensions would be the :math:`q` and :math:`\omega` axes. 

67 """ 

68 keys_to_skip = set(self.available_correlation_functions) 

69 return sorted(list(set(self._data_keys) - keys_to_skip)) 

70 

71 @property 

72 def atom_types(self): 

73 if self._atom_types is None: 73 ↛ 74line 73 didn't jump to line 74, because the condition on line 73 was never true

74 return None 

75 return self._atom_types.copy() 

76 

77 @property 

78 def particle_counts(self): 

79 if self._particle_counts is None: 79 ↛ 80line 79 didn't jump to line 80, because the condition on line 79 was never true

80 return None 

81 return self._particle_counts.copy() 

82 

83 @property 

84 def pairs(self): 

85 if self._pairs is None: 85 ↛ 86line 85 didn't jump to line 86, because the condition on line 85 was never true

86 return None 

87 return self._pairs.copy() 

88 

89 @property 

90 def cell(self): 

91 if self._cell is None: 91 ↛ 92line 91 didn't jump to line 92, because the condition on line 91 was never true

92 return None 

93 return self._cell.copy() 

94 

95 def __repr__(self): 

96 return str(self) 

97 

98 def __str__(self): 

99 s_contents = [self.__class__.__name__] 

100 s_contents.append(f'Atom types: {self.atom_types}') 

101 s_contents.append(f'Pairs: {self.pairs}') 

102 s_contents.append(f'Particle counts: {self.particle_counts}') 

103 s_contents.append('Simulations cell:') 

104 s_contents.append(f'{self.cell}') 

105 for key in self.dimensions: 

106 s_i = f'{key:15} with shape: {np.shape(getattr(self, key))}' 

107 s_contents.append(s_i) 

108 for key in self.available_correlation_functions: 

109 s_i = f'{key:15} with shape: {np.shape(getattr(self, key))}' 

110 s_contents.append(s_i) 

111 s = '\n'.join(s_contents) 

112 return s 

113 

114 def _repr_html_(self) -> str: 

115 s = [f'<h3>{self.__class__.__name__}</h3>'] 

116 s += ['<table border="1" class="dataframe">'] 

117 s += ['<thead><tr><th style="text-align: left">Field</th>' 

118 '<th>Content/Size</th></tr></thead>'] 

119 s += ['<tbody>'] 

120 s += ['<tr><td style="text-align: left">Atom types</td>' 

121 f'<td>{self.atom_types}</td></tr>'] 

122 s += ['<tr><td style="text-align: left">Pairs</td>' 

123 f'<td>{self.pairs}</td></tr>'] 

124 s += ['<tr><td style="text-align: left">Particle counts</td>' 

125 f'<td>{self.particle_counts}</td></tr>'] 

126 s += ['<tr><td style="text-align: left">Simulations cell</td>' 

127 f'<td>{self.cell}</td></tr>'] 

128 for key in self._data_keys: 

129 s += [f'<tr><td style="text-align: left">{key}</td>' 

130 f'<td>{np.shape(getattr(self, key))}</td></tr>'] 

131 s += ['</tbody>'] 

132 s += ['</table>'] 

133 return '\n'.join(s) 

134 

135 @property 

136 def has_incoherent(self): 

137 """ Whether this sample contains the incoherent correlation functions or not. """ 

138 return False 

139 

140 @property 

141 def has_currents(self): 

142 """ Whether this sample contains the current correlation functions or not. """ 

143 return False 

144 

145 

146class StaticSample(Sample): 

147 

148 def to_dataframe(self): 

149 """ Returns correlation functions as pandas dataframe """ 

150 df = pd.DataFrame() 

151 for dim in self.dimensions: 

152 df[dim] = self[dim].tolist() # to list to make q-points (N, 3) work in dataframe 

153 for key in self.available_correlation_functions: 

154 df[key] = self[key].reshape(-1, ) 

155 return df 

156 

157 

158class DynamicSample(Sample): 

159 

160 @property 

161 def has_incoherent(self): 

162 return 'Fqt_incoh' in self.available_correlation_functions 

163 

164 @property 

165 def has_currents(self): 

166 pair_string = '_'.join(self.pairs[0]) 

167 return f'Clqt_{pair_string}' in self.available_correlation_functions 

168 

169 def to_dataframe(self, q_index: int): 

170 """ Returns correlation functions as pandas dataframe for the given q-index. 

171 

172 Parameters 

173 ---------- 

174 q_index 

175 index of q-point to return 

176 """ 

177 df = pd.DataFrame() 

178 for dim in self.dimensions: 

179 if dim in ['q_points', 'q_norms']: 

180 continue 

181 df[dim] = self[dim] 

182 for key in self.available_correlation_functions: 

183 df[key] = self[key][q_index] 

184 return df 

185 

186 

187def read_sample_from_npz(fname: str) -> Sample: 

188 """ Read :class:`Sample <dynasor.sample.Sample>` from file. 

189 

190 Parameters 

191 ---------- 

192 fname 

193 Path to the file (numpy npz format) from which to read 

194 the :class:`Sample <dynasor.sample.Sample>` object. 

195 """ 

196 data_read = np.load(fname, allow_pickle=True) 

197 data_dict = data_read['data_dict'].item() 

198 meta_data = data_read['meta_data'].item() 

199 if data_read['name'] == 'StaticSample': 199 ↛ 200line 199 didn't jump to line 200, because the condition on line 199 was never true

200 return StaticSample(data_dict, **meta_data) 

201 elif data_read['name'] == 'DynamicSample': 

202 return DynamicSample(data_dict, **meta_data) 

203 else: 

204 return Sample(data_dict, **meta_data)