import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import fsolve
# define the offspring distribution:
#P(X=0)=0.1, P(X=1)=0.7, P(X=2)=0.1, P(X=3)=0.05, P(X=4)=0.05
def offspring_generator():
rand = np.random.rand()
if rand <= 0.1:
return 0
elif rand <= 0.8:
return 1
elif rand <= 0.9:
return 2
elif rand <= 0.95:
return 3
else:
return 4
#function to simulate branching process.
#input: Max_gen means maximal number of generations
#output: Extinct_gen, the generation that this population goes extinct,
# End_size, the size of the pupulation at Max_gen
def branching_process(Max_gen):
Z =np.zeros(Max_gen + 1,dtype = int)
Z[0]=1
gen = 1
Extinct_gen = Max_gen+1
while gen <= Max_gen:
Z[gen] = sum([offspring_generator() for i in range(Z[gen-1])])
#Z[gen] = sum([np.random.poisson(2, size=1)[0] for i in range(Z[gen-1])])
if Z[gen] == 0:
Extinct_gen = gen
break
gen += 1
End_size = Z[Max_gen]
return Extinct_gen, End_size
#An simulation to get the Probability of extinction
Max_gen=25
sample_size = 50000
Extinct_gen_list=np.full(sample_size,Max_gen+1)
End_size_list = np.full(sample_size,0)
for i in range(sample_size):
One_sample = branching_process(Max_gen)
Extinct_gen_list[i]=One_sample[0]
plt.hist(Extinct_gen_list)
plt.xlabel('Generation to go extinct')
plt.ylabel('Counts')
plt.xticks(range(0,Max_gen + 1,5), [str(i) for i in range(0,Max_gen,5)] + ['non-extinction'])
plt.show()
prob_nonextinct= np.count_nonzero(Extinct_gen_list == Max_gen+1)/sample_size
prob_extinct = 1-prob_nonextinct
print(f"Probability of extinction = {prob_extinct:.4f}")
print(f"Probability of non-extinction = {prob_nonextinct:.4f}")
Probability of extinction = 0.4030 Probability of non-extinction = 0.5970
# plot the generating function and calculate the theoretic extinction probability.
def f1(s):
return 0.1+0.7 * s + 0.1 * s ** 2+ 0.05 * s ** 3+ 0.05 * s ** 4
def f2(s):
return s
s_values = np.linspace(0, 1, 1000)
f1_values = f1(s_values)
f2_values = f2(s_values)
plt.plot(s_values, f1_values, label=r'y=$E(s^X)$', color = 'blue')
plt.plot(s_values, f2_values, label='y=s',linestyle = '--',color = 'red')
def equation_to_solve(s):
return f1(s) - f2(s)
# Find intersection point
intersection_s = fsolve(equation_to_solve, 0.5)[0]
intersection_f = f1(intersection_s)
plt.plot(intersection_s, intersection_f, 'ro') # 'ro' for red dot
plt.xlabel('s')
plt.xlim([0,1])
plt.ylim([0,1])
plt.ylabel('y')
plt.legend()
plt.show()
print(f"Intersection point(theorectic extinction probability): s = {intersection_s:.4f}")
Intersection point(theorectic extinction probability): s = 0.4026