diff mbox series

[net,v2] sched: sch_cake: add bounds checks to host bulk flow fairness counts

Message ID 20250107120105.70685-1-toke@redhat.com (mailing list archive)
State New
Delegated to: Netdev Maintainers
Headers show
Series [net,v2] sched: sch_cake: add bounds checks to host bulk flow fairness counts | expand

Checks

Context Check Description
netdev/series_format success Single patches do not need cover letters
netdev/tree_selection success Clearly marked for net
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag present in non-next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 1 this patch: 1
netdev/build_tools success No tools touched, skip
netdev/cc_maintainers success CCed 10 of 10 maintainers
netdev/build_clang success Errors and warnings before: 2 this patch: 2
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success Fixes tag looks correct
netdev/build_allmodconfig_warn success Errors and warnings before: 1 this patch: 1
netdev/checkpatch warning WARNING: line length of 104 exceeds 80 columns WARNING: line length of 81 exceeds 80 columns WARNING: line length of 82 exceeds 80 columns WARNING: line length of 88 exceeds 80 columns WARNING: line length of 96 exceeds 80 columns WARNING: line length of 98 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Toke Høiland-Jørgensen Jan. 7, 2025, 12:01 p.m. UTC
Even though we fixed a logic error in the commit cited below, syzbot
still managed to trigger an underflow of the per-host bulk flow
counters, leading to an out of bounds memory access.

To avoid any such logic errors causing out of bounds memory accesses,
this commit factors out all accesses to the per-host bulk flow counters
to a series of helpers that perform bounds-checking before any
increments and decrements. This also has the benefit of improving
readability by moving the conditional checks for the flow mode into
these helpers, instead of having them spread out throughout the
code (which was the cause of the original logic error).

v2:
- Remove now-unused srchost and dsthost local variables in cake_dequeue()

Fixes: 546ea84d07e3 ("sched: sch_cake: fix bulk flow accounting logic for host fairness")
Reported-by: syzbot+f63600d288bfb7057424@syzkaller.appspotmail.com
Signed-off-by: Toke Høiland-Jørgensen <toke@redhat.com>
---
 net/sched/sch_cake.c | 140 +++++++++++++++++++++++--------------------
 1 file changed, 75 insertions(+), 65 deletions(-)

Comments

