summaryrefslogtreecommitdiff
path: root/main.py
blob: c851a478c1121ff180db04f25cd1d578e6cc8112 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import os
from pathlib import Path

from law import Law 

ps = np.linspace(0.2,0.9,9)
ms = np.linspace(5,20,10).astype(int)

path = Path("plot")


E = np.zeros([9,10]).astype(float)
for i,p in enumerate(ps):
    for j,m in enumerate(ms):
        os.makedirs(Path(path,str(p),str(m)),exist_ok=True)
        tmp_path = Path(path,str(p),str(m))
        law = Law(p,m)
        start, end = law.cdf_position()
        mean = law.mean()
        N = 1000
        ns = np.linspace(start,end,N)

        mass_curve = np.zeros(N).astype(float)
        cdf_curve = np.zeros_like(mass_curve)
        for ni,n in enumerate(ns):
            mass_curve[ni] = law.mass(n)
            cdf_curve[ni] = law.cdf(n)

        plt.plot(ns,mass_curve)
        plt.xlabel("n")
        plt.ylabel("P(D)=n")
        plt.savefig(Path(tmp_path,"mass.pdf"),bbox_inches="tight")
        plt.clf()

        plt.plot(ns,cdf_curve)
        plt.xlabel("n")
        plt.ylabel("P(D)\\leq n")
        plt.savefig(Path(tmp_path,"cdf.pdf"),bbox_inches="tight")
        plt.clf()

        E[i,j] = law.mean()
        

msl = [f"m={m}" for m in ms]
psl = [f"p={round(p,2)}" for p in ps]
plt.plot(ps,np.log(E)/np.log(10),label=msl)
plt.rcParams.update({
"text.usetex": True,
"font.family": "sans-serif",
"font.size":12
})
plt.xlabel("p")
plt.ylabel("E(D)")
ax=plt.gca()
log_lab = ax.get_yticks()
m = int(np.min(log_lab))
M = int(np.max(log_lab))
pos = np.linspace(m,M,M-m+1)
ax.set_yticks(pos)
#ax.set_yticks(log_lab)
ax.set_yticklabels([f"$10^{{ {int(t)} }}$" for t in pos])
#ax.set_yticklabels(np.round(np.exp(log_lab),2))
plt.legend()
plt.savefig(Path(path,"pvsE.pdf"),bbox_inches="tight")
plt.clf()

plt.plot(ms,np.transpose(np.log(E))/np.log(10),label=psl)
plt.legend()
plt.xlabel("m")
plt.rcParams.update({
"text.usetex": True,
"font.family": "sans-serif",
})
plt.ylabel("E(D)")
ax=plt.gca()
log_lab = ax.get_yticks()
m = int(np.min(log_lab))
M = int(np.max(log_lab))
pos = np.linspace(m,M,M-m+1)
ax.set_yticks(pos)
ax.set_xticks([5,10,15,20])
#ax.set_yticks(log_lab)
ax.set_yticklabels([f"$10^{{ {int(t)} }}$" for t in pos])
plt.ylabel("E(D)")
plt.savefig(Path(path,"mvsE.pdf"),bbox_inches="tight")
plt.clf()