summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Aalmoes <jan.aalmoes@inria.fr>2024-01-16 07:29:49 +0100
committerJan Aalmoes <jan.aalmoes@inria.fr>2024-01-16 07:29:49 +0100
commitf66960aba99a37cad0e51010d837e17692d0b6d6 (patch)
tree9429d202fad03dc0a2f17666361a901a36933022
initial commitHEADmaster
-rw-r--r--law.py163
-rw-r--r--main.py92
2 files changed, 255 insertions, 0 deletions
diff --git a/law.py b/law.py
new file mode 100644
index 0000000..101200a
--- /dev/null
+++ b/law.py
@@ -0,0 +1,163 @@
+import math
+import numpy as np
+
+class Law:
+ def __init__(self, p, m):
+ self.p = p
+ self.q = 1-p
+ self.a = (1-p)*p**m
+ self.m = m
+ self.eigen_computed = False
+
+ def eigenvalues(self):
+ p = self.p
+ q = self.q
+ a = self.a
+ m = self.m
+
+ pol_char = np.zeros(m+2).astype(float)
+ pol_char[0] = 1
+ pol_char[1] = -1
+ pol_char[-1] = a
+ roots = np.roots(pol_char)
+ roots = roots.reshape(m+1)
+ self.l = roots
+ self.eigen_computed = True
+
+ #Finding initial conditions
+ u = np.zeros(m+1).astype(float)
+ for n in range(m+1):
+ u[n] = 1-p**m-n*(1-p)*p**m
+ I = u[:m+1].reshape(-1,1)
+ L = np.zeros([m+1,m+1]).astype(np.complex_)
+ for i in range(m+1):
+ L[i,:] = roots**(i)
+ try:
+ self.c = np.matmul(np.linalg.inv(L),I).reshape(m+1)
+ except:
+ self.c = float("nan")
+ print("pas invers")
+
+ def mass(self, N):
+ m = self.m
+ p = self.p
+ if N<m:
+ return 0
+ elif N==m:
+ return p**m
+ elif N<=2*m:
+ return (1-p)*p**m
+ else:
+ return np.real(np.sum(self.c*self.l**(N-2*m-1))*(1-p)*p**m)
+
+ def cdf(self, N):
+ if not(self.eigen_computed):
+ self.eigenvalues()
+ p = self.p
+ q = self.q
+ a = self.a
+ m = self.m
+ l = self.l
+ c = self.c
+
+ """P(D<=N)"""
+ if N<m:
+ return 0
+ elif N<=2*m:
+ return p**m+(N-m)*(1-p)*p**m
+ else:
+ un_sum = np.sum(c*(1-l**(N-2*m))/(1-l))
+ return np.real(p**m+m*(1-p)*p**m + (1-p)*p**m*un_sum)
+
+ def cdf_multin(self, N):
+ """N : nombre de messages"""
+ P = np.zeros_like(N).astype(float)
+ for ni,n in enumerate(N):
+ P[ni] = self.cdf(n)
+
+ return np.mean(P)
+
+
+ def cdf_position(self):
+ self.eigenvalues()
+ p = self.p
+ q = self.q
+ a = self.a
+ m = self.m
+ l = self.l
+ c = self.c
+
+ #cdf start point
+ s = 0.01
+ if self.cdf(m)<0.01:
+ start = m
+ else:
+ x0=2*m+1
+ x0=self.mean()
+ x1 = x0
+ #for i in range(1,niter):
+ #while (np.abs(f(x0,p,m,l,c,s))>1):
+ while (np.abs(self.cdf(x0)-s)>1):
+ left = (s-p**m-m*q*p**m)/(q*p**m)
+ f = np.real(np.sum(c*(1-l**(x0-2*m))/(1-l))) - left
+ fp = (-1)*np.real(np.sum(c*(l**(x0-2*m)*np.log(l))/(1-l)))
+ x1 = x0 - f/fp
+ x0 = x1
+ start = x1
+
+ #cdf end point
+ s = 0.99
+ #x0=2*m+1
+ x0=self.mean()
+ x1 = x0
+ i = 0
+
+ while (np.abs(self.cdf(x1)-s)>0.005):
+ for i in range(1000):
+ left = (s-p**m-m*q*p**m)/(q*p**m)
+ f = np.real(np.sum(c*(1-l**(x0-2*m))/(1-l))) - left
+ fp = (-1)*np.real(np.sum(c*(l**(x0-2*m)*np.log(l))/(1-l)))
+ x1 = x0 - f/fp
+ x0 = x1
+ i += 1
+ if math.isnan(x1):
+ x0 = 2*m+1
+ x1 = x0
+ s = s - 0.1*s
+ elif (np.abs(self.cdf(x1)-s)>0.005):
+ s = s - 0.1*s
+
+ end = x1
+
+ return start, end
+
+ def cdf_inv(self, s):
+ m = self.m
+ p = self.p
+ q = 1-p
+ c = self.c
+ l = self.l
+ x0=2*m+1
+ x1 = x0
+ i = 0
+
+ while (np.abs(self.cdf(x1)-s)>0.005):
+ left = (s-p**m-m*q*p**m)/(q*p**m)
+ f = np.real(np.sum(c*(1-l**(x0-2*m))/(1-l))) - left
+ fp = (-1)*np.real(np.sum(c*(l**(x0-2*m)*np.log(l))/(1-l)))
+ x1 = x0 - f/fp
+ x0 = x1
+ #i += 1
+
+ return x1
+
+
+
+ def mean(self):
+ m = self.m
+ p = self.p
+ q = 1-p
+ E = 0
+ E += m*p**m + q*p**m*(3*m**2+m)/2
+ E += q*p**m*(np.sum(self.c*self.l/(1-self.l)**2)+(2*m+1)*np.sum(self.c/(1-self.l)))
+ return np.real(E)
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..c851a47
--- /dev/null
+++ b/main.py
@@ -0,0 +1,92 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.ticker as mtick
+import os
+from pathlib import Path
+
+from law import Law
+
+ps = np.linspace(0.2,0.9,9)
+ms = np.linspace(5,20,10).astype(int)
+
+path = Path("plot")
+
+
+E = np.zeros([9,10]).astype(float)
+for i,p in enumerate(ps):
+ for j,m in enumerate(ms):
+ os.makedirs(Path(path,str(p),str(m)),exist_ok=True)
+ tmp_path = Path(path,str(p),str(m))
+ law = Law(p,m)
+ start, end = law.cdf_position()
+ mean = law.mean()
+ N = 1000
+ ns = np.linspace(start,end,N)
+
+ mass_curve = np.zeros(N).astype(float)
+ cdf_curve = np.zeros_like(mass_curve)
+ for ni,n in enumerate(ns):
+ mass_curve[ni] = law.mass(n)
+ cdf_curve[ni] = law.cdf(n)
+
+ plt.plot(ns,mass_curve)
+ plt.xlabel("n")
+ plt.ylabel("P(D)=n")
+ plt.savefig(Path(tmp_path,"mass.pdf"),bbox_inches="tight")
+ plt.clf()
+
+ plt.plot(ns,cdf_curve)
+ plt.xlabel("n")
+ plt.ylabel("P(D)\\leq n")
+ plt.savefig(Path(tmp_path,"cdf.pdf"),bbox_inches="tight")
+ plt.clf()
+
+ E[i,j] = law.mean()
+
+
+msl = [f"m={m}" for m in ms]
+psl = [f"p={round(p,2)}" for p in ps]
+plt.plot(ps,np.log(E)/np.log(10),label=msl)
+plt.rcParams.update({
+"text.usetex": True,
+"font.family": "sans-serif",
+"font.size":12
+})
+plt.xlabel("p")
+plt.ylabel("E(D)")
+ax=plt.gca()
+log_lab = ax.get_yticks()
+m = int(np.min(log_lab))
+M = int(np.max(log_lab))
+pos = np.linspace(m,M,M-m+1)
+ax.set_yticks(pos)
+#ax.set_yticks(log_lab)
+ax.set_yticklabels([f"$10^{{ {int(t)} }}$" for t in pos])
+#ax.set_yticklabels(np.round(np.exp(log_lab),2))
+plt.legend()
+plt.savefig(Path(path,"pvsE.pdf"),bbox_inches="tight")
+plt.clf()
+
+plt.plot(ms,np.transpose(np.log(E))/np.log(10),label=psl)
+plt.legend()
+plt.xlabel("m")
+plt.rcParams.update({
+"text.usetex": True,
+"font.family": "sans-serif",
+})
+plt.ylabel("E(D)")
+ax=plt.gca()
+log_lab = ax.get_yticks()
+m = int(np.min(log_lab))
+M = int(np.max(log_lab))
+pos = np.linspace(m,M,M-m+1)
+ax.set_yticks(pos)
+ax.set_xticks([5,10,15,20])
+#ax.set_yticks(log_lab)
+ax.set_yticklabels([f"$10^{{ {int(t)} }}$" for t in pos])
+plt.ylabel("E(D)")
+plt.savefig(Path(path,"mvsE.pdf"),bbox_inches="tight")
+plt.clf()
+
+
+