diff mbox series

[2/4] target/arm: Fix UMOPA/UMOPS of 16-bit values

Message ID 20240722172957.1041231-3-peter.maydell@linaro.org (mailing list archive)
State New, archived
Headers show
Series target/arm: Various minor SME bugfixes | expand

Commit Message

Peter Maydell July 22, 2024, 5:29 p.m. UTC
The UMOPA/UMOPS instructions are supposed to multiply unsigned 8 or
16 bit elements and accumulate the products into a 64-bit element.
In the Arm ARM pseudocode, this is done with the usual
infinite-precision signed arithmetic.  However our implementation
doesn't quite get it right, because in the DEF_IMOP_64() macro we do:
  sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0);

where NTYPE and MTYPE are uint16_t or int16_t.  In the uint16_t case,
the C usual arithmetic conversions mean the values are converted to
"int" type and the multiply is done as a 32-bit multiply.  This means
that if the inputs are, for example, 0xffff and 0xffff then the
result is 0xFFFE0001 as an int, which is then promoted to uint64_t
for the accumulation into sum; this promotion incorrectly sign
extends the multiply.

Avoid the incorrect sign extension by casting to int64_t before
the multiply, so we do the multiply as 64-bit signed arithmetic,
which is a type large enough that the multiply can never
overflow into the sign bit.

(The equivalent 8-bit operations in DEF_IMOP_32() are fine, because
the 8-bit multiplies can never overflow into the sign bit of a
32-bit integer.)

Cc: qemu-stable@nongnu.org
Resolves: https://gitlab.com/qemu-project/qemu/-/issues/2372
Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
---
 target/arm/tcg/sme_helper.c | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)
diff mbox series

Patch

diff --git a/target/arm/tcg/sme_helper.c b/target/arm/tcg/sme_helper.c
index 5a6dd76489f..f9001f5213a 100644
--- a/target/arm/tcg/sme_helper.c
+++ b/target/arm/tcg/sme_helper.c
@@ -1146,10 +1146,10 @@  static uint64_t NAME(uint64_t n, uint64_t m, uint64_t a, uint8_t p, bool neg) \
     uint64_t sum = 0;                                                       \
     /* Apply P to N as a mask, making the inactive elements 0. */           \
     n &= expand_pred_h(p);                                                  \
-    sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0);                               \
-    sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16);                             \
-    sum += (NTYPE)(n >> 32) * (MTYPE)(m >> 32);                             \
-    sum += (NTYPE)(n >> 48) * (MTYPE)(m >> 48);                             \
+    sum += (int64_t)(NTYPE)(n >> 0) * (MTYPE)(m >> 0);                      \
+    sum += (int64_t)(NTYPE)(n >> 16) * (MTYPE)(m >> 16);                    \
+    sum += (int64_t)(NTYPE)(n >> 32) * (MTYPE)(m >> 32);                    \
+    sum += (int64_t)(NTYPE)(n >> 48) * (MTYPE)(m >> 48);                    \
     return neg ? a - sum : a + sum;                                         \
 }