【WALT】predict_and_update_buckets() 与 update_task_pred_demand() 代码详解
@
【WALT】predict_and_update_buckets() 与 update_task_pred_demand() 代码详解
代码版本:Linux4.9 android-msm-crosshatch-4.9-android12
代码展示
static inline u32 predict_and_update_buckets(struct rq *rq,
struct task_struct *p, u32 runtime) {
int bidx;
u32 pred_demand;
if (!sched_predl)
return 0;
// ⑴ 根据 runtime 给出桶的下标
bidx = busy_to_bucket(runtime);
// ⑵ 根据桶的下标预测 pred_demand
pred_demand = get_pred_busy(rq, p, bidx, runtime);
// ⑶ 根据桶的下标更新桶
bucket_increase(p->ravg.busy_buckets, bidx);
return pred_demand;
}
void update_task_pred_demand(struct rq *rq, struct task_struct *p, int event)
{
u32 new, old;
if (!sched_predl)
return;
if (is_idle_task(p) || exiting_task(p))
return;
if (event != PUT_PREV_TASK && event != TASK_UPDATE &&
(!SCHED_FREQ_ACCOUNT_WAIT_TIME ||
(event != TASK_MIGRATE &&
event != PICK_NEXT_TASK)))
return;
if (event == TASK_UPDATE) {
if (!p->on_rq && !SCHED_FREQ_ACCOUNT_WAIT_TIME)
return;
}
new = calc_pred_demand(rq, p);
old = p->ravg.pred_demand;
if (old >= new)
return;
if (task_on_rq_queued(p) && (!task_has_dl_policy(p) ||
!p->dl.dl_throttled))
p->sched_class->fixup_walt_sched_stats(rq, p, p->ravg.demand, new);
p->ravg.pred_demand = new;
}
代码逻辑
sched_predl 默认为 1,意味着正常情况下这两个函数都会执行。
predict_and_update_buckets() 主要通过桶算法来预测 pred_demand。其中,runtime 是在 update_task_demand() 中根据 scale_exec_time() 归一化过后的执行时间。
update_task_pred_demand() 主要是在任务的 perd_demand 值小于 curr_window 时再次计算 pred_demand。
桶算法的核心是数组 busy_buckets[NUM_BUSY_BUCKETS]
。
代码中对数组 busy_buckets 的描述是:'busy_buckets' 将历史 busy time 分组到用于预测的不同桶中。('busy_buckets' groups historical busy time into different buckets used for prediction.)
我将数组 busy_buckets 称之为一种类似 PELT 的机制,数组中每一个下标对应的值可以算作是一种权重,数组会根据窗口的不断翻滚对权重进行衰减或加强。
桶算法将一个 WALT 窗口的时间分为 NUM_BUSY_BUCKETS 块(默认为 10 块),可以看做是 NUM_BUSY_BUCKETS 个等级。runtime 越大,它所处的等级就越高。
每次调用函数 predict_and_update_buckets()
时会先判断上一个窗口中当前任务所执行的时间(归一化后)处于的等级 bidx,根据 bidx 和数组 busy_buckets 来预测 pred_demand,然后对数组 busy_buckets 中下标非 bidx 的权重进行衰减,对下标为 bidx 的权重进行加强。
⑴ 根据 runtime 给出桶的下标
static inline int busy_to_bucket(u32 normalized_rt)
{
int bidx;
bidx = mult_frac(normalized_rt, NUM_BUSY_BUCKETS, max_task_load());
bidx = min(bidx, NUM_BUSY_BUCKETS - 1);
if (!bidx)
bidx++;
return bidx;
}
参数解释:
- normalized_rt:传入的 runtime(已归一化);
- NUM_BUSY_BUCKETS:设定的桶的数量,是数组
busy_buckets[NUM_BUSY_BUCKETS]
的长度,当前版本中默认为 10; - max_task_load():函数返回值为窗口大小 sched_ravg_window。
因此可以得到 \(bidx = NUM\_BUSY\_BUCKETS\times\dfrac{runtime}{window\_size}\)
为了防止 bidx 超出数组 busy_buckets
的下标,还需要在 bidx 和 NUM_BUSY_BUCKETS - 1 中取最小值。
如果得到的 bidx 为 0,就让 bidx = 1。这是因为最低频率落入第二个桶,因此继续预测最低的桶是没有用的。(源自代码注释:The lowest frequency falls into 2nd bucket and thus keep predicting lowest bucket is not useful.)
函数返回下标 bidx,用于 pred_demand 的计算当中。
⑵ 根据桶的下标预测 pred_demand
static u32 get_pred_busy(struct rq *rq, struct task_struct *p,
int start, u32 runtime)
{
int i;
u8 *buckets = p->ravg.busy_buckets;
u32 *hist = p->ravg.sum_history;
u32 dmin, dmax;
u64 cur_freq_runtime = 0;
int first = NUM_BUSY_BUCKETS, final;
u32 ret = runtime;
u32 temp = runtime;
// 1. 如果任务刚被创建,直接结束
if (unlikely(is_new_task(p)))
goto out;
// 2. 根据下标 bidx 和 busy_buckets 数组找到上下界 dmin 和 dmax
for (i = start; i < NUM_BUSY_BUCKETS; i++) {
if (buckets[i]) {
first = i;
break;
}
}
if (first >= NUM_BUSY_BUCKETS)
goto out;
final = first;
if (final < 2) {
dmin = 0;
final = 1;
} else {
dmin = mult_frac(final, max_task_load(), NUM_BUSY_BUCKETS);
}
dmax = mult_frac(final + 1, max_task_load(), NUM_BUSY_BUCKETS);
// 3. 根据 dmin 和 dmax,以及历史执行时间来预测 pred_demand
for (i = 0; i < sched_ravg_hist_size; i++) {
if (hist[i] >= dmin && hist[i] < dmax) {
ret = hist[i];
break;
}
}
if (ret < dmin)
ret = (dmin + dmax) / 2;
ret = max(runtime, ret);
out:
trace_sched_update_pred_demand(rq, p, runtime,
mult_frac((unsigned int)cur_freq_runtime, 100,
sched_ravg_window), ret);
return ret;
}
get_pred_busy() 函数根据 runtime 和桶下标 bidx 来找到一个 busy time。
1. 如果任务刚被创建,直接结束
新创建的任务没有历史执行时间的数据,因此没法进行预测,直接结束,返回 runtime。
2. 根据下标 bidx 和数组 busy_buckets 找到上下界 dmin 和 dmax
数组 busy_buckets 中值非 0 意味着至少 4 个历史窗口内有某个窗口的 runtime 处于该等级。
遍历数组 busy_buckets 时发现第一个非 0 的值,就把该值所在下标作为 first,并让 final = first。
如果 final < 2,意味着过去的 4 个历史窗口中曾经有一个窗口的 runtime 处于等级 0 或 1,就让 \(dmin = 0\),\(dmax = window\_size \times \dfrac{2}{NUM\_BUSY\_BUCKETS}\)。
如果 final ≥ 2,就让 \(dmin = window\_size \times \dfrac{final}{NUM\_BUSY\_BUCKETS}\),\(dmax = window\_size \times \dfrac{final+1}{NUM\_BUSY\_BUCKETS}\)。
3. 根据 dmin 和 dmax,以及历史执行时间来预测 pred_demand
ret 的初始值设置为 runtime。
得到上下界 dmin 和 dmax 之后,遍历数组 sum_history,看看保存的历史窗口的时间中是否有处于该上下界之间的,如果有就将 ret 设置为该时间,没有的话 ret 的值仍为 runtime。
如果 ret 的值比下界 dmin 还小,意味着 runtime 比之前的所有历史时间都小,就让 ret 的值设置为 (dmin + dmax) / 2。
最后再在 runtime 和 ret 之间取最大值,作为 pred_demand 返回。
⑶ 根据桶的下标更新桶
#define INC_STEP 8
#define DEC_STEP 2
#define CONSISTENT_THRES 16
#define INC_STEP_BIG 16
static inline void bucket_increase(u8 *buckets, int idx)
{
int i, step;
for (i = 0; i < NUM_BUSY_BUCKETS; i++) {
// 如果下标不是 bidx,衰减直至0
if (idx != i) {
if (buckets[i] > DEC_STEP)
buckets[i] -= DEC_STEP;
else
buckets[i] = 0;
} else {
step = buckets[i] >= CONSISTENT_THRES ?
INC_STEP_BIG : INC_STEP;
// U8_MAX = 255
if (buckets[i] > U8_MAX - step)
buckets[i] = U8_MAX;
else
buckets[i] += step;
}
}
}
这一部分就是根据 bidx 更新数组 busy_buckets。
busy_buckets[bidx] 的值如果大于等于 16,就要累加 16,直至 255;如果小于16,就累加 8,直至 255。
下标不是 bidx 的值都要衰减 2,直到该值为 0 为止。
⑷ 更新 pred_demand
不需要更新的情况有:
- 任务是 idle 任务或任务要出队
if (event != PUT_PREV_TASK && event != TASK_UPDATE && (!SCHED_FREQ_ACCOUNT_WAIT_TIME || (event != TASK_MIGRATE && event != PICK_NEXT_TASK)))
但是当前版本中 SCHED_FREQ_ACCOUNT_WAIT_TIME 默认为 1,所以该情况不会发生if (event == TASK_UPDATE) and if (!p->on_rq && !SCHED_FREQ_ACCOUNT_WAIT_TIME)
同上,不会发生
需要更新时,执行函数 calc_pred_demand():
static inline u32 calc_pred_demand(struct rq *rq, struct task_struct *p)
{
if (p->ravg.pred_demand >= p->ravg.curr_window)
return p->ravg.pred_demand;
return get_pred_busy(rq, p, busy_to_bucket(p->ravg.curr_window),
p->ravg.curr_window);
}
如果任务的 pred_demand 比 curr_window 大,就不更新,否则重新通过 get_pred_busy()
更新 pred_demand。
其中,curr_window 在 update_cpu_busy_time() 中计算,并不在 update_task_demand() 中计算。
最后,如果任务在运行队列上,且任务不是 dl 任务或 dl 任务的时间没有耗尽,就通过 fixup_walt_sched_stats()
计算负载。该部分见 【WALT】调度与负载计算