diff mbox series

[-next,v20,10/26] riscv: Add task switch support for vector

Message ID 20230518161949.11203-11-andy.chiu@sifive.com (mailing list archive)
State New, archived
Headers show
Series riscv: Add vector ISA support | expand

Commit Message

Andy Chiu May 18, 2023, 4:19 p.m. UTC
From: Greentime Hu <greentime.hu@sifive.com>

This patch adds task switch support for vector. It also supports all
lengths of vlen.

Suggested-by: Andrew Waterman <andrew@sifive.com>
Co-developed-by: Nick Knight <nick.knight@sifive.com>
Signed-off-by: Nick Knight <nick.knight@sifive.com>
Co-developed-by: Guo Ren <guoren@linux.alibaba.com>
Signed-off-by: Guo Ren <guoren@linux.alibaba.com>
Co-developed-by: Vincent Chen <vincent.chen@sifive.com>
Signed-off-by: Vincent Chen <vincent.chen@sifive.com>
Co-developed-by: Ruinland Tsai <ruinland.tsai@sifive.com>
Signed-off-by: Ruinland Tsai <ruinland.tsai@sifive.com>
Signed-off-by: Greentime Hu <greentime.hu@sifive.com>
Signed-off-by: Vineet Gupta <vineetg@rivosinc.com>
Signed-off-by: Andy Chiu <andy.chiu@sifive.com>
Reviewed-by: Conor Dooley <conor.dooley@microchip.com>
Reviewed-by: Björn Töpel <bjorn@rivosinc.com>
Reviewed-by: Heiko Stuebner <heiko.stuebner@vrull.eu>
Tested-by: Heiko Stuebner <heiko.stuebner@vrull.eu>
---
 arch/riscv/include/asm/processor.h   |  1 +
 arch/riscv/include/asm/switch_to.h   |  3 +++
 arch/riscv/include/asm/thread_info.h |  3 +++
 arch/riscv/include/asm/vector.h      | 38 ++++++++++++++++++++++++++++
 arch/riscv/kernel/process.c          | 18 +++++++++++++
 5 files changed, 63 insertions(+)

Comments

