@@ -86,7 +86,7 @@
struct hmm {
struct mm_struct *mm;
struct kref kref;
- struct mutex lock;
+ spinlock_t ranges_lock;
struct list_head ranges;
struct list_head mirrors;
struct mmu_notifier mmu_notifier;
@@ -64,7 +64,7 @@ static struct hmm *hmm_get_or_create(struct mm_struct *mm)
init_rwsem(&hmm->mirrors_sem);
hmm->mmu_notifier.ops = NULL;
INIT_LIST_HEAD(&hmm->ranges);
- mutex_init(&hmm->lock);
+ spin_lock_init(&hmm->ranges_lock);
kref_init(&hmm->kref);
hmm->notifiers = 0;
hmm->mm = mm;
@@ -144,6 +144,25 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
hmm_put(hmm);
}
+static void notifiers_decrement(struct hmm *hmm)
+{
+ unsigned long flags;
+
+ spin_lock_irqsave(&hmm->ranges_lock, flags);
+ hmm->notifiers--;
+ if (!hmm->notifiers) {
+ struct hmm_range *range;
+
+ list_for_each_entry(range, &hmm->ranges, list) {
+ if (range->valid)
+ continue;
+ range->valid = true;
+ }
+ wake_up_all(&hmm->wq);
+ }
+ spin_unlock_irqrestore(&hmm->ranges_lock, flags);
+}
+
static int hmm_invalidate_range_start(struct mmu_notifier *mn,
const struct mmu_notifier_range *nrange)
{
@@ -151,6 +170,7 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
struct hmm_mirror *mirror;
struct hmm_update update;
struct hmm_range *range;
+ unsigned long flags;
int ret = 0;
if (!kref_get_unless_zero(&hmm->kref))
@@ -161,12 +181,7 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
update.event = HMM_UPDATE_INVALIDATE;
update.blockable = mmu_notifier_range_blockable(nrange);
- if (mmu_notifier_range_blockable(nrange))
- mutex_lock(&hmm->lock);
- else if (!mutex_trylock(&hmm->lock)) {
- ret = -EAGAIN;
- goto out;
- }
+ spin_lock_irqsave(&hmm->ranges_lock, flags);
hmm->notifiers++;
list_for_each_entry(range, &hmm->ranges, list) {
if (update.end < range->start || update.start >= range->end)
@@ -174,7 +189,7 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
range->valid = false;
}
- mutex_unlock(&hmm->lock);
+ spin_unlock_irqrestore(&hmm->ranges_lock, flags);
if (mmu_notifier_range_blockable(nrange))
down_read(&hmm->mirrors_sem);
@@ -182,16 +197,23 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
ret = -EAGAIN;
goto out;
}
+
list_for_each_entry(mirror, &hmm->mirrors, list) {
- int ret;
+ int rc;
- ret = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
- if (!update.blockable && ret == -EAGAIN)
+ rc = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
+ if (rc) {
+ if (WARN_ON(update.blockable || rc != -EAGAIN))
+ continue;
+ ret = -EAGAIN;
break;
+ }
}
up_read(&hmm->mirrors_sem);
out:
+ if (ret)
+ notifiers_decrement(hmm);
hmm_put(hmm);
return ret;
}
@@ -204,20 +226,7 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn,
if (!kref_get_unless_zero(&hmm->kref))
return;
- mutex_lock(&hmm->lock);
- hmm->notifiers--;
- if (!hmm->notifiers) {
- struct hmm_range *range;
-
- list_for_each_entry(range, &hmm->ranges, list) {
- if (range->valid)
- continue;
- range->valid = true;
- }
- wake_up_all(&hmm->wq);
- }
- mutex_unlock(&hmm->lock);
-
+ notifiers_decrement(hmm);
hmm_put(hmm);
}
@@ -868,6 +877,7 @@ int hmm_range_register(struct hmm_range *range,
{
unsigned long mask = ((1UL << page_shift) - 1UL);
struct hmm *hmm = mirror->hmm;
+ unsigned long flags;
range->valid = false;
range->hmm = NULL;
@@ -886,7 +896,7 @@ int hmm_range_register(struct hmm_range *range,
return -EFAULT;
/* Initialize range to track CPU page table updates. */
- mutex_lock(&hmm->lock);
+ spin_lock_irqsave(&hmm->ranges_lock, flags);
range->hmm = hmm;
kref_get(&hmm->kref);
@@ -898,7 +908,7 @@ int hmm_range_register(struct hmm_range *range,
*/
if (!hmm->notifiers)
range->valid = true;
- mutex_unlock(&hmm->lock);
+ spin_unlock_irqrestore(&hmm->ranges_lock, flags);
return 0;
}
@@ -914,10 +924,11 @@ EXPORT_SYMBOL(hmm_range_register);
void hmm_range_unregister(struct hmm_range *range)
{
struct hmm *hmm = range->hmm;
+ unsigned long flags;
- mutex_lock(&hmm->lock);
+ spin_lock_irqsave(&hmm->ranges_lock, flags);
list_del_init(&range->list);
- mutex_unlock(&hmm->lock);
+ spin_unlock_irqrestore(&hmm->ranges_lock, flags);
/* Drop reference taken by hmm_range_register() */
mmput(hmm->mm);