-
Notifications
You must be signed in to change notification settings - Fork 0
/
index.js
157 lines (132 loc) · 3.71 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
const PD = require('probability-distributions');
class Arm {
constructor(idex, a=1, b=1) {
this.idex = idex;
this.a = a;
this.b = b;
}
record_success() {
this.a ++;
}
record_fail() {
this.b ++;
}
draw_ctr() {
return PD.rbeta(1, this.a, this.b)[0];
}
mean() {
return this.a / (this.a + this.b);
}
}
function thompson_sampling(arms) {
const all_sample = arms.map((arm) => {
return {
idex: arm.idex,
ctr: arm.draw_ctr()
};
});
return all_sample.sort((a,b)=>{return b.ctr-a.ctr})[0];
}
function monte_carlo_simulation(arms, draw = 100) {
const alphas = [];
const betas = [];
const mc = [];
const winner_idxs = [];
arms.forEach((arm) => {
alphas.push(arm.a);
betas.push(arm.b);
});
for (let i = 0; i < draw; i ++) {
const temp = [];
arms.forEach((arm) => {
temp.push(arm.draw_ctr());
});
mc.push(temp);
winner_idxs.push(temp.indexOf(Math.max(...temp)));
}
// console.log('winner_idxs', winner_idxs)
const counts = [...Array(arms.length)].map(_=>0);
winner_idxs.forEach((item) => {
counts[item] ++;
})
const p_winner = counts.map((count) => {
return count / draw;
});
// console.log('p_winner', p_winner)
return {
mc,
p_winner
};
}
function should_terminate(p_winner, est_ctrs, mc, alpha=0.05) {
const winner_idx = p_winner.indexOf(Math.max(...p_winner));
const values_remaining = mc.map((item) => {
const max = Math.max(...item);
return (max - item[winner_idx]) / item[winner_idx];
});
const q = 1 - alpha;
const position = values_remaining.length * q;
if (Math.round(position) === position) {
return values_remaining.sort()[position] < 0.01 * est_ctrs[winner_idx];
} else {
const temp = values_remaining.sort();
return (temp[Math.floor(position)] + temp[Math.ceil(position)]) / 2 < 0.01 * est_ctrs[winner_idx];
}
}
function k_arm_bandit(ctrs, alpha=0.05, burn_in=1000, max_iter=100000, draw=100, slient=false) {
const n_arms = ctrs.length;
const arms = [];
let est_ctrs, idx;
const history_p = [];
for (let i = 0; i < n_arms; i ++) {
arms.push(new Arm(i));
history_p.push([]);
}
let i = 0;
for (; i < max_iter; i ++) {
// console.log('---i---', i);
idx = thompson_sampling(arms).idex;
// debugger
const arm = arms[idx];
const ctr = ctrs[idx];
if (Math.random() < ctr) {
arm.record_success();
} else {
arm.record_fail();
}
const { mc, p_winner } = monte_carlo_simulation(arms, draw);
// console.log('p_winner', p_winner)
for (let j = 0; j < p_winner.length; j ++) {
// console.log(history_p[j])
history_p[j].push(p_winner[j]);
}
est_ctrs = arms.map((arm) => {
return arm.mean();
});
console.log('should_terminate(p_winner, est_ctrs, mc, alpha)', should_terminate(p_winner, est_ctrs, mc, alpha))
if (i >= burn_in && should_terminate(p_winner, est_ctrs, mc, alpha)) {
if (!slient) {
console.log(`Terminated at iteration ${i}`);
}
break;
}
}
const traffic = arms.map((arm) => {
console.log('arm.a', arm.a, 'arm.b', arm.b);
return arm.a + arm.b - 2;
});
return {
idx,
i,
est_ctrs,
history_p,
traffic
}
}
export default {
Arm,
thompson_sampling,
monte_carlo_simulation,
should_terminate,
k_arm_bandit
};