Coverage for pioupiou/__init__.py: 100%

218 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-05-25 19:58 +0000

1# Python 3 Standard Library 

2import abc 

3import builtins 

4import inspect 

5import numbers 

6import operator 

7 

8# Third-Party Libraries 

9import numpy as np 

10import numpy.random as npr 

11import scipy.special 

12import scipy.stats 

13 

14# ------------------------------------------------------------------------------ 

15class Universe: 

16 def __init__(self): 

17 if hasattr(self, "rvs"): 

18 for rv in self.rvs: 

19 rv._valid = False 

20 self.rvs = [] 

21 self.n = 0 

22 seed = 0 

23 self.ss = np.random.SeedSequence(seed) 

24 self.rng = npr.default_rng(self.ss) 

25 

26 def __call__(self, size=None): 

27 if size is None: 

28 output_size = (self.n,) 

29 elif isinstance(size, int): 

30 output_size = (self.n, size) 

31 else: # tuple -- TODO check 

32 output_size = (self.n,) + size 

33 return self.rng.uniform(size=output_size) 

34 

35 

36Omega = Universe() 

37 

38 

39def restart(): 

40 Omega.__init__() 

41 

42 

43# ------------------------------------------------------------------------------ 

44class Error(Exception): 

45 pass 

46 

47class InvalidRandomVariable(Error): 

48 pass 

49 

50class InvalidSample(Error): 

51 pass 

52 

53 

54class RandomVariable(abc.ABC): 

55 def __init__(self): 

56 Omega.rvs.append(self) 

57 self._valid = True 

58 

59 def check(self, omega): 

60 if not self._valid: 

61 raise InvalidRandomVariable() 

62 if len(omega) != Omega.n: 

63 error = "sample omega does not match the current universe size" 

64 raise InvalidSample(error) 

65 

66 # Binary operators 

67 def __add__(self, other): 

68 return function(operator.add)( 

69 self, other 

70 ) # wrapped each and every time ? This is ugly. 

71 # at the moment I can't make it work otherwise. Probably because I don't understand 

72 # what wrapt is doing, I should probably get rid of it. 

73 # There is at least one level of nesting I can get rid of (the decoration 

74 # can be done at class definition time and the result stored in it). 

75 # 

76 # TODO: 

77 # - get rid of wrapt 

78 # - automate 

79 # - unroll the calls 

80 

81 __radd__ = __add__ 

82 

83 def __sub__(self, other): 

84 return function(operator.sub)(self, other) 

85 

86 def __rsub__(self, other): 

87 return function(operator.sub)(other, self) 

88 

89 def __mul__(self, other): 

90 return function(operator.mul)(self, other) 

91 

92 __rmul__ = __mul__ 

93 

94 def __truediv__(self, other): 

95 return function(operator.truediv)(self, other) 

96 

97 def __rtruediv__(self, other): 

98 return function(operator.truediv)(other, self) 

99 

100 def __floordiv__(self, other): 

101 return function(operator.floordiv)(self, other) 

102 

103 def __rfloordiv__(self, other): 

104 return function(operator.floordiv)(other, self) 

105 

106 def __pow__(self, other): 

107 return function(operator.pow)(self, other) 

108 

109 # TODO: divmod, pow, lshift, rshift, and, xor, or 

110 

111 def __lt__(self, other): 

112 return function(operator.lt)(self, other) 

113 

114 def __le__(self, other): 

115 return function(operator.le)(self, other) 

116 

117 def __eq__(self, other): 

118 return function(operator.eq)(self, other) 

119 

120 def __ne__(self, other): 

121 return function(operator.ne)(self, other) 

122 

123 def __ge__(self, other): 

124 return function(operator.ge)(self, other) 

125 

126 def __gt__(self, other): 

127 return function(operator.gt)(self, other) 

128 

129 # Unary operators 

130 def __neg__(self): 

131 return function(operator.neg)(self) 

132 

133 def __pos__(self): 

134 return function(operator.pos)(self) 

135 

136 # **BUG** won't work ... This is structural : __bool__ must return a bool. 

137 # This limitation should be documented ; this is tricky. You cannot randomize 

138 # a function that tests boolean outputs. Of course, you can still write it 

139 # in the low-level form, randomize the input yourself, do the sampling based 

140 # on the argument omega and THEN test on the samples. 

141 def __bool__(self): 

142 # return function(builtins.bool)(self) # this is probably very borked, right ? 

143 raise TypeError("you cannot use a random value where a boolean is required") 

144 

145 # TODO : abs, invert, complex, int, long, float, oct, hex. 

146 

147 

148# Mmm I don't really understand of `np.vectorize` can make functions with test 

149# (apparently) work ... I have to give it some thought :). The bottom line 

150# being that you cannot output anything from __bool__ but a true bool ... 

151# Ah, ok, I get it : the decoration merely delays the test evaluation, 

152# but the code of randomized functions only see "true" deterministic values ... 

153# This is a trick that should be carefully documented : "wrapping" the tests 

154# into randomized function will allow use to use random variables in tests. 

155 

156def function(f): 

157 def wrapped_function(*args, **kwargs): 

158 all_args = list(args) + list(kwargs.values()) 

159 if not any(isinstance(arg, RandomVariable) for arg in all_args): 

160 return f(*args, **kwargs) 

161 

162 class Deterministic(RandomVariable): 

163 def __init__(self, *args, **kwargs): # TODO: I'd like these args and 

164 # kwargs to have wrapped signature and be checked against it ... 

165 # Does it work by default ? 

166 super().__init__() 

167 self.args = [randomize(arg) for arg in args] 

168 self.kwargs = {k: randomize(v) for k, v in kwargs.items()} 

169 

170 def __call__(self, omega): 

171 self.check(omega) 

172 args_values = [arg(omega) for arg in self.args] 

173 kwargs_values = {k: v(omega) for k, v in kwargs.items()} 

174 return f(*args_values, **kwargs_values) 

175 

176 return Deterministic(*args, **kwargs) 

177 

178 return wrapped_function 

179 

180 

181# # Using the bool function is fine (as long as the result is not used in tests) 

182bool = function(builtins.bool) 

183 

184 

185class Constant(RandomVariable): 

186 def __init__(self, value): 

187 super().__init__() 

188 # Yep, the value of a constant can be randomized too. 

189 if isinstance(value, RandomVariable): 

190 self.rv = value 

191 else: 

192 self.rv = lambda u: value 

193 

194 def __call__(self, omega): 

195 self.check(omega) 

196 return self.rv(omega) 

197 

198 

199# Distributions 

200# ------------------------------------------------------------------------------ 

201class Uniform(RandomVariable): 

202 def __init__(self, a=0.0, b=1.0): 

203 super().__init__() 

204 self.n = Omega.n 

205 Omega.n += 1 

206 self.a = randomize(a) 

207 self.b = randomize(b) 

208 

209 def __call__(self, omega): 

210 self.check(omega) 

211 u_n = omega[self.n] # localized abstraction leak HERE. 

212 return self.a(omega) * (1 - u_n) + self.b(omega) * u_n 

213 

214 

215class Bernoulli(RandomVariable): 

216 def __init__(self, p=0.5): 

217 super().__init__() 

218 self.U = Uniform() 

219 self.P = randomize(p) 

220 

221 def __call__(self, omega): 

222 self.check(omega) 

223 u = self.U(omega) 

224 p = self.P(omega) 

225 return u <= p 

226 

227class Binomial(RandomVariable): 

228 def __init__(self, n, p=0.5): 

229 super().__init__() 

230 self.Bs = [Bernoulli(p) for _ in range(n)] 

231 

232 def __call__(self, omega): 

233 self.check(omega) 

234 bs = [B(omega) for B in self.Bs] 

235 return sum(bs) 

236 

237class Poisson(RandomVariable): 

238 def __init__(self, lambda_): 

239 super().__init__() 

240 self.L = randomize(lambda_) 

241 self.U = Uniform() 

242 

243 def __call__(self, omega): 

244 self.check(omega) 

245 lambda_ = self.L(omega) 

246 u = self.U(omega) 

247 # Source: scipy `_discrete_distns.py` 

248 vals = np.ceil(scipy.special.pdtrik(u, lambda_)) 

249 vals_minus_1 = np.maximum(vals - 1, 0) 

250 temp = scipy.special.pdtr(vals_minus_1, lambda_) 

251 return np.where(u <= temp, vals_minus_1, vals) 

252 

253class Normal(RandomVariable): 

254 def __init__(self, mu=0.0, sigma2=1.0): 

255 super().__init__() 

256 self.U = Uniform() 

257 self.mu = randomize(mu) 

258 self.sigma2 = randomize(sigma2) 

259 

260 def __call__(self, omega): 

261 self.check(omega) 

262 u = self.U(omega) 

263 mu = self.mu(omega) 

264 sigma = np.sqrt(self.sigma2(omega)) 

265 return scipy.special.erfinv(2 * u - 1) * np.sqrt(2) * sigma + mu 

266 

267 

268class Exponential(RandomVariable): 

269 def __init__(self, lambda_=1.0): 

270 super().__init__() 

271 self.U = Uniform() 

272 self.lambda_ = randomize(lambda_) 

273 

274 def __call__(self, omega): 

275 self.check(omega) 

276 u = self.U(omega) 

277 lambda_ = self.lambda_(omega) 

278 return -np.log(1 - u) / lambda_ 

279 

280 

281class Cauchy(RandomVariable): 

282 def __init__(self, x0=0.0, gamma=1.0): 

283 super().__init__() 

284 self.U = Uniform() 

285 self.x0 = randomize(x0) 

286 self.gamma = randomize(gamma) 

287 

288 def __call__(self, omega): 

289 self.check(omega) 

290 u = self.U(omega) 

291 x0 = self.x0(omega) 

292 gamma = self.gamma(omega) 

293 return x0 + gamma * np.tan(np.pi * (u - 0.5)) 

294 

295 

296class t(RandomVariable): 

297 def __init__(self, nu): 

298 super().__init__() 

299 self.U = Uniform() 

300 self.N = randomize(nu) 

301 

302 def __call__(self, omega): 

303 self.check(omega) 

304 u = self.U(omega) 

305 nu = self.N(omega) 

306 return scipy.special.stdtrit(nu, u) 

307 

308 

309class Beta(RandomVariable): 

310 def __init__(self, alpha, beta): 

311 super().__init__() 

312 self.U = Uniform() 

313 self.A = randomize(alpha) 

314 self.B = randomize(beta) 

315 

316 def __call__(self, omega): 

317 self.check(omega) 

318 u = self.U(omega) 

319 alpha = self.A(omega) 

320 beta = self.B(omega) 

321 return scipy.special.btdtri(alpha, beta, u) 

322 

323 

324# ------------------------------------------------------------------------------ 

325for name in dir(np): 

326 item = getattr(np, name) 

327 if isinstance(item, np.ufunc): 

328 globals()[name] = function(item) 

329 

330 

331def randomize(item): 

332 if isinstance(item, RandomVariable): 

333 return item 

334 elif callable(item): 

335 return function(item) 

336 else: 

337 return Constant(item)