Palmer Dabbelt May 24, 2023, 12:49 a.m. UTC | #1
On Thu, 18 May 2023 09:19:33 PDT (-0700), andy.chiu@sifive.com wrote:
> From: Greentime Hu <greentime.hu@sifive.com>
>
> This patch adds task switch support for vector. It also supports all
> lengths of vlen.
>
> Suggested-by: Andrew Waterman <andrew@sifive.com>
> Co-developed-by: Nick Knight <nick.knight@sifive.com>
> Signed-off-by: Nick Knight <nick.knight@sifive.com>
> Co-developed-by: Guo Ren <guoren@linux.alibaba.com>
> Signed-off-by: Guo Ren <guoren@linux.alibaba.com>
> Co-developed-by: Vincent Chen <vincent.chen@sifive.com>
> Signed-off-by: Vincent Chen <vincent.chen@sifive.com>
> Co-developed-by: Ruinland Tsai <ruinland.tsai@sifive.com>
> Signed-off-by: Ruinland Tsai <ruinland.tsai@sifive.com>
> Signed-off-by: Greentime Hu <greentime.hu@sifive.com>
> Signed-off-by: Vineet Gupta <vineetg@rivosinc.com>
> Signed-off-by: Andy Chiu <andy.chiu@sifive.com>
> Reviewed-by: Conor Dooley <conor.dooley@microchip.com>
> Reviewed-by: Björn Töpel <bjorn@rivosinc.com>
> Reviewed-by: Heiko Stuebner <heiko.stuebner@vrull.eu>
> Tested-by: Heiko Stuebner <heiko.stuebner@vrull.eu>
> ---
>  arch/riscv/include/asm/processor.h   |  1 +
>  arch/riscv/include/asm/switch_to.h   |  3 +++
>  arch/riscv/include/asm/thread_info.h |  3 +++
>  arch/riscv/include/asm/vector.h      | 38 ++++++++++++++++++++++++++++
>  arch/riscv/kernel/process.c          | 18 +++++++++++++
>  5 files changed, 63 insertions(+)
>
> diff --git a/arch/riscv/include/asm/processor.h b/arch/riscv/include/asm/processor.h
> index 94a0590c6971..f0ddf691ac5e 100644
> --- a/arch/riscv/include/asm/processor.h
> +++ b/arch/riscv/include/asm/processor.h
> @@ -39,6 +39,7 @@ struct thread_struct {
>  	unsigned long s[12];	/* s[0]: frame pointer */
>  	struct __riscv_d_ext_state fstate;
>  	unsigned long bad_cause;
> +	struct __riscv_v_ext_state vstate;
>  };
>
>  /* Whitelist the fstate from the task_struct for hardened usercopy */
> diff --git a/arch/riscv/include/asm/switch_to.h b/arch/riscv/include/asm/switch_to.h
> index 4b96b13dee27..a727be723c56 100644
> --- a/arch/riscv/include/asm/switch_to.h
> +++ b/arch/riscv/include/asm/switch_to.h
> @@ -8,6 +8,7 @@
>
>  #include <linux/jump_label.h>
>  #include <linux/sched/task_stack.h>
> +#include <asm/vector.h>
>  #include <asm/hwcap.h>
>  #include <asm/processor.h>
>  #include <asm/ptrace.h>
> @@ -78,6 +79,8 @@ do {							\
>  	struct task_struct *__next = (next);		\
>  	if (has_fpu())					\
>  		__switch_to_fpu(__prev, __next);	\
> +	if (has_vector())					\
> +		__switch_to_vector(__prev, __next);	\
>  	((last) = __switch_to(__prev, __next));		\
>  } while (0)
>
> diff --git a/arch/riscv/include/asm/thread_info.h b/arch/riscv/include/asm/thread_info.h
> index e0d202134b44..97e6f65ec176 100644
> --- a/arch/riscv/include/asm/thread_info.h
> +++ b/arch/riscv/include/asm/thread_info.h
> @@ -81,6 +81,9 @@ struct thread_info {
>  	.preempt_count	= INIT_PREEMPT_COUNT,	\
>  }
>
> +void arch_release_task_struct(struct task_struct *tsk);
> +int arch_dup_task_struct(struct task_struct *dst, struct task_struct *src);
> +
>  #endif /* !__ASSEMBLY__ */
>
>  /*
> diff --git a/arch/riscv/include/asm/vector.h b/arch/riscv/include/asm/vector.h
> index 3c29f4eb552a..ce6a75e9cf62 100644
> --- a/arch/riscv/include/asm/vector.h
> +++ b/arch/riscv/include/asm/vector.h
> @@ -12,6 +12,9 @@
>  #ifdef CONFIG_RISCV_ISA_V
>
>  #include <linux/stringify.h>
> +#include <linux/sched.h>
> +#include <linux/sched/task_stack.h>
> +#include <asm/ptrace.h>
>  #include <asm/hwcap.h>
>  #include <asm/csr.h>
>  #include <asm/asm.h>
> @@ -124,6 +127,38 @@ static inline void __riscv_v_vstate_restore(struct __riscv_v_ext_state *restore_
>  	riscv_v_disable();
>  }
>
> +static inline void riscv_v_vstate_save(struct task_struct *task,
> +				       struct pt_regs *regs)
> +{
> +	if ((regs->status & SR_VS) == SR_VS_DIRTY) {
> +		struct __riscv_v_ext_state *vstate = &task->thread.vstate;
> +
> +		__riscv_v_vstate_save(vstate, vstate->datap);
> +		__riscv_v_vstate_clean(regs);
> +	}
> +}
> +
> +static inline void riscv_v_vstate_restore(struct task_struct *task,
> +					  struct pt_regs *regs)
> +{
> +	if ((regs->status & SR_VS) != SR_VS_OFF) {
> +		struct __riscv_v_ext_state *vstate = &task->thread.vstate;
> +
> +		__riscv_v_vstate_restore(vstate, vstate->datap);
> +		__riscv_v_vstate_clean(regs);
> +	}
> +}
> +
> +static inline void __switch_to_vector(struct task_struct *prev,
> +				      struct task_struct *next)
> +{
> +	struct pt_regs *regs;
> +
> +	regs = task_pt_regs(prev);
> +	riscv_v_vstate_save(prev, regs);
> +	riscv_v_vstate_restore(next, task_pt_regs(next));
> +}
> +
>  #else /* ! CONFIG_RISCV_ISA_V  */
>
>  struct pt_regs;
> @@ -132,6 +167,9 @@ static inline int riscv_v_setup_vsize(void) { return -EOPNOTSUPP; }
>  static __always_inline bool has_vector(void) { return false; }
>  static inline bool riscv_v_vstate_query(struct pt_regs *regs) { return false; }
>  #define riscv_v_vsize (0)
> +#define riscv_v_vstate_save(task, regs)		do {} while (0)
> +#define riscv_v_vstate_restore(task, regs)	do {} while (0)
> +#define __switch_to_vector(__prev, __next)	do {} while (0)
>  #define riscv_v_vstate_off(regs)		do {} while (0)
>  #define riscv_v_vstate_on(regs)			do {} while (0)
>
> diff --git a/arch/riscv/kernel/process.c b/arch/riscv/kernel/process.c
> index e2a060066730..b7a10361ddc6 100644
> --- a/arch/riscv/kernel/process.c
> +++ b/arch/riscv/kernel/process.c
> @@ -24,6 +24,7 @@
>  #include <asm/switch_to.h>
>  #include <asm/thread_info.h>
>  #include <asm/cpuidle.h>
> +#include <asm/vector.h>
>
>  register unsigned long gp_in_global __asm__("gp");
>
> @@ -146,12 +147,28 @@ void flush_thread(void)
>  	fstate_off(current, task_pt_regs(current));
>  	memset(&current->thread.fstate, 0, sizeof(current->thread.fstate));
>  #endif
> +#ifdef CONFIG_RISCV_ISA_V
> +	/* Reset vector state */
> +	riscv_v_vstate_off(task_pt_regs(current));
> +	kfree(current->thread.vstate.datap);
> +	memset(&current->thread.vstate, 0, sizeof(struct __riscv_v_ext_state));
> +#endif
> +}
> +
> +void arch_release_task_struct(struct task_struct *tsk)
> +{
> +	/* Free the vector context of datap. */
> +	if (has_vector())
> +		kfree(tsk->thread.vstate.datap);
>  }
>
>  int arch_dup_task_struct(struct task_struct *dst, struct task_struct *src)
>  {
>  	fstate_save(src, task_pt_regs(src));
>  	*dst = *src;
> +	/* clear entire V context, including datap for a new task */
> +	memset(&dst->thread.vstate, 0, sizeof(struct __riscv_v_ext_state));
> +
>  	return 0;
>  }
>
> @@ -184,6 +201,7 @@ int copy_thread(struct task_struct *p, const struct kernel_clone_args *args)
>  		p->thread.s[0] = 0;
>  	}
>  	p->thread.ra = (unsigned long)ret_from_fork;
> +	riscv_v_vstate_off(childregs);

When is V still on at this point?  If we got here via clone() (or any 
other syscall) it should be off already, so if we need to turn it off 
here then we must have arrived via something that's not a syscall.  I 
don't know what that case is, so it's not clear we can just throw away 
the V state.

>  	p->thread.sp = (unsigned long)childregs; /* kernel sp */
>  	return 0;
>  }
Andy Chiu May 30, 2023, 10:11 a.m. UTC | #2
On Wed, May 24, 2023 at 8:49 AM Palmer Dabbelt <palmer@dabbelt.com> wrote:
>
> On Thu, 18 May 2023 09:19:33 PDT (-0700), andy.chiu@sifive.com wrote:
> > From: Greentime Hu <greentime.hu@sifive.com>
> >
> > This patch adds task switch support for vector. It also supports all
> > lengths of vlen.
> >
> > Suggested-by: Andrew Waterman <andrew@sifive.com>
> > Co-developed-by: Nick Knight <nick.knight@sifive.com>
> > Signed-off-by: Nick Knight <nick.knight@sifive.com>
> > Co-developed-by: Guo Ren <guoren@linux.alibaba.com>
> > Signed-off-by: Guo Ren <guoren@linux.alibaba.com>
> > Co-developed-by: Vincent Chen <vincent.chen@sifive.com>
> > Signed-off-by: Vincent Chen <vincent.chen@sifive.com>
> > Co-developed-by: Ruinland Tsai <ruinland.tsai@sifive.com>
> > Signed-off-by: Ruinland Tsai <ruinland.tsai@sifive.com>
> > Signed-off-by: Greentime Hu <greentime.hu@sifive.com>
> > Signed-off-by: Vineet Gupta <vineetg@rivosinc.com>
> > Signed-off-by: Andy Chiu <andy.chiu@sifive.com>
> > Reviewed-by: Conor Dooley <conor.dooley@microchip.com>
> > Reviewed-by: Björn Töpel <bjorn@rivosinc.com>
> > Reviewed-by: Heiko Stuebner <heiko.stuebner@vrull.eu>
> > Tested-by: Heiko Stuebner <heiko.stuebner@vrull.eu>
> > ---
> >  arch/riscv/include/asm/processor.h   |  1 +
> >  arch/riscv/include/asm/switch_to.h   |  3 +++
> >  arch/riscv/include/asm/thread_info.h |  3 +++
> >  arch/riscv/include/asm/vector.h      | 38 ++++++++++++++++++++++++++++
> >  arch/riscv/kernel/process.c          | 18 +++++++++++++
> >  5 files changed, 63 insertions(+)
> >
> > diff --git a/arch/riscv/include/asm/processor.h b/arch/riscv/include/asm/processor.h
> > index 94a0590c6971..f0ddf691ac5e 100644
> > --- a/arch/riscv/include/asm/processor.h
> > +++ b/arch/riscv/include/asm/processor.h
> > @@ -39,6 +39,7 @@ struct thread_struct {
> >       unsigned long s[12];    /* s[0]: frame pointer */
> >       struct __riscv_d_ext_state fstate;
> >       unsigned long bad_cause;
> > +     struct __riscv_v_ext_state vstate;
> >  };
> >
> >  /* Whitelist the fstate from the task_struct for hardened usercopy */
> > diff --git a/arch/riscv/include/asm/switch_to.h b/arch/riscv/include/asm/switch_to.h
> > index 4b96b13dee27..a727be723c56 100644
> > --- a/arch/riscv/include/asm/switch_to.h
> > +++ b/arch/riscv/include/asm/switch_to.h
> > @@ -8,6 +8,7 @@
> >
> >  #include <linux/jump_label.h>
> >  #include <linux/sched/task_stack.h>
> > +#include <asm/vector.h>
> >  #include <asm/hwcap.h>
> >  #include <asm/processor.h>
> >  #include <asm/ptrace.h>
> > @@ -78,6 +79,8 @@ do {                                                        \
> >       struct task_struct *__next = (next);            \
> >       if (has_fpu())                                  \
> >               __switch_to_fpu(__prev, __next);        \
> > +     if (has_vector())                                       \
> > +             __switch_to_vector(__prev, __next);     \
> >       ((last) = __switch_to(__prev, __next));         \
> >  } while (0)
> >
> > diff --git a/arch/riscv/include/asm/thread_info.h b/arch/riscv/include/asm/thread_info.h
> > index e0d202134b44..97e6f65ec176 100644
> > --- a/arch/riscv/include/asm/thread_info.h
> > +++ b/arch/riscv/include/asm/thread_info.h
> > @@ -81,6 +81,9 @@ struct thread_info {
> >       .preempt_count  = INIT_PREEMPT_COUNT,   \
> >  }
> >
> > +void arch_release_task_struct(struct task_struct *tsk);
> > +int arch_dup_task_struct(struct task_struct *dst, struct task_struct *src);
> > +
> >  #endif /* !__ASSEMBLY__ */
> >
> >  /*
> > diff --git a/arch/riscv/include/asm/vector.h b/arch/riscv/include/asm/vector.h
> > index 3c29f4eb552a..ce6a75e9cf62 100644
> > --- a/arch/riscv/include/asm/vector.h
> > +++ b/arch/riscv/include/asm/vector.h
> > @@ -12,6 +12,9 @@
> >  #ifdef CONFIG_RISCV_ISA_V
> >
> >  #include <linux/stringify.h>
> > +#include <linux/sched.h>
> > +#include <linux/sched/task_stack.h>
> > +#include <asm/ptrace.h>
> >  #include <asm/hwcap.h>
> >  #include <asm/csr.h>
> >  #include <asm/asm.h>
> > @@ -124,6 +127,38 @@ static inline void __riscv_v_vstate_restore(struct __riscv_v_ext_state *restore_
> >       riscv_v_disable();
> >  }
> >
> > +static inline void riscv_v_vstate_save(struct task_struct *task,
> > +                                    struct pt_regs *regs)
> > +{
> > +     if ((regs->status & SR_VS) == SR_VS_DIRTY) {
> > +             struct __riscv_v_ext_state *vstate = &task->thread.vstate;
> > +
> > +             __riscv_v_vstate_save(vstate, vstate->datap);
> > +             __riscv_v_vstate_clean(regs);
> > +     }
> > +}
> > +
> > +static inline void riscv_v_vstate_restore(struct task_struct *task,
> > +                                       struct pt_regs *regs)
> > +{
> > +     if ((regs->status & SR_VS) != SR_VS_OFF) {
> > +             struct __riscv_v_ext_state *vstate = &task->thread.vstate;
> > +
> > +             __riscv_v_vstate_restore(vstate, vstate->datap);
> > +             __riscv_v_vstate_clean(regs);
> > +     }
> > +}
> > +
> > +static inline void __switch_to_vector(struct task_struct *prev,
> > +                                   struct task_struct *next)
> > +{
> > +     struct pt_regs *regs;
> > +
> > +     regs = task_pt_regs(prev);
> > +     riscv_v_vstate_save(prev, regs);
> > +     riscv_v_vstate_restore(next, task_pt_regs(next));
> > +}
> > +
> >  #else /* ! CONFIG_RISCV_ISA_V  */
> >
> >  struct pt_regs;
> > @@ -132,6 +167,9 @@ static inline int riscv_v_setup_vsize(void) { return -EOPNOTSUPP; }
> >  static __always_inline bool has_vector(void) { return false; }
> >  static inline bool riscv_v_vstate_query(struct pt_regs *regs) { return false; }
> >  #define riscv_v_vsize (0)
> > +#define riscv_v_vstate_save(task, regs)              do {} while (0)
> > +#define riscv_v_vstate_restore(task, regs)   do {} while (0)
> > +#define __switch_to_vector(__prev, __next)   do {} while (0)
> >  #define riscv_v_vstate_off(regs)             do {} while (0)
> >  #define riscv_v_vstate_on(regs)                      do {} while (0)
> >
> > diff --git a/arch/riscv/kernel/process.c b/arch/riscv/kernel/process.c
> > index e2a060066730..b7a10361ddc6 100644
> > --- a/arch/riscv/kernel/process.c
> > +++ b/arch/riscv/kernel/process.c
> > @@ -24,6 +24,7 @@
> >  #include <asm/switch_to.h>
> >  #include <asm/thread_info.h>
> >  #include <asm/cpuidle.h>
> > +#include <asm/vector.h>
> >
> >  register unsigned long gp_in_global __asm__("gp");
> >
> > @@ -146,12 +147,28 @@ void flush_thread(void)
> >       fstate_off(current, task_pt_regs(current));
> >       memset(&current->thread.fstate, 0, sizeof(current->thread.fstate));
> >  #endif
> > +#ifdef CONFIG_RISCV_ISA_V
> > +     /* Reset vector state */
> > +     riscv_v_vstate_off(task_pt_regs(current));
> > +     kfree(current->thread.vstate.datap);
> > +     memset(&current->thread.vstate, 0, sizeof(struct __riscv_v_ext_state));
> > +#endif
> > +}
> > +
> > +void arch_release_task_struct(struct task_struct *tsk)
> > +{
> > +     /* Free the vector context of datap. */
> > +     if (has_vector())
> > +             kfree(tsk->thread.vstate.datap);
> >  }
> >
> >  int arch_dup_task_struct(struct task_struct *dst, struct task_struct *src)
> >  {
> >       fstate_save(src, task_pt_regs(src));
> >       *dst = *src;
> > +     /* clear entire V context, including datap for a new task */
> > +     memset(&dst->thread.vstate, 0, sizeof(struct __riscv_v_ext_state));
> > +
> >       return 0;
> >  }
> >
> > @@ -184,6 +201,7 @@ int copy_thread(struct task_struct *p, const struct kernel_clone_args *args)
> >               p->thread.s[0] = 0;
> >       }
> >       p->thread.ra = (unsigned long)ret_from_fork;
> > +     riscv_v_vstate_off(childregs);
>
> When is V still on at this point?  If we got here via clone() (or any
> other syscall) it should be off already, so if we need to turn it off
> here then we must have arrived via something that's not a syscall.  I
> don't know what that case is, so it's not clear we can just throw away
> the V state.

I think we should move this callsite into the else clause of the
previous if-else statement. We must clear status.VS for every newly
forked process in order to make the first-use trap work. Since a
parent process may fork a child after the parent obtains a valid V
context in the first-use trap, and has status.VS enabled. Then, all
register contents are copied into child's, including the status.VS at
L:196. If we do not specifically turn off V after the copy, then we
will break the determination of  first-use trap and scheduling
routing, where the kernel assumes a valid V context exists if it sees
status.VS not being "OFF".

For the kernel thread we don't need to make this call because status
is set to SR_PP | SR_PIE for all.

>
> >       p->thread.sp = (unsigned long)childregs; /* kernel sp */
> >       return 0;
> >  }

Cheers,
Andy
diff mbox series

Patch

diff --git a/arch/riscv/include/asm/processor.h b/arch/riscv/include/asm/processor.h
index 94a0590c6971..f0ddf691ac5e 100644
--- a/arch/riscv/include/asm/processor.h
+++ b/arch/riscv/include/asm/processor.h
@@ -39,6 +39,7 @@  struct thread_struct {
 	unsigned long s[12];	/* s[0]: frame pointer */
 	struct __riscv_d_ext_state fstate;
 	unsigned long bad_cause;
+	struct __riscv_v_ext_state vstate;
 };
 
 /* Whitelist the fstate from the task_struct for hardened usercopy */
diff --git a/arch/riscv/include/asm/switch_to.h b/arch/riscv/include/asm/switch_to.h
index 4b96b13dee27..a727be723c56 100644
--- a/arch/riscv/include/asm/switch_to.h
+++ b/arch/riscv/include/asm/switch_to.h
@@ -8,6 +8,7 @@ 
 
 #include <linux/jump_label.h>
 #include <linux/sched/task_stack.h>
+#include <asm/vector.h>
 #include <asm/hwcap.h>
 #include <asm/processor.h>
 #include <asm/ptrace.h>
@@ -78,6 +79,8 @@  do {							\
 	struct task_struct *__next = (next);		\
 	if (has_fpu())					\
 		__switch_to_fpu(__prev, __next);	\
+	if (has_vector())					\
+		__switch_to_vector(__prev, __next);	\
 	((last) = __switch_to(__prev, __next));		\
 } while (0)
 
diff --git a/arch/riscv/include/asm/thread_info.h b/arch/riscv/include/asm/thread_info.h
index e0d202134b44..97e6f65ec176 100644
--- a/arch/riscv/include/asm/thread_info.h
+++ b/arch/riscv/include/asm/thread_info.h
@@ -81,6 +81,9 @@  struct thread_info {
 	.preempt_count	= INIT_PREEMPT_COUNT,	\
 }
 
+void arch_release_task_struct(struct task_struct *tsk);
+int arch_dup_task_struct(struct task_struct *dst, struct task_struct *src);
+
 #endif /* !__ASSEMBLY__ */
 
 /*
diff --git a/arch/riscv/include/asm/vector.h b/arch/riscv/include/asm/vector.h
index 3c29f4eb552a..ce6a75e9cf62 100644
--- a/arch/riscv/include/asm/vector.h
+++ b/arch/riscv/include/asm/vector.h
@@ -12,6 +12,9 @@ 
 #ifdef CONFIG_RISCV_ISA_V
 
 #include <linux/stringify.h>
+#include <linux/sched.h>
+#include <linux/sched/task_stack.h>
+#include <asm/ptrace.h>
 #include <asm/hwcap.h>
 #include <asm/csr.h>
 #include <asm/asm.h>
@@ -124,6 +127,38 @@  static inline void __riscv_v_vstate_restore(struct __riscv_v_ext_state *restore_
 	riscv_v_disable();
 }
 
+static inline void riscv_v_vstate_save(struct task_struct *task,
+				       struct pt_regs *regs)
+{
+	if ((regs->status & SR_VS) == SR_VS_DIRTY) {
+		struct __riscv_v_ext_state *vstate = &task->thread.vstate;
+
+		__riscv_v_vstate_save(vstate, vstate->datap);
+		__riscv_v_vstate_clean(regs);
+	}
+}
+
+static inline void riscv_v_vstate_restore(struct task_struct *task,
+					  struct pt_regs *regs)
+{
+	if ((regs->status & SR_VS) != SR_VS_OFF) {
+		struct __riscv_v_ext_state *vstate = &task->thread.vstate;
+
+		__riscv_v_vstate_restore(vstate, vstate->datap);
+		__riscv_v_vstate_clean(regs);
+	}
+}
+
+static inline void __switch_to_vector(struct task_struct *prev,
+				      struct task_struct *next)
+{
+	struct pt_regs *regs;
+
+	regs = task_pt_regs(prev);
+	riscv_v_vstate_save(prev, regs);
+	riscv_v_vstate_restore(next, task_pt_regs(next));
+}
+
 #else /* ! CONFIG_RISCV_ISA_V  */
 
 struct pt_regs;
