Simulation on Branching process¶

Huarui Zhou¶

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import fsolve
In [2]:
# define the offspring distribution: 
#P(X=0)=0.1, P(X=1)=0.7, P(X=2)=0.1, P(X=3)=0.05, P(X=4)=0.05
def offspring_generator():
    rand = np.random.rand()
    if rand <= 0.1:
        return 0
    elif rand <= 0.8:
        return 1
    elif rand <= 0.9:
        return 2
    elif rand <= 0.95:
        return 3
    else:
        return 4
In [3]:
#function to simulate branching process. 
#input: Max_gen means maximal number of generations
#output: Extinct_gen, the generation that this population goes extinct, 
#        End_size, the size of the pupulation at Max_gen
def branching_process(Max_gen):
    Z =np.zeros(Max_gen + 1,dtype = int)
    Z[0]=1
    gen = 1
    Extinct_gen = Max_gen+1
    while gen <= Max_gen:
        Z[gen] = sum([offspring_generator() for i in range(Z[gen-1])])
        #Z[gen] = sum([np.random.poisson(2, size=1)[0] for i in range(Z[gen-1])])
        if Z[gen] == 0: 
            Extinct_gen = gen
            break
        gen += 1
    End_size = Z[Max_gen]
    return Extinct_gen, End_size
In [6]:
#An simulation to get the Probability of extinction
Max_gen=25
sample_size = 50000
Extinct_gen_list=np.full(sample_size,Max_gen+1)
End_size_list = np.full(sample_size,0)
for i in range(sample_size):
    One_sample = branching_process(Max_gen)
    Extinct_gen_list[i]=One_sample[0]
In [9]:
plt.hist(Extinct_gen_list) 
plt.xlabel('Generation to go extinct')
plt.ylabel('Counts')
plt.xticks(range(0,Max_gen + 1,5), [str(i) for i in range(0,Max_gen,5)] + ['non-extinction'])
plt.show()
prob_nonextinct= np.count_nonzero(Extinct_gen_list == Max_gen+1)/sample_size
prob_extinct = 1-prob_nonextinct
print(f"Probability of extinction = {prob_extinct:.4f}")
print(f"Probability of non-extinction = {prob_nonextinct:.4f}")
Probability of extinction = 0.4030
Probability of non-extinction = 0.5970
In [10]:
# plot the generating function and calculate the theoretic extinction probability.
def f1(s):
    return 0.1+0.7 * s + 0.1 * s ** 2+ 0.05 * s ** 3+ 0.05 * s ** 4

def f2(s):
    return s

s_values = np.linspace(0, 1, 1000)

f1_values = f1(s_values)
f2_values = f2(s_values)

plt.plot(s_values, f1_values, label=r'y=$E(s^X)$', color = 'blue')
plt.plot(s_values, f2_values, label='y=s',linestyle = '--',color = 'red')

def equation_to_solve(s):
    return f1(s) - f2(s)

# Find intersection point
intersection_s = fsolve(equation_to_solve, 0.5)[0]
intersection_f = f1(intersection_s)

plt.plot(intersection_s, intersection_f, 'ro')  # 'ro' for red dot

plt.xlabel('s')
plt.xlim([0,1])
plt.ylim([0,1])
plt.ylabel('y')
plt.legend()

plt.show()


print(f"Intersection point(theorectic extinction probability): s = {intersection_s:.4f}")
Intersection point(theorectic extinction probability): s = 0.4026
In [ ]: