@@ -132,6 +132,9 @@ static int __set_sregs2(struct kvm_vcpu *vcpu, struct kvm_sregs2 *sregs2);
static void __get_sregs2(struct kvm_vcpu *vcpu, struct kvm_sregs2 *sregs2);
static DEFINE_MUTEX(vendor_module_lock);
+static void kvm_load_guest_fpu(struct kvm_vcpu *vcpu);
+static void kvm_put_guest_fpu(struct kvm_vcpu *vcpu);
+
struct kvm_x86_ops kvm_x86_ops __read_mostly;
#define KVM_X86_OP(func) \
@@ -4346,6 +4349,21 @@ int kvm_get_msr_common(struct kvm_vcpu *vcpu, struct msr_data *msr_info)
}
EXPORT_SYMBOL_GPL(kvm_get_msr_common);
+static const u32 xstate_msrs[] = {
+ MSR_IA32_U_CET, MSR_IA32_PL3_SSP,
+};
+
+static bool is_xstate_msr(u32 index)
+{
+ int i;
+
+ for (i = 0; i < ARRAY_SIZE(xstate_msrs); i++) {
+ if (index == xstate_msrs[i])
+ return true;
+ }
+ return false;
+}
+
/*
* Read or write a bunch of msrs. All parameters are kernel addresses.
*
@@ -4356,11 +4374,20 @@ static int __msr_io(struct kvm_vcpu *vcpu, struct kvm_msrs *msrs,
int (*do_msr)(struct kvm_vcpu *vcpu,
unsigned index, u64 *data))
{
+ bool fpu_loaded = false;
int i;
- for (i = 0; i < msrs->nmsrs; ++i)
+ for (i = 0; i < msrs->nmsrs; ++i) {
+ if (vcpu && !fpu_loaded && kvm_caps.supported_xss &&
+ is_xstate_msr(entries[i].index)) {
+ kvm_load_guest_fpu(vcpu);
+ fpu_loaded = true;
+ }
if (do_msr(vcpu, entries[i].index, &entries[i].data))
break;
+ }
+ if (fpu_loaded)
+ kvm_put_guest_fpu(vcpu);
return i;
}