diff mbox series

[v2,12/13] xen/device_tree: Introduce function to merge overlapping intervals

Message ID 20240409114543.3332150-13-luca.fancellu@arm.com (mailing list archive)
State Superseded
Headers show
Series Static shared memory followup v2 - pt1 | expand

Commit Message

Luca Fancellu April 9, 2024, 11:45 a.m. UTC
Introduce a function that given an array of cells containing
(address,size) intervals, merges the overlapping ones, returning
an array with no overlapping intervals.

The algorithm needs to sort the intervals by ascending order
address, so the sort() function already included in the codebase
is used, however in this case additional data is needed for the
compare function, to be able to extract the address from the
interval.
So add one argument to the sort() function and its compare
callback to have additional data and be able to pass, in this
case, the address length. In case the argument is not needed,
NULL can be provided.

Signed-off-by: Luca Fancellu <luca.fancellu@arm.com>
---
v2:
 - new patch
---
---
 xen/arch/arm/bootfdt.c        |   5 +-
 xen/arch/arm/io.c             |  11 ++-
 xen/arch/x86/extable.c        |   5 +-
 xen/common/device_tree.c      | 140 ++++++++++++++++++++++++++++++++++
 xen/include/xen/device_tree.h |  19 +++++
 xen/include/xen/sort.h        |  14 ++--
 6 files changed, 181 insertions(+), 13 deletions(-)

Comments

Luca Fancellu April 11, 2024, 9:50 a.m. UTC | #1
> On 9 Apr 2024, at 12:45, Luca Fancellu <Luca.Fancellu@arm.com> wrote:
> 
> Introduce a function that given an array of cells containing
> (address,size) intervals, merges the overlapping ones, returning
> an array with no overlapping intervals.
> 
> The algorithm needs to sort the intervals by ascending order
> address, so the sort() function already included in the codebase
> is used, however in this case additional data is needed for the
> compare function, to be able to extract the address from the
> interval.
> So add one argument to the sort() function and its compare
> callback to have additional data and be able to pass, in this
> case, the address length. In case the argument is not needed,
> NULL can be provided.
> 
> Signed-off-by: Luca Fancellu <luca.fancellu@arm.com>
> ---

Hi all,

I’ve just spotted an issue with the algorithm, the fix is this one:

diff --git a/xen/common/device_tree.c b/xen/common/device_tree.c
index 24914a80d03b..262385a041a8 100644
--- a/xen/common/device_tree.c
+++ b/xen/common/device_tree.c
@@ -2360,6 +2360,10 @@ int __init dt_merge_overlapping_addr_size_intervals(__be32 *reg, int *nr_cells,
             __be32 *tmp_last_cell_size = last_cell + addrcells;
 
             dt_set_cell(&tmp_last_cell_size, sizecells, new_size);
+
+            /* Last interval updated, so the end is changed */
+            end_last = start_last + size_last;
+
             /*
              * This current interval is merged with the last one, so remove this
              * interval and shift left all the remaining elements


--------------------------------

Now, I would like to write something about the algorithm to ease the reviewers,
the problem is that we have some intervals and we would like to merge the overlapping
ones, a simple algorithm can be found here: https://www.interviewbit.com/blog/merge-intervals/

Limitation now is that when merging the intervals, we don’t want to exceed the space needed
to store the information, for example:

sizecells: 1 (meaning one __be32, 4 byte)
Int1: start 0x0,                 size 0xFFFFFFFF
Int2: start 0xFFFFFFFF,  size 0x1000

We can’t merge them because the new size would be over 4 byte.

During the development of this algorithm I’ve prototyped it in Python, I’ll attach my script here
so that it’s easier to understand:

#!/usr/bin/env python3

def merge_intervals_inplace(intervals, size_limit):
    merged = intervals[:]
    last_idx = 0
    i = 1
    count = len(merged)

    if count == 1:
        return merged

    last_cell = merged[last_idx]
    start_last = last_cell[0]
    size_last = last_cell[1]
    end_last = start_last + size_last

    while i < count:

        start_current = merged[i][0]
        size_current = merged[i][1]
        end_current = start_current + size_current
        overlap = end_last >= start_current
        new_size = max(end_last, end_current) - start_last
        #print((f"last ({start_last},{end_last}),"
        #       f" curr ({start_current},{end_current}),"
        #       f" newsize: {new_size}"
        #    ))

        # If the current interval doesn't overlap with the last one, or even if
        # they overlap but the computed new size would be over the imposed
        # limit, then advance the last element by one position
        if (not overlap) or (overlap and new_size > size_limit):
            #print("advance last")
            last_idx += 1
            last_cell = merged[last_idx]
            start_last = last_cell[0]
            size_last = last_cell[1]
            end_last = start_last + size_last
        else:
            #print("merge")
            # Set new size for the last element, merging the last interval with
            # the current one
            merged[last_idx] = (start_last, new_size)
            # Update last elem interval end
            end_last = start_last + new_size
            # The current interval (i) is merged with the last, so remove it and
            # shift left all the remaining intervals
            merged = merged[:i] + merged[i+1:]
            # Now the array has less element since we merged two intervals
            count -= 1
            # Next iteration needs to start from the current index, skip
            # increment
            continue
        i += 1

    return merged


def print_interval(intervals):
    print("[", end='')
    for interval in intervals:
        s = interval[0]
        sz = interval[1]
        print(f" ({s},{sz}) ", end='')
    print("] -> ", end='')
    print("[", end='')
    for interval in intervals:
        s = interval[0]
        e = interval[0] + interval[1]
        print(f" ({s},{e}) ", end='')
    print("]")


def main(argv):
    limit=20

    # Array of intervals (start address, size)
    #banks = [(0,2), (5,1), (0,10), (10,7), (2,6)]
    banks = [(0,20), (20,5), (10,15), (5,15)]

    for interval in banks:
        if interval[1] > limit:
            raise Exception(f"{interval} size > limit ({limit})")

    # Sort by start address ascending order
    banks.sort(key=lambda a: a[0])

    print("IN (sorted) [(start,size)] -> [(start,end)]")
    print_interval(banks)

    banks = merge_intervals_inplace(banks, limit)

    print("OUT [(start,size)] -> [(start,end)]")
    print_interval(banks)


if __name__ == "__main__":
    main(sys.argv[1:])


Cheers,
Luca
Luca Fancellu April 11, 2024, 1:36 p.m. UTC | #2
> 
> I’ve just spotted an issue with the algorithm, the fix is this one:
> 
> diff --git a/xen/common/device_tree.c b/xen/common/device_tree.c
> index 24914a80d03b..262385a041a8 100644
> --- a/xen/common/device_tree.c
> +++ b/xen/common/device_tree.c
> @@ -2360,6 +2360,10 @@ int __init dt_merge_overlapping_addr_size_intervals(__be32 *reg, int *nr_cells,
>             __be32 *tmp_last_cell_size = last_cell + addrcells;
> 
>             dt_set_cell(&tmp_last_cell_size, sizecells, new_size);
> +
> +            /* Last interval updated, so the end is changed */
> +            end_last = start_last + size_last;
> +
>             /*
>              * This current interval is merged with the last one, so remove this
>              * interval and shift left all the remaining elements
> 

Apologies, this is the fix:

diff --git a/xen/common/device_tree.c b/xen/common/device_tree.c
index 24914a80d03b..9a2f5b27aa9b 100644
--- a/xen/common/device_tree.c
+++ b/xen/common/device_tree.c
@@ -2360,6 +2360,10 @@ int __init dt_merge_overlapping_addr_size_intervals(__be32 *reg, int *nr_cells,
             __be32 *tmp_last_cell_size = last_cell + addrcells;
 
             dt_set_cell(&tmp_last_cell_size, sizecells, new_size);
+
+            /* Last interval updated, so the end is changed */
+            end_last = start_last + new_size;
+
             /*
              * This current interval is merged with the last one, so remove this
              * interval and shift left all the remaining elements

So instead of “size_last” -> “new_size”.

Sorry for the noise.

Cheers,
Luca
Jan Beulich April 18, 2024, 6:28 a.m. UTC | #3
On 09.04.2024 13:45, Luca Fancellu wrote:
> --- a/xen/arch/x86/extable.c
> +++ b/xen/arch/x86/extable.c
> @@ -23,7 +23,8 @@ static inline unsigned long ex_cont(const struct exception_table_entry *x)
>  	return EX_FIELD(x, cont);
>  }
>  
> -static int init_or_livepatch cf_check cmp_ex(const void *a, const void *b)
> +static int init_or_livepatch cf_check cmp_ex(const void *a, const void *b,
> +                                             const void *data)
>  {
>  	const struct exception_table_entry *l = a, *r = b;
>  	unsigned long lip = ex_addr(l);
> @@ -53,7 +54,7 @@ void init_or_livepatch sort_exception_table(struct exception_table_entry *start,
>                                   const struct exception_table_entry *stop)
>  {
>      sort(start, stop - start,
> -         sizeof(struct exception_table_entry), cmp_ex, swap_ex);
> +         sizeof(struct exception_table_entry), cmp_ex, swap_ex, NULL);
>  }

Not the least because of this addition of an entirely useless parameter / argument
I'm not in favor of ...

> --- a/xen/include/xen/sort.h
> +++ b/xen/include/xen/sort.h
> @@ -23,8 +23,8 @@
>  extern gnu_inline
>  #endif
>  void sort(void *base, size_t num, size_t size,
> -          int (*cmp)(const void *a, const void *b),
> -          void (*swap)(void *a, void *b, size_t size))
> +          int (*cmp)(const void *a, const void *b, const void *data),
> +          void (*swap)(void *a, void *b, size_t size), const void *cmp_data)
>  {

... this change. Consider you were doing this on a C library you cannot change.
You'd have to find a different solution anyway. And the way we have sort()
right now is matching the C spec. The change to do renders things unexpected to
anyone wanting to use this function in a spec-compliant way. One approach may
be to make an adjustment to data representation, such that the extra reference
data is accessible through the pointers already being passed.

Jan
Luca Fancellu April 18, 2024, 7:44 a.m. UTC | #4
> On 18 Apr 2024, at 07:28, Jan Beulich <jbeulich@suse.com> wrote:
> 
> On 09.04.2024 13:45, Luca Fancellu wrote:
>> --- a/xen/arch/x86/extable.c
>> +++ b/xen/arch/x86/extable.c
>> @@ -23,7 +23,8 @@ static inline unsigned long ex_cont(const struct exception_table_entry *x)
>> return EX_FIELD(x, cont);
>> }
>> 
>> -static int init_or_livepatch cf_check cmp_ex(const void *a, const void *b)
>> +static int init_or_livepatch cf_check cmp_ex(const void *a, const void *b,
>> +                                             const void *data)
>> {
>> const struct exception_table_entry *l = a, *r = b;
>> unsigned long lip = ex_addr(l);
>> @@ -53,7 +54,7 @@ void init_or_livepatch sort_exception_table(struct exception_table_entry *start,
>>                                  const struct exception_table_entry *stop)
>> {
>>     sort(start, stop - start,
>> -         sizeof(struct exception_table_entry), cmp_ex, swap_ex);
>> +         sizeof(struct exception_table_entry), cmp_ex, swap_ex, NULL);
>> }
> 
> Not the least because of this addition of an entirely useless parameter / argument

Well it’s not useless in this patch, given that without it I couldn’t know the size of the address
element, however ...

> I'm not in favor of ...
> 
>> --- a/xen/include/xen/sort.h
>> +++ b/xen/include/xen/sort.h
>> @@ -23,8 +23,8 @@
>> extern gnu_inline
>> #endif
>> void sort(void *base, size_t num, size_t size,
>> -          int (*cmp)(const void *a, const void *b),
>> -          void (*swap)(void *a, void *b, size_t size))
>> +          int (*cmp)(const void *a, const void *b, const void *data),
>> +          void (*swap)(void *a, void *b, size_t size), const void *cmp_data)
>> {
> 
> ... this change. Consider you were doing this on a C library you cannot change.
> You'd have to find a different solution anyway.

I get your point here, we should not change standard functions.

> And the way we have sort()
> right now is matching the C spec. The change to do renders things unexpected to
> anyone wanting to use this function in a spec-compliant way. One approach may
> be to make an adjustment to data representation, such that the extra reference
> data is accessible through the pointers already being passed.
> 
> Jan
> 

Anyway in the end this patch was dropped for other reasons.

Cheers,
Luca
diff mbox series

Patch

diff --git a/xen/arch/arm/bootfdt.c b/xen/arch/arm/bootfdt.c
index 4d708442a19e..a2aba67b45e7 100644
--- a/xen/arch/arm/bootfdt.c
+++ b/xen/arch/arm/bootfdt.c
@@ -521,7 +521,8 @@  static void __init early_print_info(void)
 }
 
 /* This function assumes that memory regions are not overlapped */
-static int __init cmp_memory_node(const void *key, const void *elem)
+static int __init cmp_memory_node(const void *key, const void *elem,
+                                  const void *data)
 {
     const struct membank *handler0 = key;
     const struct membank *handler1 = elem;
@@ -569,7 +570,7 @@  size_t __init boot_fdt_info(const void *fdt, paddr_t paddr)
      * the banks sorted in ascending order. So sort them through.
      */
     sort(mem->bank, mem->nr_banks, sizeof(struct membank),
-         cmp_memory_node, swap_memory_node);
+         cmp_memory_node, swap_memory_node, NULL);
 
     early_print_info();
 
diff --git a/xen/arch/arm/io.c b/xen/arch/arm/io.c
index 96c740d5636c..c1814491fec4 100644
--- a/xen/arch/arm/io.c
+++ b/xen/arch/arm/io.c
@@ -57,7 +57,7 @@  static enum io_state handle_write(const struct mmio_handler *handler,
 }
 
 /* This function assumes that mmio regions are not overlapped */
-static int cmp_mmio_handler(const void *key, const void *elem)
+static int cmp_mmio_handler(const void *key, const void *elem, const void *data)
 {
     const struct mmio_handler *handler0 = key;
     const struct mmio_handler *handler1 = elem;
@@ -71,6 +71,11 @@  static int cmp_mmio_handler(const void *key, const void *elem)
     return 0;
 }
 
+static int bsearch_cmp_mmio_handler(const void *key, const void *elem)
+{
+    return cmp_mmio_handler(key, elem, NULL);
+}
+
 static void swap_mmio_handler(void *_a, void *_b, size_t size)
 {
     struct mmio_handler *a = _a, *b = _b;
@@ -87,7 +92,7 @@  static const struct mmio_handler *find_mmio_handler(struct domain *d,
 
     read_lock(&vmmio->lock);
     handler = bsearch(&key, vmmio->handlers, vmmio->num_entries,
-                      sizeof(*handler), cmp_mmio_handler);
+                      sizeof(*handler), bsearch_cmp_mmio_handler);
     read_unlock(&vmmio->lock);
 
     return handler;
@@ -219,7 +224,7 @@  void register_mmio_handler(struct domain *d,
 
     /* Sort mmio handlers in ascending order based on base address */
     sort(vmmio->handlers, vmmio->num_entries, sizeof(struct mmio_handler),
-         cmp_mmio_handler, swap_mmio_handler);
+         cmp_mmio_handler, swap_mmio_handler, NULL);
 
     write_unlock(&vmmio->lock);
 }
diff --git a/xen/arch/x86/extable.c b/xen/arch/x86/extable.c
index 8415cd1fa249..589e251b29b9 100644
--- a/xen/arch/x86/extable.c
+++ b/xen/arch/x86/extable.c
@@ -23,7 +23,8 @@  static inline unsigned long ex_cont(const struct exception_table_entry *x)
 	return EX_FIELD(x, cont);
 }
 
