diff mbox series

[v2,22/37] arm64: mte: Add in-kernel MTE helpers

Message ID 4ac1ed624dd1b0851d8cf2861b4f4aac4d2dbc83.1600204505.git.andreyknvl@google.com (mailing list archive)
State New, archived
Headers show
Series kasan: add hardware tag-based mode for arm64 | expand

Commit Message

Andrey Konovalov Sept. 15, 2020, 9:16 p.m. UTC
From: Vincenzo Frascino <vincenzo.frascino@arm.com>

Provide helper functions to manipulate allocation and pointer tags for
kernel addresses.

Low-level helper functions (mte_assign_*, written in assembly) operate
tag values from the [0x0, 0xF] range. High-level helper functions
(mte_get/set_*) use the [0xF0, 0xFF] range to preserve compatibility
with normal kernel pointers that have 0xFF in their top byte.

MTE_GRANULE_SIZE and related definitions are moved to mte-def.h header
that doesn't have any dependencies and is safe to include into any
low-level header.

Signed-off-by: Vincenzo Frascino <vincenzo.frascino@arm.com>
Co-developed-by: Andrey Konovalov <andreyknvl@google.com>
Signed-off-by: Andrey Konovalov <andreyknvl@google.com>
---
Change-Id: I1b5230254f90dc21a913447cb17f07fea7944ece
---
 arch/arm64/include/asm/esr.h         |  1 +
 arch/arm64/include/asm/mte-helpers.h | 48 ++++++++++++++++++++++++++++
 arch/arm64/include/asm/mte.h         | 17 ++++++----
 arch/arm64/kernel/mte.c              | 48 ++++++++++++++++++++++++++++
 arch/arm64/lib/mte.S                 | 17 ++++++++++
 5 files changed, 125 insertions(+), 6 deletions(-)
 create mode 100644 arch/arm64/include/asm/mte-helpers.h

Comments

Catalin Marinas Sept. 17, 2020, 1:46 p.m. UTC | #1
On Tue, Sep 15, 2020 at 11:16:04PM +0200, Andrey Konovalov wrote:
> diff --git a/arch/arm64/include/asm/mte-helpers.h b/arch/arm64/include/asm/mte-helpers.h
> new file mode 100644
> index 000000000000..5dc2d443851b
> --- /dev/null
> +++ b/arch/arm64/include/asm/mte-helpers.h
> @@ -0,0 +1,48 @@
> +/* SPDX-License-Identifier: GPL-2.0 */
> +/*
> + * Copyright (C) 2020 ARM Ltd.
> + */
> +#ifndef __ASM_MTE_ASM_H
> +#define __ASM_MTE_ASM_H
> +
> +#define __MTE_PREAMBLE		".arch armv8.5-a\n.arch_extension memtag\n"