Dave Taht Jan. 8, 2025, 4:10 p.m. UTC | #1
On Tue, Jan 7, 2025 at 4:01 AM Toke Høiland-Jørgensen via Cake
<cake@lists.bufferbloat.net> wrote:
>
> Even though we fixed a logic error in the commit cited below, syzbot
> still managed to trigger an underflow of the per-host bulk flow
> counters, leading to an out of bounds memory access.
>
> To avoid any such logic errors causing out of bounds memory accesses,
> this commit factors out all accesses to the per-host bulk flow counters
> to a series of helpers that perform bounds-checking before any
> increments and decrements. This also has the benefit of improving
> readability by moving the conditional checks for the flow mode into
> these helpers, instead of having them spread out throughout the
> code (which was the cause of the original logic error).
>
> v2:
> - Remove now-unused srchost and dsthost local variables in cake_dequeue()
>
> Fixes: 546ea84d07e3 ("sched: sch_cake: fix bulk flow accounting logic for host fairness")
> Reported-by: syzbot+f63600d288bfb7057424@syzkaller.appspotmail.com
> Signed-off-by: Toke Høiland-Jørgensen <toke@redhat.com>
> ---
>  net/sched/sch_cake.c | 140 +++++++++++++++++++++++--------------------
>  1 file changed, 75 insertions(+), 65 deletions(-)
>
> diff --git a/net/sched/sch_cake.c b/net/sched/sch_cake.c
> index 8d8b2db4653c..2c2e2a67f3b2 100644
> --- a/net/sched/sch_cake.c
> +++ b/net/sched/sch_cake.c
> @@ -627,6 +627,63 @@ static bool cake_ddst(int flow_mode)
>         return (flow_mode & CAKE_FLOW_DUAL_DST) == CAKE_FLOW_DUAL_DST;
>  }
>
> +static void cake_dec_srchost_bulk_flow_count(struct cake_tin_data *q,
> +                                            struct cake_flow *flow,
> +                                            int flow_mode)
> +{
> +       if (likely(cake_dsrc(flow_mode) &&
> +                  q->hosts[flow->srchost].srchost_bulk_flow_count))
> +               q->hosts[flow->srchost].srchost_bulk_flow_count--;
> +}
> +
> +static void cake_inc_srchost_bulk_flow_count(struct cake_tin_data *q,
> +                                            struct cake_flow *flow,
> +                                            int flow_mode)
> +{
> +       if (likely(cake_dsrc(flow_mode) &&
> +                  q->hosts[flow->srchost].srchost_bulk_flow_count < CAKE_QUEUES))
> +               q->hosts[flow->srchost].srchost_bulk_flow_count++;
> +}
> +
> +static void cake_dec_dsthost_bulk_flow_count(struct cake_tin_data *q,
> +                                            struct cake_flow *flow,
> +                                            int flow_mode)
> +{
> +       if (likely(cake_ddst(flow_mode) &&
> +                  q->hosts[flow->dsthost].dsthost_bulk_flow_count))
> +               q->hosts[flow->dsthost].dsthost_bulk_flow_count--;
> +}
> +
> +static void cake_inc_dsthost_bulk_flow_count(struct cake_tin_data *q,
> +                                            struct cake_flow *flow,
> +                                            int flow_mode)
> +{
> +       if (likely(cake_ddst(flow_mode) &&
> +                  q->hosts[flow->dsthost].dsthost_bulk_flow_count < CAKE_QUEUES))
> +               q->hosts[flow->dsthost].dsthost_bulk_flow_count++;
> +}
> +
> +static u16 cake_get_flow_quantum(struct cake_tin_data *q,
> +                                struct cake_flow *flow,
> +                                int flow_mode)
> +{
> +       u16 host_load = 1;
> +
> +       if (cake_dsrc(flow_mode))
> +               host_load = max(host_load,
> +                               q->hosts[flow->srchost].srchost_bulk_flow_count);
> +
> +       if (cake_ddst(flow_mode))
> +               host_load = max(host_load,
> +                               q->hosts[flow->dsthost].dsthost_bulk_flow_count);
> +
> +       /* The get_random_u16() is a way to apply dithering to avoid
> +        * accumulating roundoff errors
> +        */
> +       return (q->flow_quantum * quantum_div[host_load] +
> +               get_random_u16()) >> 16;
> +}
> +
>  static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
>                      int flow_mode, u16 flow_override, u16 host_override)
>  {
> @@ -773,10 +830,8 @@ static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
>                 allocate_dst = cake_ddst(flow_mode);
>
>                 if (q->flows[outer_hash + k].set == CAKE_SET_BULK) {
> -                       if (allocate_src)
> -                               q->hosts[q->flows[reduced_hash].srchost].srchost_bulk_flow_count--;
> -                       if (allocate_dst)
> -                               q->hosts[q->flows[reduced_hash].dsthost].dsthost_bulk_flow_count--;
> +                       cake_dec_srchost_bulk_flow_count(q, &q->flows[outer_hash + k], flow_mode);
> +                       cake_dec_dsthost_bulk_flow_count(q, &q->flows[outer_hash + k], flow_mode);
>                 }
>  found:
>                 /* reserve queue for future packets in same flow */
> @@ -801,9 +856,10 @@ static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
>                         q->hosts[outer_hash + k].srchost_tag = srchost_hash;
>  found_src:
>                         srchost_idx = outer_hash + k;
> -                       if (q->flows[reduced_hash].set == CAKE_SET_BULK)
> -                               q->hosts[srchost_idx].srchost_bulk_flow_count++;
>                         q->flows[reduced_hash].srchost = srchost_idx;
> +
> +                       if (q->flows[reduced_hash].set == CAKE_SET_BULK)
> +                               cake_inc_srchost_bulk_flow_count(q, &q->flows[reduced_hash], flow_mode);
>                 }
>
>                 if (allocate_dst) {
> @@ -824,9 +880,10 @@ static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
>                         q->hosts[outer_hash + k].dsthost_tag = dsthost_hash;
>  found_dst:
>                         dsthost_idx = outer_hash + k;
> -                       if (q->flows[reduced_hash].set == CAKE_SET_BULK)
> -                               q->hosts[dsthost_idx].dsthost_bulk_flow_count++;
>                         q->flows[reduced_hash].dsthost = dsthost_idx;
> +
> +                       if (q->flows[reduced_hash].set == CAKE_SET_BULK)
> +                               cake_inc_dsthost_bulk_flow_count(q, &q->flows[reduced_hash], flow_mode);
>                 }
>         }
>
> @@ -1839,10 +1896,6 @@ static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch,
>
>         /* flowchain */
>         if (!flow->set || flow->set == CAKE_SET_DECAYING) {
> -               struct cake_host *srchost = &b->hosts[flow->srchost];
> -               struct cake_host *dsthost = &b->hosts[flow->dsthost];
> -               u16 host_load = 1;
> -
>                 if (!flow->set) {
>                         list_add_tail(&flow->flowchain, &b->new_flows);
>                 } else {
> @@ -1852,18 +1905,8 @@ static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch,
>                 flow->set = CAKE_SET_SPARSE;
>                 b->sparse_flow_count++;
>
> -               if (cake_dsrc(q->flow_mode))
> -                       host_load = max(host_load, srchost->srchost_bulk_flow_count);
> -
> -               if (cake_ddst(q->flow_mode))
> -                       host_load = max(host_load, dsthost->dsthost_bulk_flow_count);
> -
> -               flow->deficit = (b->flow_quantum *
> -                                quantum_div[host_load]) >> 16;
> +               flow->deficit = cake_get_flow_quantum(b, flow, q->flow_mode);
>         } else if (flow->set == CAKE_SET_SPARSE_WAIT) {
> -               struct cake_host *srchost = &b->hosts[flow->srchost];
> -               struct cake_host *dsthost = &b->hosts[flow->dsthost];
> -
>                 /* this flow was empty, accounted as a sparse flow, but actually
>                  * in the bulk rotation.
>                  */
> @@ -1871,12 +1914,8 @@ static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch,
>                 b->sparse_flow_count--;
>                 b->bulk_flow_count++;
>
> -               if (cake_dsrc(q->flow_mode))
> -                       srchost->srchost_bulk_flow_count++;
> -
> -               if (cake_ddst(q->flow_mode))
> -                       dsthost->dsthost_bulk_flow_count++;
> -
> +               cake_inc_srchost_bulk_flow_count(b, flow, q->flow_mode);
> +               cake_inc_dsthost_bulk_flow_count(b, flow, q->flow_mode);
>         }
>
>         if (q->buffer_used > q->buffer_max_used)
> @@ -1933,13 +1972,11 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>  {
>         struct cake_sched_data *q = qdisc_priv(sch);
>         struct cake_tin_data *b = &q->tins[q->cur_tin];
> -       struct cake_host *srchost, *dsthost;
>         ktime_t now = ktime_get();
>         struct cake_flow *flow;
>         struct list_head *head;
>         bool first_flow = true;
>         struct sk_buff *skb;
> -       u16 host_load;
>         u64 delay;
>         u32 len;
>
> @@ -2039,11 +2076,6 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>         q->cur_flow = flow - b->flows;
>         first_flow = false;
>
> -       /* triple isolation (modified DRR++) */
> -       srchost = &b->hosts[flow->srchost];
> -       dsthost = &b->hosts[flow->dsthost];
> -       host_load = 1;
> -
>         /* flow isolation (DRR++) */
>         if (flow->deficit <= 0) {
>                 /* Keep all flows with deficits out of the sparse and decaying
> @@ -2055,11 +2087,8 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>                                 b->sparse_flow_count--;
>                                 b->bulk_flow_count++;
>
> -                               if (cake_dsrc(q->flow_mode))
> -                                       srchost->srchost_bulk_flow_count++;
> -
> -                               if (cake_ddst(q->flow_mode))
> -                                       dsthost->dsthost_bulk_flow_count++;
> +                               cake_inc_srchost_bulk_flow_count(b, flow, q->flow_mode);
> +                               cake_inc_dsthost_bulk_flow_count(b, flow, q->flow_mode);
>
>                                 flow->set = CAKE_SET_BULK;
>                         } else {
> @@ -2071,19 +2100,7 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>                         }
>                 }
>
> -               if (cake_dsrc(q->flow_mode))
> -                       host_load = max(host_load, srchost->srchost_bulk_flow_count);
> -
> -               if (cake_ddst(q->flow_mode))
> -                       host_load = max(host_load, dsthost->dsthost_bulk_flow_count);
> -
> -               WARN_ON(host_load > CAKE_QUEUES);
> -
> -               /* The get_random_u16() is a way to apply dithering to avoid
> -                * accumulating roundoff errors
> -                */
> -               flow->deficit += (b->flow_quantum * quantum_div[host_load] +
> -                                 get_random_u16()) >> 16;
> +               flow->deficit += cake_get_flow_quantum(b, flow, q->flow_mode);
>                 list_move_tail(&flow->flowchain, &b->old_flows);
>
>                 goto retry;
> @@ -2107,11 +2124,8 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>                                 if (flow->set == CAKE_SET_BULK) {
>                                         b->bulk_flow_count--;
>
> -                                       if (cake_dsrc(q->flow_mode))
> -                                               srchost->srchost_bulk_flow_count--;
> -
> -                                       if (cake_ddst(q->flow_mode))
> -                                               dsthost->dsthost_bulk_flow_count--;
> +                                       cake_dec_srchost_bulk_flow_count(b, flow, q->flow_mode);
> +                                       cake_dec_dsthost_bulk_flow_count(b, flow, q->flow_mode);
>
>                                         b->decaying_flow_count++;
>                                 } else if (flow->set == CAKE_SET_SPARSE ||
> @@ -2129,12 +2143,8 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>                                 else if (flow->set == CAKE_SET_BULK) {
>                                         b->bulk_flow_count--;
>
> -                                       if (cake_dsrc(q->flow_mode))
> -                                               srchost->srchost_bulk_flow_count--;
> -
> -                                       if (cake_ddst(q->flow_mode))
> -                                               dsthost->dsthost_bulk_flow_count--;
> -
> +                                       cake_dec_srchost_bulk_flow_count(b, flow, q->flow_mode);
> +                                       cake_dec_dsthost_bulk_flow_count(b, flow, q->flow_mode);
>                                 } else
>                                         b->decaying_flow_count--;
>
> --
> 2.47.1
>

Acked-By: Dave Taht <dave.taht@gmail.com>
diff mbox series

Patch

diff --git a/net/sched/sch_cake.c b/net/sched/sch_cake.c
index 8d8b2db4653c..2c2e2a67f3b2 100644
--- a/net/sched/sch_cake.c
+++ b/net/sched/sch_cake.c
@@ -627,6 +627,63 @@  static bool cake_ddst(int flow_mode)
 	return (flow_mode & CAKE_FLOW_DUAL_DST) == CAKE_FLOW_DUAL_DST;
 }
 
+static void cake_dec_srchost_bulk_flow_count(struct cake_tin_data *q,
+					     struct cake_flow *flow,
+					     int flow_mode)
+{
+	if (likely(cake_dsrc(flow_mode) &&
+		   q->hosts[flow->srchost].srchost_bulk_flow_count))
+		q->hosts[flow->srchost].srchost_bulk_flow_count--;
+}
+
+static void cake_inc_srchost_bulk_flow_count(struct cake_tin_data *q,
+					     struct cake_flow *flow,
+					     int flow_mode)
+{
+	if (likely(cake_dsrc(flow_mode) &&
+		   q->hosts[flow->srchost].srchost_bulk_flow_count < CAKE_QUEUES))
+		q->hosts[flow->srchost].srchost_bulk_flow_count++;
+}
+
+static void cake_dec_dsthost_bulk_flow_count(struct cake_tin_data *q,
+					     struct cake_flow *flow,
+					     int flow_mode)
+{
+	if (likely(cake_ddst(flow_mode) &&
+		   q->hosts[flow->dsthost].dsthost_bulk_flow_count))
+		q->hosts[flow->dsthost].dsthost_bulk_flow_count--;
+}
+
+static void cake_inc_dsthost_bulk_flow_count(struct cake_tin_data *q,
+					     struct cake_flow *flow,
+					     int flow_mode)
+{
+	if (likely(cake_ddst(flow_mode) &&
+		   q->hosts[flow->dsthost].dsthost_bulk_flow_count < CAKE_QUEUES))
+		q->hosts[flow->dsthost].dsthost_bulk_flow_count++;
+}
+
+static u16 cake_get_flow_quantum(struct cake_tin_data *q,
+				 struct cake_flow *flow,
+				 int flow_mode)
+{
+	u16 host_load = 1;
+
+	if (cake_dsrc(flow_mode))
+		host_load = max(host_load,
+				q->hosts[flow->srchost].srchost_bulk_flow_count);
+
+	if (cake_ddst(flow_mode))
+		host_load = max(host_load,
+				q->hosts[flow->dsthost].dsthost_bulk_flow_count);
+
+	/* The get_random_u16() is a way to apply dithering to avoid
+	 * accumulating roundoff errors
+	 */
+	return (q->flow_quantum * quantum_div[host_load] +
+		get_random_u16()) >> 16;
+}
+
 static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
 		     int flow_mode, u16 flow_override, u16 host_override)
 {
@@ -773,10 +830,8 @@  static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
 		allocate_dst = cake_ddst(flow_mode);
 
 		if (q->flows[outer_hash + k].set == CAKE_SET_BULK) {
-			if (allocate_src)
-				q->hosts[q->flows[reduced_hash].srchost].srchost_bulk_flow_count--;
-			if (allocate_dst)
-				q->hosts[q->flows[reduced_hash].dsthost].dsthost_bulk_flow_count--;
+			cake_dec_srchost_bulk_flow_count(q, &q->flows[outer_hash + k], flow_mode);
+			cake_dec_dsthost_bulk_flow_count(q, &q->flows[outer_hash + k], flow_mode);
 		}
 found:
 		/* reserve queue for future packets in same flow */
@@ -801,9 +856,10 @@  static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
 			q->hosts[outer_hash + k].srchost_tag = srchost_hash;
 found_src:
 			srchost_idx = outer_hash + k;
-			if (q->flows[reduced_hash].set == CAKE_SET_BULK)
-				q->hosts[srchost_idx].srchost_bulk_flow_count++;
 			q->flows[reduced_hash].srchost = srchost_idx;
+
+			if (q->flows[reduced_hash].set == CAKE_SET_BULK)
+				cake_inc_srchost_bulk_flow_count(q, &q->flows[reduced_hash], flow_mode);
 		}
 
 		if (allocate_dst) {
@@ -824,9 +880,10 @@  static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
 			q->hosts[outer_hash + k].dsthost_tag = dsthost_hash;
 found_dst:
 			dsthost_idx = outer_hash + k;
-			if (q->flows[reduced_hash].set == CAKE_SET_BULK)
-				q->hosts[dsthost_idx].dsthost_bulk_flow_count++;
 			q->flows[reduced_hash].dsthost = dsthost_idx;
+
+			if (q->flows[reduced_hash].set == CAKE_SET_BULK)
+				cake_inc_dsthost_bulk_flow_count(q, &q->flows[reduced_hash], flow_mode);
 		}
 	}
 
@@ -1839,10 +1896,6 @@  static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch,
 
 	/* flowchain */
 	if (!flow->set || flow->set == CAKE_SET_DECAYING) {
-		struct cake_host *srchost = &b->hosts[flow->srchost];
-		struct cake_host *dsthost = &b->hosts[flow->dsthost];
-		u16 host_load = 1;
-
 		if (!flow->set) {
 			list_add_tail(&flow->flowchain, &b->new_flows);
 		} else {
@@ -1852,18 +1905,8 @@  static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch,
 		flow->set = CAKE_SET_SPARSE;
 		b->sparse_flow_count++;
 
-		if (cake_dsrc(q->flow_mode))
-			host_load = max(host_load, srchost->srchost_bulk_flow_count);
-
-		if (cake_ddst(q->flow_mode))
-			host_load = max(host_load, dsthost->dsthost_bulk_flow_count);
-
-		flow->deficit = (b->flow_quantum *
-				 quantum_div[host_load]) >> 16;
+		flow->deficit = cake_get_flow_quantum(b, flow, q->flow_mode);
 	} else if (flow->set == CAKE_SET_SPARSE_WAIT) {
-		struct cake_host *srchost = &b->hosts[flow->srchost];
-		struct cake_host *dsthost = &b->hosts[flow->dsthost];
-
 		/* this flow was empty, accounted as a sparse flow, but actually
 		 * in the bulk rotation.
 		 */
@@ -1871,12 +1914,8 @@  static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch,
 		b->sparse_flow_count--;
 		b->bulk_flow_count++;
 
-		if (cake_dsrc(q->flow_mode))
-			srchost->srchost_bulk_flow_count++;
-
-		if (cake_ddst(q->flow_mode))
-			dsthost->dsthost_bulk_flow_count++;
-
+		cake_inc_srchost_bulk_flow_count(b, flow, q->flow_mode);
+		cake_inc_dsthost_bulk_flow_count(b, flow, q->flow_mode);
 	}
 
 	if (q->buffer_used > q->buffer_max_used)
@@ -1933,13 +1972,11 @@  static struct sk_buff *cake_dequeue(struct Qdisc *sch)
 {
 	struct cake_sched_data *q = qdisc_priv(sch);
 	struct cake_tin_data *b = &q->tins[q->cur_tin];
-	struct cake_host *srchost, *dsthost;
 	ktime_t now = ktime_get();
 	struct cake_flow *flow;
 	struct list_head *head;
 	bool first_flow = true;
 	struct sk_buff *skb;
-	u16 host_load;
 	u64 delay;
 	u32 len;
 
@@ -2039,11 +2076,6 @@  static struct sk_buff *cake_dequeue(struct Qdisc *sch)
 	q->cur_flow = flow - b->flows;
 	first_flow = false;
 
-	/* triple isolation (modified DRR++) */
-	srchost = &b->hosts[flow->srchost];
-	dsthost = &b->hosts[flow->dsthost];
-	host_load = 1;
-
 	/* flow isolation (DRR++) */
 	if (flow->deficit <= 0) {
 		/* Keep all flows with deficits out of the sparse and decaying
@@ -2055,11 +2087,8 @@  static struct sk_buff *cake_dequeue(struct Qdisc *sch)
 				b->sparse_flow_count--;
 				b->bulk_flow_count++;
 
-				if (cake_dsrc(q->flow_mode))
-					srchost->srchost_bulk_flow_count++;
-
-				if (cake_ddst(q->flow_mode))
-					dsthost->dsthost_bulk_flow_count++;
+				cake_inc_srchost_bulk_flow_count(b, flow, q->flow_mode);
+				cake_inc_dsthost_bulk_flow_count(b, flow, q->flow_mode);
 
 				flow->set = CAKE_SET_BULK;
 			} else {
@@ -2071,19 +2100,7 @@  static struct sk_buff *cake_dequeue(struct Qdisc *sch)
 			}
 		}
 
-		if (cake_dsrc(q->flow_mode))
-			host_load = max(host_load, srchost->srchost_bulk_flow_count);
-
-		if (cake_ddst(q->flow_mode))
-			host_load = max(host_load, dsthost->dsthost_bulk_flow_count);
-
-		WARN_ON(host_load > CAKE_QUEUES);
-
-		/* The get_random_u16() is a way to apply dithering to avoid
-		 * accumulating roundoff errors
-		 */
-		flow->deficit += (b->flow_quantum * quantum_div[host_load] +
-				  get_random_u16()) >> 16;
+		flow->deficit += cake_get_flow_quantum(b, flow, q->flow_mode);
 		list_move_tail(&flow->flowchain, &b->old_flows);
 
 		goto retry;
@@ -2107,11 +2124,8 @@  static struct sk_buff *cake_dequeue(struct Qdisc *sch)
 				if (flow->set == CAKE_SET_BULK) {
 					b->bulk_flow_count--;
 
-					if (cake_dsrc(q->flow_mode))
-						srchost->srchost_bulk_flow_count--;
-
-					if (cake_ddst(q->flow_mode))
-						dsthost->dsthost_bulk_flow_count--;
+					cake_dec_srchost_bulk_flow_count(b, flow, q->flow_mode);
+					cake_dec_dsthost_bulk_flow_count(b, flow, q->flow_mode);
 
 					b->decaying_flow_count++;
 				} else if (flow->set == CAKE_SET_SPARSE ||
@@ -2129,12 +2143,8 @@  static struct sk_buff *cake_dequeue(struct Qdisc *sch)
 				else if (flow->set == CAKE_SET_BULK) {
 					b->bulk_flow_count--;
 
-					if (cake_dsrc(q->flow_mode))
-						srchost->srchost_bulk_flow_count--;
-
-					if (cake_ddst(q->flow_mode))
-						dsthost->dsthost_bulk_flow_count--;
-
+					cake_dec_srchost_bulk_flow_count(b, flow, q->flow_mode);
+					cake_dec_dsthost_bulk_flow_count(b, flow, q->flow_mode);
 				} else
 					b->decaying_flow_count--;