-static int init_or_livepatch cf_check cmp_ex(const void *a, const void *b)
+static int init_or_livepatch cf_check cmp_ex(const void *a, const void *b,
+                                             const void *data)
 {
 	const struct exception_table_entry *l = a, *r = b;
 	unsigned long lip = ex_addr(l);
@@ -53,7 +54,7 @@  void init_or_livepatch sort_exception_table(struct exception_table_entry *start,
                                  const struct exception_table_entry *stop)
 {
     sort(start, stop - start,
-         sizeof(struct exception_table_entry), cmp_ex, swap_ex);
+         sizeof(struct exception_table_entry), cmp_ex, swap_ex, NULL);
 }
 
 void __init sort_exception_tables(void)
diff --git a/xen/common/device_tree.c b/xen/common/device_tree.c
index 8d1017a49d80..24914a80d03b 100644
--- a/xen/common/device_tree.c
+++ b/xen/common/device_tree.c
@@ -18,6 +18,7 @@ 
 #include <xen/lib.h>
 #include <xen/libfdt/libfdt.h>
 #include <xen/mm.h>
+#include <xen/sort.h>
 #include <xen/stdarg.h>
 #include <xen/string.h>
 #include <xen/cpumask.h>
@@ -2243,6 +2244,145 @@  int dt_get_pci_domain_nr(struct dt_device_node *node)
     return (u16)domain;
 }
 
