@@ -86,7 +86,7 @@
struct hmm {
struct mm_struct *mm;
struct kref kref;
- struct mutex lock;
+ spinlock_t ranges_lock;
struct list_head ranges;
struct list_head mirrors;
struct mmu_notifier mmu_notifier;
@@ -67,7 +67,7 @@ static struct hmm *hmm_get_or_create(struct mm_struct *mm)
init_rwsem(&hmm->mirrors_sem);
hmm->mmu_notifier.ops = NULL;
INIT_LIST_HEAD(&hmm->ranges);
- mutex_init(&hmm->lock);
+ spin_lock_init(&hmm->ranges_lock);
kref_init(&hmm->kref);
hmm->notifiers = 0;
hmm->mm = mm;
@@ -124,18 +124,19 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
{
struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
struct hmm_mirror *mirror;
+ unsigned long flags;
/* Bail out if hmm is in the process of being freed */
if (!kref_get_unless_zero(&hmm->kref))
return;
- mutex_lock(&hmm->lock);
+ spin_lock_irqsave(&hmm->ranges_lock, flags);
/*
* Since hmm_range_register() holds the mmget() lock hmm_release() is
* prevented as long as a range exists.
*/
WARN_ON(!list_empty(&hmm->ranges));
- mutex_unlock(&hmm->lock);
+ spin_unlock_irqrestore(&hmm->ranges_lock, flags);
down_read(&hmm->mirrors_sem);
list_for_each_entry(mirror, &hmm->mirrors, list) {
@@ -151,6 +152,23 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
hmm_put(hmm);
}
+static void notifiers_decrement(struct hmm *hmm)
+{
+ lockdep_assert_held(&hmm->ranges_lock);
+
+ hmm->notifiers--;
+ if (!hmm->notifiers) {
+ struct hmm_range *range;
+
+ list_for_each_entry(range, &hmm->ranges, list) {
+ if (range->valid)
+ continue;
+ range->valid = true;
+ }
+ wake_up_all(&hmm->wq);
+ }
+}
+
static int hmm_invalidate_range_start(struct mmu_notifier *mn,
const struct mmu_notifier_range *nrange)
{
@@ -158,6 +176,7 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
struct hmm_mirror *mirror;
struct hmm_update update;
struct hmm_range *range;
+ unsigned long flags;
int ret = 0;
if (!kref_get_unless_zero(&hmm->kref))
@@ -168,12 +187,7 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
update.event = HMM_UPDATE_INVALIDATE;
update.blockable = mmu_notifier_range_blockable(nrange);
- if (mmu_notifier_range_blockable(nrange))
- mutex_lock(&hmm->lock);
- else if (!mutex_trylock(&hmm->lock)) {
- ret = -EAGAIN;
- goto out;
- }
+ spin_lock_irqsave(&hmm->ranges_lock, flags);
hmm->notifiers++;
list_for_each_entry(range, &hmm->ranges, list) {
if (update.end < range->start || update.start >= range->end)
@@ -181,7 +195,7 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
range->valid = false;
}
- mutex_unlock(&hmm->lock);
+ spin_unlock_irqrestore(&hmm->ranges_lock, flags);
if (mmu_notifier_range_blockable(nrange))
down_read(&hmm->mirrors_sem);
@@ -189,16 +203,26 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
ret = -EAGAIN;
goto out;
}
+
list_for_each_entry(mirror, &hmm->mirrors, list) {
- int ret;
+ int rc;
- ret = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
- if (!update.blockable && ret == -EAGAIN)
+ rc = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
+ if (rc) {
+ if (WARN_ON(update.blockable || rc != -EAGAIN))
+ continue;
+ ret = -EAGAIN;
break;
+ }
}
up_read(&hmm->mirrors_sem);
out:
+ if (ret) {
+ spin_lock_irqsave(&hmm->ranges_lock, flags);
+ notifiers_decrement(hmm);
+ spin_unlock_irqrestore(&hmm->ranges_lock, flags);
+ }
hmm_put(hmm);
return ret;
}
@@ -207,23 +231,14 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn,
const struct mmu_notifier_range *nrange)
{
struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
+ unsigned long flags;
if (!kref_get_unless_zero(&hmm->kref))
return;
- mutex_lock(&hmm->lock);
- hmm->notifiers--;
- if (!hmm->notifiers) {
- struct hmm_range *range;
-
- list_for_each_entry(range, &hmm->ranges, list) {
- if (range->valid)
- continue;
- range->valid = true;
- }
- wake_up_all(&hmm->wq);
- }
- mutex_unlock(&hmm->lock);
+ spin_lock_irqsave(&hmm->ranges_lock, flags);
+ notifiers_decrement(hmm);
+ spin_unlock_irqrestore(&hmm->ranges_lock, flags);
hmm_put(hmm);
}
@@ -876,6 +891,7 @@ int hmm_range_register(struct hmm_range *range,
{
unsigned long mask = ((1UL << page_shift) - 1UL);
struct hmm *hmm = mirror->hmm;
+ unsigned long flags;
range->valid = false;
range->hmm = NULL;
@@ -894,7 +910,7 @@ int hmm_range_register(struct hmm_range *range,
return -EFAULT;
/* Initialize range to track CPU page table updates. */
- mutex_lock(&hmm->lock);
+ spin_lock_irqsave(&hmm->ranges_lock, flags);
range->hmm = hmm;
kref_get(&hmm->kref);
@@ -906,7 +922,7 @@ int hmm_range_register(struct hmm_range *range,
*/
if (!hmm->notifiers)
range->valid = true;
- mutex_unlock(&hmm->lock);
+ spin_unlock_irqrestore(&hmm->ranges_lock, flags);
return 0;
}
@@ -922,10 +938,11 @@ EXPORT_SYMBOL(hmm_range_register);
void hmm_range_unregister(struct hmm_range *range)
{
struct hmm *hmm = range->hmm;
+ unsigned long flags;
- mutex_lock(&hmm->lock);
+ spin_lock_irqsave(&hmm->ranges_lock, flags);
list_del(&range->list);
- mutex_unlock(&hmm->lock);
+ spin_unlock_irqrestore(&hmm->ranges_lock, flags);
/* Drop reference taken by hmm_range_register() */
mmput(hmm->mm);