@@ -132,6 +167,9 @@  static inline int riscv_v_setup_vsize(void) { return -EOPNOTSUPP; }
 static __always_inline bool has_vector(void) { return false; }
 static inline bool riscv_v_vstate_query(struct pt_regs *regs) { return false; }
 #define riscv_v_vsize (0)
+#define riscv_v_vstate_save(task, regs)		do {} while (0)
+#define riscv_v_vstate_restore(task, regs)	do {} while (0)
+#define __switch_to_vector(__prev, __next)	do {} while (0)
 #define riscv_v_vstate_off(regs)		do {} while (0)
 #define riscv_v_vstate_on(regs)			do {} while (0)
 
diff --git a/arch/riscv/kernel/process.c b/arch/riscv/kernel/process.c
index e2a060066730..b7a10361ddc6 100644
--- a/arch/riscv/kernel/process.c
+++ b/arch/riscv/kernel/process.c
@@ -24,6 +24,7 @@ 
 #include <asm/switch_to.h>
 #include <asm/thread_info.h>
 #include <asm/cpuidle.h>
+#include <asm/vector.h>
 
 register unsigned long gp_in_global __asm__("gp");
 
@@ -146,12 +147,28 @@  void flush_thread(void)
 	fstate_off(current, task_pt_regs(current));
 	memset(&current->thread.fstate, 0, sizeof(current->thread.fstate));
 #endif
+#ifdef CONFIG_RISCV_ISA_V
+	/* Reset vector state */
+	riscv_v_vstate_off(task_pt_regs(current));
+	kfree(current->thread.vstate.datap);
+	memset(&current->thread.vstate, 0, sizeof(struct __riscv_v_ext_state));
+#endif
+}
+
+void arch_release_task_struct(struct task_struct *tsk)
+{
+	/* Free the vector context of datap. */
+	if (has_vector())
+		kfree(tsk->thread.vstate.datap);
 }
 
 int arch_dup_task_struct(struct task_struct *dst, struct task_struct *src)
 {
 	fstate_save(src, task_pt_regs(src));
 	*dst = *src;
+	/* clear entire V context, including datap for a new task */
+	memset(&dst->thread.vstate, 0, sizeof(struct __riscv_v_ext_state));
+
 	return 0;
 }
 
@@ -184,6 +201,7 @@  int copy_thread(struct task_struct *p, const struct kernel_clone_args *args)
 		p->thread.s[0] = 0;
 	}
 	p->thread.ra = (unsigned long)ret_from_fork;
+	riscv_v_vstate_off(childregs);
 	p->thread.sp = (unsigned long)childregs; /* kernel sp */
 	return 0;
 }