Because of how the .arch overrides a previous .arch, we should follow
the ARM64_ASM_PREAMBLE introduced in commit 1764c3edc668 ("arm64: use a
common .arch preamble for inline assembly"). The above should be
something like:

#define __MTE_PREAMBLE	ARM64_ASM_PREAMBLE ".arch_extension memtag"

with the ARM64_ASM_PREAMBLE adjusted to armv8.5-a if available.

> +#define MTE_GRANULE_SIZE	UL(16)
> +#define MTE_GRANULE_MASK	(~(MTE_GRANULE_SIZE - 1))
> +#define MTE_TAG_SHIFT		56
> +#define MTE_TAG_SIZE		4
> +#define MTE_TAG_MASK		GENMASK((MTE_TAG_SHIFT + (MTE_TAG_SIZE - 1)), MTE_TAG_SHIFT)
> +#define MTE_TAG_MAX		(MTE_TAG_MASK >> MTE_TAG_SHIFT)

In v1 I suggested we keep those definitions in mte-def.h (or
mte-hwdef.h) so that they can be included in cache.h. Anything else
should go in mte.h, I don't see the point of two headers for various MTE
function prototypes.

> +
> +#ifndef __ASSEMBLY__
> +
> +#include <linux/types.h>
> +
> +#ifdef CONFIG_ARM64_MTE
> +
> +#define mte_get_ptr_tag(ptr)	((u8)(((u64)(ptr)) >> MTE_TAG_SHIFT))

I wonder whether this could also be an inline function that takes a void
*ptr.

> diff --git a/arch/arm64/kernel/mte.c b/arch/arm64/kernel/mte.c
> index 52a0638ed967..e238ffde2679 100644
> --- a/arch/arm64/kernel/mte.c
> +++ b/arch/arm64/kernel/mte.c
> @@ -72,6 +74,52 @@ int memcmp_pages(struct page *page1, struct page *page2)
>  	return ret;
>  }
>  
> +u8 mte_get_mem_tag(void *addr)
> +{
> +	if (system_supports_mte())
> +		asm volatile(ALTERNATIVE("ldr %0, [%0]",
> +					 __MTE_PREAMBLE "ldg %0, [%0]",
> +					 ARM64_MTE)
> +			     : "+r" (addr));

This doesn't do what you think it does. LDG indeed reads the tag from
memory but LDR loads the actual data at that address. Instead of the
first LDR, you may want something like "mov %0, #0xf << 56" (and use
some macros to avoid the hard-coded 56).

> +
> +	return 0xF0 | mte_get_ptr_tag(addr);
> +}
> +
> +u8 mte_get_random_tag(void)
> +{
> +	u8 tag = 0xF;
> +	u64 addr = 0;
> +
> +	if (system_supports_mte()) {
> +		asm volatile(ALTERNATIVE("add %0, %0, %0",
> +					 __MTE_PREAMBLE "irg %0, %0",
> +					 ARM64_MTE)
> +			     : "+r" (addr));

What was the intention here? The first ADD doubles the pointer value and
gets a tag out of it (possibly doubled as well, depends on the carry
from bit 55). Better use something like "orr %0, %0, #0xf << 56".

> +
> +		tag = mte_get_ptr_tag(addr);
> +	}
> +
> +	return 0xF0 | tag;

This function return seems inconsistent with the previous one. I'd
prefer the return line to be the same in both.

> +}
> +
> +void *mte_set_mem_tag_range(void *addr, size_t size, u8 tag)
> +{
> +	void *ptr = addr;
> +
> +	if ((!system_supports_mte()) || (size == 0))
> +		return addr;
> +
> +	/* Make sure that size is aligned. */
> +	WARN_ON(size & (MTE_GRANULE_SIZE - 1));
> +
> +	tag = 0xF0 | (tag & 0xF);

No point in tag & 0xf, the top nibble doesn't matter as you or 0xf0 in.

> +	ptr = (void *)__tag_set(ptr, tag);
> +
> +	mte_assign_mem_tag_range(ptr, size);
> +
> +	return ptr;
> +}
> +
>  static void update_sctlr_el1_tcf0(u64 tcf0)
>  {
>  	/* ISB required for the kernel uaccess routines */
> diff --git a/arch/arm64/lib/mte.S b/arch/arm64/lib/mte.S
> index 03ca6d8b8670..cc2c3a378c00 100644
> --- a/arch/arm64/lib/mte.S
> +++ b/arch/arm64/lib/mte.S
> @@ -149,3 +149,20 @@ SYM_FUNC_START(mte_restore_page_tags)
>  
>  	ret
>  SYM_FUNC_END(mte_restore_page_tags)
> +
> +/*
> + * Assign allocation tags for a region of memory based on the pointer tag
> + *   x0 - source pointer
> + *   x1 - size
> + *
> + * Note: size must be non-zero and MTE_GRANULE_SIZE aligned
> + */
> +SYM_FUNC_START(mte_assign_mem_tag_range)
> +	/* if (src == NULL) return; */
> +	cbz	x0, 2f
> +1:	stg	x0, [x0]
> +	add	x0, x0, #MTE_GRANULE_SIZE
> +	sub	x1, x1, #MTE_GRANULE_SIZE
> +	cbnz	x1, 1b
> +2:	ret
> +SYM_FUNC_END(mte_assign_mem_tag_range)

I thought Vincenzo agreed to my comments on the previous version w.r.t.
the fist cbz and the last cbnz:

https://lore.kernel.org/linux-arm-kernel/921c4ed0-b5b5-bc01-5418-c52d80f1af59@arm.com/
Vincenzo Frascino Sept. 17, 2020, 2:21 p.m. UTC | #2
On 9/17/20 2:46 PM, Catalin Marinas wrote:
> On Tue, Sep 15, 2020 at 11:16:04PM +0200, Andrey Konovalov wrote:
>> diff --git a/arch/arm64/include/asm/mte-helpers.h b/arch/arm64/include/asm/mte-helpers.h
>> new file mode 100644
>> index 000000000000..5dc2d443851b
>> --- /dev/null
>> +++ b/arch/arm64/include/asm/mte-helpers.h
>> @@ -0,0 +1,48 @@
>> +/* SPDX-License-Identifier: GPL-2.0 */
>> +/*
>> + * Copyright (C) 2020 ARM Ltd.
>> + */
>> +#ifndef __ASM_MTE_ASM_H
>> +#define __ASM_MTE_ASM_H
>> +
>> +#define __MTE_PREAMBLE		".arch armv8.5-a\n.arch_extension memtag\n"
> 
> Because of how the .arch overrides a previous .arch, we should follow
> the ARM64_ASM_PREAMBLE introduced in commit 1764c3edc668 ("arm64: use a
> common .arch preamble for inline assembly"). The above should be
> something like:
> 
> #define __MTE_PREAMBLE	ARM64_ASM_PREAMBLE ".arch_extension memtag"
> 
> with the ARM64_ASM_PREAMBLE adjusted to armv8.5-a if available.

Good idea, I was not aware of commit 1764c3edc668. I will fix it accordingly.

> 
>> +#define MTE_GRANULE_SIZE	UL(16)
>> +#define MTE_GRANULE_MASK	(~(MTE_GRANULE_SIZE - 1))
>> +#define MTE_TAG_SHIFT		56
>> +#define MTE_TAG_SIZE		4
>> +#define MTE_TAG_MASK		GENMASK((MTE_TAG_SHIFT + (MTE_TAG_SIZE - 1)), MTE_TAG_SHIFT)
>> +#define MTE_TAG_MAX		(MTE_TAG_MASK >> MTE_TAG_SHIFT)
> 
> In v1 I suggested we keep those definitions in mte-def.h (or
> mte-hwdef.h) so that they can be included in cache.h. Anything else
> should go in mte.h, I don't see the point of two headers for various MTE
> function prototypes.
> 

This is what I did in my patches I shared with Andrey. I suppose that since in
this version he introduced some functions that are needed in this file, he
reverted to the old name (mte-helper.h).

>> +
>> +#ifndef __ASSEMBLY__
>> +
>> +#include <linux/types.h>
>> +
>> +#ifdef CONFIG_ARM64_MTE
>> +
>> +#define mte_get_ptr_tag(ptr)	((u8)(((u64)(ptr)) >> MTE_TAG_SHIFT))
> 
> I wonder whether this could also be an inline function that takes a void
> *ptr.
> 
>> diff --git a/arch/arm64/kernel/mte.c b/arch/arm64/kernel/mte.c
>> index 52a0638ed967..e238ffde2679 100644
>> --- a/arch/arm64/kernel/mte.c
>> +++ b/arch/arm64/kernel/mte.c
>> @@ -72,6 +74,52 @@ int memcmp_pages(struct page *page1, struct page *page2)
>>  	return ret;
>>  }
>>  
>> +u8 mte_get_mem_tag(void *addr)
>> +{
>> +	if (system_supports_mte())
>> +		asm volatile(ALTERNATIVE("ldr %0, [%0]",
>> +					 __MTE_PREAMBLE "ldg %0, [%0]",
>> +					 ARM64_MTE)
>> +			     : "+r" (addr));
> 
> This doesn't do what you think it does. LDG indeed reads the tag from
> memory but LDR loads the actual data at that address. Instead of the
> first LDR, you may want something like "mov %0, #0xf << 56" (and use
> some macros to avoid the hard-coded 56).
> 

The result of the load should never be used since it is meaningful only if
system_supports_mte(). It should be only required for compilation purposes.

Said that, I think I like more your solution hence I am going to adopt it.

>> +
>> +	return 0xF0 | mte_get_ptr_tag(addr);
>> +}
>> +
>> +u8 mte_get_random_tag(void)
>> +{
>> +	u8 tag = 0xF;
>> +	u64 addr = 0;
>> +
>> +	if (system_supports_mte()) {
>> +		asm volatile(ALTERNATIVE("add %0, %0, %0",
>> +					 __MTE_PREAMBLE "irg %0, %0",
>> +					 ARM64_MTE)
>> +			     : "+r" (addr));
> 
> What was the intention here? The first ADD doubles the pointer value and
> gets a tag out of it (possibly doubled as well, depends on the carry
> from bit 55). Better use something like "orr %0, %0, #0xf << 56".
>

Same as above but I will use the orr in the next version.

>> +
>> +		tag = mte_get_ptr_tag(addr);
>> +	}
>> +
>> +	return 0xF0 | tag;
> 
> This function return seems inconsistent with the previous one. I'd
> prefer the return line to be the same in both.
> 

The reason why it is different is that in this function extracting the tag from
the address makes sense only if irg is executed.

I can initialize addr to 0xf << 56 and make them the same.

>> +}
>> +
>> +void *mte_set_mem_tag_range(void *addr, size_t size, u8 tag)
>> +{
>> +	void *ptr = addr;
>> +
>> +	if ((!system_supports_mte()) || (size == 0))
>> +		return addr;
>> +
>> +	/* Make sure that size is aligned. */
>> +	WARN_ON(size & (MTE_GRANULE_SIZE - 1));
>> +
>> +	tag = 0xF0 | (tag & 0xF);
> 
> No point in tag & 0xf, the top nibble doesn't matter as you or 0xf0 in.
> 

Agree, will remove in the next version.

>> +	ptr = (void *)__tag_set(ptr, tag);
>> +
>> +	mte_assign_mem_tag_range(ptr, size);
>> +
>> +	return ptr;
>> +}
>> +
>>  static void update_sctlr_el1_tcf0(u64 tcf0)
>>  {
>>  	/* ISB required for the kernel uaccess routines */
>> diff --git a/arch/arm64/lib/mte.S b/arch/arm64/lib/mte.S
>> index 03ca6d8b8670..cc2c3a378c00 100644
>> --- a/arch/arm64/lib/mte.S
>> +++ b/arch/arm64/lib/mte.S
>> @@ -149,3 +149,20 @@ SYM_FUNC_START(mte_restore_page_tags)
>>  
>>  	ret
>>  SYM_FUNC_END(mte_restore_page_tags)
>> +
>> +/*
>> + * Assign allocation tags for a region of memory based on the pointer tag
>> + *   x0 - source pointer
>> + *   x1 - size
>> + *
>> + * Note: size must be non-zero and MTE_GRANULE_SIZE aligned
>> + */
>> +SYM_FUNC_START(mte_assign_mem_tag_range)
>> +	/* if (src == NULL) return; */
>> +	cbz	x0, 2f
>> +1:	stg	x0, [x0]
>> +	add	x0, x0, #MTE_GRANULE_SIZE
>> +	sub	x1, x1, #MTE_GRANULE_SIZE
>> +	cbnz	x1, 1b
>> +2:	ret
>> +SYM_FUNC_END(mte_assign_mem_tag_range)
> 
> I thought Vincenzo agreed to my comments on the previous version w.r.t.
> the fist cbz and the last cbnz:
> 
> https://lore.kernel.org/linux-arm-kernel/921c4ed0-b5b5-bc01-5418-c52d80f1af59@arm.com/
> 

Ups, this is my fault, I just realized I missed to unstash this change. Will be
present in the next version.
Vincenzo Frascino Sept. 17, 2020, 4:17 p.m. UTC | #3
On 9/17/20 2:46 PM, Catalin Marinas wrote:
>> diff --git a/arch/arm64/kernel/mte.c b/arch/arm64/kernel/mte.c
>> index 52a0638ed967..e238ffde2679 100644
>> --- a/arch/arm64/kernel/mte.c
>> +++ b/arch/arm64/kernel/mte.c
>> @@ -72,6 +74,52 @@ int memcmp_pages(struct page *page1, struct page *page2)
>>  	return ret;
>>  }
>>  
>> +u8 mte_get_mem_tag(void *addr)
>> +{
>> +	if (system_supports_mte())
>> +		asm volatile(ALTERNATIVE("ldr %0, [%0]",
>> +					 __MTE_PREAMBLE "ldg %0, [%0]",
>> +					 ARM64_MTE)
>> +			     : "+r" (addr));
> This doesn't do what you think it does. LDG indeed reads the tag from
> memory but LDR loads the actual data at that address. Instead of the
> first LDR, you may want something like "mov %0, #0xf << 56" (and use
> some macros to avoid the hard-coded 56).
>

Seems I can't encode a shift of 56 neither in mov nor in orr. I propose to
replace both with an and of the address with itself.
This should not change anything.

Thoughts?

>> +
>> +	return 0xF0 | mte_get_ptr_tag(addr);
>> +}
>> +
>> +u8 mte_get_random_tag(void)
>> +{
>> +	u8 tag = 0xF;
>> +	u64 addr = 0;
>> +
>> +	if (system_supports_mte()) {
>> +		asm volatile(ALTERNATIVE("add %0, %0, %0",
>> +					 __MTE_PREAMBLE "irg %0, %0",
>> +					 ARM64_MTE)
>> +			     : "+r" (addr));
> What was the intention here? The first ADD doubles the pointer value and
> gets a tag out of it (possibly doubled as well, depends on the carry
> from bit 55). Better use something like "orr %0, %0, #0xf << 56".
>
Catalin Marinas Sept. 17, 2020, 5:07 p.m. UTC | #4
On Thu, Sep 17, 2020 at 05:17:00PM +0100, Vincenzo Frascino wrote:
> On 9/17/20 2:46 PM, Catalin Marinas wrote:
> >> diff --git a/arch/arm64/kernel/mte.c b/arch/arm64/kernel/mte.c
> >> index 52a0638ed967..e238ffde2679 100644
> >> --- a/arch/arm64/kernel/mte.c
> >> +++ b/arch/arm64/kernel/mte.c
> >> @@ -72,6 +74,52 @@ int memcmp_pages(struct page *page1, struct page *page2)
> >>  	return ret;
> >>  }
> >>  
> >> +u8 mte_get_mem_tag(void *addr)
> >> +{
> >> +	if (system_supports_mte())
> >> +		asm volatile(ALTERNATIVE("ldr %0, [%0]",
> >> +					 __MTE_PREAMBLE "ldg %0, [%0]",
> >> +					 ARM64_MTE)
> >> +			     : "+r" (addr));
> > This doesn't do what you think it does. LDG indeed reads the tag from
> > memory but LDR loads the actual data at that address. Instead of the
> > first LDR, you may want something like "mov %0, #0xf << 56" (and use
> > some macros to avoid the hard-coded 56).
> >
> 
> Seems I can't encode a shift of 56 neither in mov nor in orr. I propose to
> replace both with an and of the address with itself.
> This should not change anything.

Then use a NOP.
Catalin Marinas Sept. 18, 2020, 9:36 a.m. UTC | #5
On Thu, Sep 17, 2020 at 03:21:41PM +0100, Vincenzo Frascino wrote:
> On 9/17/20 2:46 PM, Catalin Marinas wrote:
> > On Tue, Sep 15, 2020 at 11:16:04PM +0200, Andrey Konovalov wrote:
> >> diff --git a/arch/arm64/kernel/mte.c b/arch/arm64/kernel/mte.c
> >> index 52a0638ed967..e238ffde2679 100644
> >> --- a/arch/arm64/kernel/mte.c
> >> +++ b/arch/arm64/kernel/mte.c
> >> @@ -72,6 +74,52 @@ int memcmp_pages(struct page *page1, struct page *page2)
> >>  	return ret;
> >>  }
> >>  
> >> +u8 mte_get_mem_tag(void *addr)
> >> +{
> >> +	if (system_supports_mte())
> >> +		asm volatile(ALTERNATIVE("ldr %0, [%0]",
> >> +					 __MTE_PREAMBLE "ldg %0, [%0]",
> >> +					 ARM64_MTE)
> >> +			     : "+r" (addr));
> > 
> > This doesn't do what you think it does. LDG indeed reads the tag from
> > memory but LDR loads the actual data at that address. Instead of the
> > first LDR, you may want something like "mov %0, #0xf << 56" (and use
> > some macros to avoid the hard-coded 56).
> 
> The result of the load should never be used since it is meaningful only if
> system_supports_mte(). It should be only required for compilation purposes.
> 
> Said that, I think I like more your solution hence I am going to adopt it.

Forgot to mention, please remove the system_supports_mte() if you use
ALTERNATIVE, we don't need both. I think the first asm instruction can
be a NOP since the kernel addresses without KASAN_HW or ARM64_MTE have
the top byte 0xff.

> >> +
> >> +	return 0xF0 | mte_get_ptr_tag(addr);
> >> +}
> >> +
> >> +u8 mte_get_random_tag(void)
> >> +{
> >> +	u8 tag = 0xF;
> >> +	u64 addr = 0;
> >> +
> >> +	if (system_supports_mte()) {
> >> +		asm volatile(ALTERNATIVE("add %0, %0, %0",
> >> +					 __MTE_PREAMBLE "irg %0, %0",
> >> +					 ARM64_MTE)
> >> +			     : "+r" (addr));
> > 
> > What was the intention here? The first ADD doubles the pointer value and
> > gets a tag out of it (possibly doubled as well, depends on the carry
> > from bit 55). Better use something like "orr %0, %0, #0xf << 56".
> 
> Same as above but I will use the orr in the next version.

I wonder whether system_supports_mte() makes more sense here than the
alternative:

	if (!system_supports_mte())
		return 0xff;

	... mte irg stuff ...

(you could do the same for the mte_get_mem_tag() function)

> >> +
> >> +		tag = mte_get_ptr_tag(addr);
> >> +	}
> >> +
> >> +	return 0xF0 | tag;
> > 
> > This function return seems inconsistent with the previous one. I'd
> > prefer the return line to be the same in both.
> 
> The reason why it is different is that in this function extracting the tag from
> the address makes sense only if irg is executed.
> 
> I can initialize addr to 0xf << 56 and make them the same.

I think you are right, they can be different. But see my comment above
about not doing the unnecessary shifting when all you want is to return
0xff with !MTE.
Vincenzo Frascino Sept. 22, 2020, 10:16 a.m. UTC | #6
On 9/18/20 10:36 AM, Catalin Marinas wrote:
>> Same as above but I will use the orr in the next version.
> I wonder whether system_supports_mte() makes more sense here than the
> alternative:
> 
> 	if (!system_supports_mte())
> 		return 0xff;
> 
> 	... mte irg stuff ...
> 
> (you could do the same for the mte_get_mem_tag() function)
> 

This would have been my preference from the beginning but then you mentioned
alternatives ;)

Anyway, more then happy to change the code in this way, seems more clean and
easy to understand.
diff mbox series

Patch

diff --git a/arch/arm64/include/asm/esr.h b/arch/arm64/include/asm/esr.h
index 035003acfa87..bc0dc66a6a27 100644
--- a/arch/arm64/include/asm/esr.h
+++ b/arch/arm64/include/asm/esr.h
@@ -103,6 +103,7 @@ 
 #define ESR_ELx_FSC		(0x3F)
 #define ESR_ELx_FSC_TYPE	(0x3C)
 #define ESR_ELx_FSC_EXTABT	(0x10)
+#define ESR_ELx_FSC_MTE		(0x11)
 #define ESR_ELx_FSC_SERROR	(0x11)
 #define ESR_ELx_FSC_ACCESS	(0x08)
 #define ESR_ELx_FSC_FAULT	(0x04)
diff --git a/arch/arm64/include/asm/mte-helpers.h b/arch/arm64/include/asm/mte-helpers.h
new file mode 100644
index 000000000000..5dc2d443851b
--- /dev/null
+++ b/arch/arm64/include/asm/mte-helpers.h
@@ -0,0 +1,48 @@ 
+/* SPDX-License-Identifier: GPL-2.0 */
+/*
+ * Copyright (C) 2020 ARM Ltd.
+ */
+#ifndef __ASM_MTE_ASM_H
+#define __ASM_MTE_ASM_H
+
+#define __MTE_PREAMBLE		".arch armv8.5-a\n.arch_extension memtag\n"
+
+#define MTE_GRANULE_SIZE	UL(16)
+#define MTE_GRANULE_MASK	(~(MTE_GRANULE_SIZE - 1))
+#define MTE_TAG_SHIFT		56
+#define MTE_TAG_SIZE		4
+#define MTE_TAG_MASK		GENMASK((MTE_TAG_SHIFT + (MTE_TAG_SIZE - 1)), MTE_TAG_SHIFT)
+#define MTE_TAG_MAX		(MTE_TAG_MASK >> MTE_TAG_SHIFT)
+
+#ifndef __ASSEMBLY__
+
+#include <linux/types.h>
+
+#ifdef CONFIG_ARM64_MTE
+
+#define mte_get_ptr_tag(ptr)	((u8)(((u64)(ptr)) >> MTE_TAG_SHIFT))
+u8 mte_get_mem_tag(void *addr);
+u8 mte_get_random_tag(void);
+void *mte_set_mem_tag_range(void *addr, size_t size, u8 tag);
+
+#else /* CONFIG_ARM64_MTE */
+
+#define mte_get_ptr_tag(ptr)	0xFF
+static inline u8 mte_get_mem_tag(void *addr)
+{
+	return 0xFF;
+}
+static inline u8 mte_get_random_tag(void)
+{
+	return 0xFF;
+}
+static inline void *mte_set_mem_tag_range(void *addr, size_t size, u8 tag)
+{
+	return addr;
+}
+
+#endif /* CONFIG_ARM64_MTE */
+
+#endif /* __ASSEMBLY__ */
+
+#endif /* __ASM_MTE_ASM_H  */
diff --git a/arch/arm64/include/asm/mte.h b/arch/arm64/include/asm/mte.h
index 1c99fcadb58c..82cd7c89edec 100644
--- a/arch/arm64/include/asm/mte.h
+++ b/arch/arm64/include/asm/mte.h
@@ -5,14 +5,13 @@ 
 #ifndef __ASM_MTE_H
 #define __ASM_MTE_H
 
-#define MTE_GRANULE_SIZE	UL(16)
-#define MTE_GRANULE_MASK	(~(MTE_GRANULE_SIZE - 1))
-#define MTE_TAG_SHIFT		56
-#define MTE_TAG_SIZE		4
+#include <asm/mte-helpers.h>
 
 #ifndef __ASSEMBLY__
 
+#include <linux/bitfield.h>
 #include <linux/page-flags.h>
+#include <linux/types.h>
 
 #include <asm/pgtable-types.h>
 
@@ -45,7 +44,9 @@  long get_mte_ctrl(struct task_struct *task);
 int mte_ptrace_copy_tags(struct task_struct *child, long request,
 			 unsigned long addr, unsigned long data);
 
-#else
+void mte_assign_mem_tag_range(void *addr, size_t size);
+
+#else /* CONFIG_ARM64_MTE */
 
 /* unused if !CONFIG_ARM64_MTE, silence the compiler */
 #define PG_mte_tagged	0
@@ -80,7 +81,11 @@  static inline int mte_ptrace_copy_tags(struct task_struct *child,
 	return -EIO;
 }
 
-#endif
+static inline void mte_assign_mem_tag_range(void *addr, size_t size)
+{
+}
+
+#endif /* CONFIG_ARM64_MTE */
 
 #endif /* __ASSEMBLY__ */
 #endif /* __ASM_MTE_H  */
diff --git a/arch/arm64/kernel/mte.c b/arch/arm64/kernel/mte.c
index 52a0638ed967..e238ffde2679 100644
--- a/arch/arm64/kernel/mte.c
+++ b/arch/arm64/kernel/mte.c
@@ -13,8 +13,10 @@ 
 #include <linux/swap.h>
 #include <linux/swapops.h>
 #include <linux/thread_info.h>
+#include <linux/types.h>
 #include <linux/uio.h>
 
+#include <asm/barrier.h>
 #include <asm/cpufeature.h>
 #include <asm/mte.h>
 #include <asm/ptrace.h>
@@ -72,6 +74,52 @@  int memcmp_pages(struct page *page1, struct page *page2)
 	return ret;
 }
 
+u8 mte_get_mem_tag(void *addr)
+{
+	if (system_supports_mte())
+		asm volatile(ALTERNATIVE("ldr %0, [%0]",
+					 __MTE_PREAMBLE "ldg %0, [%0]",
+					 ARM64_MTE)
+			     : "+r" (addr));
+
+	return 0xF0 | mte_get_ptr_tag(addr);
+}
+
+u8 mte_get_random_tag(void)
+{
+	u8 tag = 0xF;
+	u64 addr = 0;
+
+	if (system_supports_mte()) {
+		asm volatile(ALTERNATIVE("add %0, %0, %0",
+					 __MTE_PREAMBLE "irg %0, %0",
+					 ARM64_MTE)
+			     : "+r" (addr));
+
+		tag = mte_get_ptr_tag(addr);
+	}
+
+	return 0xF0 | tag;
+}
+
+void *mte_set_mem_tag_range(void *addr, size_t size, u8 tag)
+{
+	void *ptr = addr;
+
+	if ((!system_supports_mte()) || (size == 0))
+		return addr;
+
+	/* Make sure that size is aligned. */
+	WARN_ON(size & (MTE_GRANULE_SIZE - 1));
+
+	tag = 0xF0 | (tag & 0xF);
+	ptr = (void *)__tag_set(ptr, tag);
+
+	mte_assign_mem_tag_range(ptr, size);
+
+	return ptr;
+}
+
 static void update_sctlr_el1_tcf0(u64 tcf0)
 {
 	/* ISB required for the kernel uaccess routines */
diff --git a/arch/arm64/lib/mte.S b/arch/arm64/lib/mte.S
index 03ca6d8b8670..cc2c3a378c00 100644
--- a/arch/arm64/lib/mte.S
+++ b/arch/arm64/lib/mte.S
@@ -149,3 +149,20 @@  SYM_FUNC_START(mte_restore_page_tags)
 
 	ret
 SYM_FUNC_END(mte_restore_page_tags)
+
+/*
+ * Assign allocation tags for a region of memory based on the pointer tag
+ *   x0 - source pointer
+ *   x1 - size
+ *
+ * Note: size must be non-zero and MTE_GRANULE_SIZE aligned
+ */
+SYM_FUNC_START(mte_assign_mem_tag_range)
+	/* if (src == NULL) return; */
+	cbz	x0, 2f
+1:	stg	x0, [x0]
+	add	x0, x0, #MTE_GRANULE_SIZE
+	sub	x1, x1, #MTE_GRANULE_SIZE
+	cbnz	x1, 1b
+2:	ret
+SYM_FUNC_END(mte_assign_mem_tag_range)