+static int __init cmp_mem_reg_cell(const void *key, const void *elem,
+                                   const void *data)
+{
+    const __be32 *cell0 = key;
+    const __be32 *cell1 = elem;
+    const int *addrcells = data;
+    u64 addr0, addr1;
+
+    /* Same address, same element */
+    if ( cell0 == cell1 )
+        return 0;
+
+    BUG_ON(!addrcells || !*addrcells || *addrcells > 2);
+    addr0 = dt_read_number(cell0, *addrcells);
+    addr1 = dt_read_number(cell1, *addrcells);
+
+    if ( addr0 < addr1 )
+        return -1;
+
+    if ( addr0 > addr1 )
+        return 1;
+
+    return 0;
+}
+
+static void __init swap_mem_reg_cell(void *_a, void *_b, size_t size)
+{
+    __be32 tmp[4];
+    __be32 *cell0 = _a;
+    __be32 *cell1 = _b;
+
+    BUG_ON(size > (4 * sizeof(__be32)));
+
+    /* Don't swap the same element */
+    if ( cell0 == cell1 )
+        return;
+
+    /* Swap cell0 and cell1 */
+    memcpy(tmp, cell0, size);
+    memcpy(cell0, cell1, size);
+    memcpy(cell1, tmp, size);
+}
+
+int __init dt_merge_overlapping_addr_size_intervals(__be32 *reg, int *nr_cells,
+                                                    int addrcells,
+                                                    int sizecells)
+{
+    int reg_size = addrcells + sizecells;
+    u64 start_last, size_last, end_last;
+    unsigned int count;
+    unsigned int i = 1;
+    __be32 *last_cell = reg;
+
+    BUG_ON(!nr_cells || !reg);
+
+    if ( (addrcells < 1) || (addrcells > 2) || (sizecells < 1) ||
+         (sizecells > 2) )
+        return -EINVAL;
+
+    count = *nr_cells / reg_size;
+    /* Early stop, only one interval in the array */
+    if ( count == 1 )
+        return 0;
+
+    /* Sort cells by ascending address */
+    sort(reg, count, reg_size * sizeof(__be32), cmp_mem_reg_cell,
+         swap_mem_reg_cell, &addrcells);
+
+    /*
+     * Algorithm to merge overlapping intervals in place, prerequisite for the
+     * intervals is that they must be sorted with ascending order address
+     */
+    start_last = dt_read_number(last_cell, addrcells);
+    size_last = dt_read_number(last_cell + addrcells, sizecells);
+    end_last = start_last + size_last;
+
+    /* The sum is too big */
+    if ( end_last < start_last )
+        return -ERANGE;
+
+    while ( i < count )
+    {
+        __be32 *current_cell = &reg[i * reg_size];
+        u64 start_current = dt_read_number(current_cell, addrcells);
+        u64 size_current = dt_read_number(current_cell + addrcells, sizecells);
+        u64 end_current = start_current + size_current;
+        bool overlap = end_last >= start_current;
+        u64 new_size;
+
+        /* The sum is too big */
+        if ( end_current < start_current )
+            return -ERANGE;
+
+        new_size = MAX(end_last, end_current) - start_last;
+
+        /*
+         * If the last interval end is not connected with the current one, or
+         * if they are connected but the new computed size would not be
+         * representable given the input sizecells, don't merge and advance the
+         * last of one position.
+         */
+        if ( !overlap ||
+             (overlap && (sizecells < 2) && (new_size > UINT32_MAX)) )
+        {
+            /* last element doesn't overlap with the current, advance it */
+            last_cell = last_cell + reg_size;
+            start_last = dt_read_number(last_cell, addrcells);
+            size_last = dt_read_number(last_cell + addrcells, sizecells);
+            end_last = start_last + size_last;
+        }
+        else
+        {
+            /* Temporary pointer because dt_set_cell modifies it */
+            __be32 *tmp_last_cell_size = last_cell + addrcells;
+
+            dt_set_cell(&tmp_last_cell_size, sizecells, new_size);
+            /*
+             * This current interval is merged with the last one, so remove this
+             * interval and shift left all the remaining elements
+             */
+            memmove(current_cell, current_cell + reg_size,
+                    (reg_size * (count - i)) * sizeof(__be32));
+            /* Now the array has less element since we merged two intervals */
+            count--;
+            /*
+             * Next iteration needs to start from the current index, skip
+             * increment
+             */
+            continue;
+        }
+        /* Point to the next element in the array */
+        i++;
+    }
+
+    /* Now count holds the number of intervals in the array */
+    *nr_cells = count * reg_size;
+    return 0;
+}
+
 /*
  * Local variables:
  * mode: C
diff --git a/xen/include/xen/device_tree.h b/xen/include/xen/device_tree.h
index e6287305a7b5..95a88a0d3bc9 100644
--- a/xen/include/xen/device_tree.h
+++ b/xen/include/xen/device_tree.h
@@ -946,6 +946,25 @@  int dt_get_pci_domain_nr(struct dt_device_node *node);
 
 struct dt_device_node *dt_find_node_by_phandle(dt_phandle handle);
 
+/**
+ * dt_merge_overlapping_addr_size_intervals - Given an array of (address, size)
+ *   cells intervals, returns an array with the overlapping intervals merged.
+ * @reg: Array of (address, size) cells.
+ * @nr_cells: Total number of cells in the array.
+ * @addrcells: Size of the "address" in number of cells.
+ * @sizecells: Size of the "size" in number of cells.
+ *
+ * Return:
+ * * 0       - On success.
+ * * -ERANGE - The interval computation results are not representable.
+ *             (address + size results in truncation overflow).
+ * * -EINVAL - addrcells or sizecells are outside the interval [1, 2]
+ *
+ * Returns in nr_cells the new number of cells in the array.
+ */
+int dt_merge_overlapping_addr_size_intervals(__be32 *reg, int *nr_cells,
+                                             int addrcells, int sizecells);
+
 #ifdef CONFIG_DEVICE_TREE_DEBUG
 #define dt_dprintk(fmt, args...)  \
     printk(XENLOG_DEBUG fmt, ## args)
diff --git a/xen/include/xen/sort.h b/xen/include/xen/sort.h
index b95328628465..1bd4420457c0 100644
--- a/xen/include/xen/sort.h
+++ b/xen/include/xen/sort.h
@@ -23,8 +23,8 @@ 
 extern gnu_inline
 #endif
 void sort(void *base, size_t num, size_t size,
-          int (*cmp)(const void *a, const void *b),
-          void (*swap)(void *a, void *b, size_t size))
+          int (*cmp)(const void *a, const void *b, const void *data),
+          void (*swap)(void *a, void *b, size_t size), const void *cmp_data)
 {
     /* pre-scale counters for performance */
     size_t i = (num / 2) * size, n = num * size, c, r;
@@ -35,9 +35,10 @@  void sort(void *base, size_t num, size_t size,
         for ( r = i -= size; r * 2 + size < n; r = c )
         {
             c = r * 2 + size;
-            if ( (c < n - size) && (cmp(base + c, base + c + size) < 0) )
+            if ( (c < n - size) &&
+                 (cmp(base + c, base + c + size, cmp_data) < 0) )
                 c += size;
-            if ( cmp(base + r, base + c) >= 0 )
+            if ( cmp(base + r, base + c, cmp_data) >= 0 )
                 break;
             swap(base + r, base + c, size);
         }
@@ -51,9 +52,10 @@  void sort(void *base, size_t num, size_t size,
         for ( r = 0; r * 2 + size < i; r = c )
         {
             c = r * 2 + size;
-            if ( (c < i - size) && (cmp(base + c, base + c + size) < 0) )
+            if ( (c < i - size) &&
+                 (cmp(base + c, base + c + size, cmp_data) < 0) )
                 c += size;
-            if ( cmp(base + r, base + c) >= 0 )
+            if ( cmp(base + r, base + c, cmp_data) >= 0 )
                 break;
             swap(base + r, base + c, size);
         }