@@ -94,13 +94,13 @@ static void adhoc_sta_free(void *data)
eapol_sm_free(sta->sm);
if (sta->hs_sta)
- handshake_state_free(sta->hs_sta);
+ handshake_state_unref(sta->hs_sta);
if (sta->sm_a)
eapol_sm_free(sta->sm_a);
if (sta->hs_auth)
- handshake_state_free(sta->hs_auth);
+ handshake_state_unref(sta->hs_auth);
end:
l_free(sta);
@@ -230,7 +230,7 @@ static void ap_stop_handshake(struct sta_state *sta)
}
if (sta->hs) {
- handshake_state_free(sta->hs);
+ handshake_state_unref(sta->hs);
sta->hs = NULL;
}
@@ -103,7 +103,14 @@ void __handshake_set_install_ext_tk_func(handshake_install_ext_tk_func_t func)
install_ext_tk = func;
}
-void handshake_state_free(struct handshake_state *s)
+struct handshake_state *handshake_state_ref(struct handshake_state *s)
+{
+ __sync_fetch_and_add(&s->refcount, 1);
+
+ return s;
+}
+
+void handshake_state_unref(struct handshake_state *s)
{
__typeof__(s->free) destroy;
@@ -117,6 +124,9 @@ void handshake_state_free(struct handshake_state *s)
return;
}
+ if (__sync_sub_and_fetch(&s->refcount, 1))
+ return;
+
l_free(s->authenticator_ie);
l_free(s->supplicant_ie);
l_free(s->authenticator_rsnxe);
@@ -170,6 +170,8 @@ struct handshake_state {
bool in_event;
handshake_event_func_t event_func;
+
+ int refcount;
};
#define HSID(x) UNIQUE_ID(handshake_, x)
@@ -186,7 +188,7 @@ struct handshake_state {
##__VA_ARGS__); \
\
if (!HSID(hs)->in_event) { \
- handshake_state_free(HSID(hs)); \
+ handshake_state_unref(HSID(hs)); \
HSID(freed) = true; \
} else \
HSID(hs)->in_event = false; \
@@ -194,7 +196,8 @@ struct handshake_state {
HSID(freed); \
})
-void handshake_state_free(struct handshake_state *s);
+struct handshake_state *handshake_state_ref(struct handshake_state *s);
+void handshake_state_unref(struct handshake_state *s);
void handshake_state_set_supplicant_address(struct handshake_state *s,
const uint8_t *spa);
@@ -316,4 +319,4 @@ void handshake_util_build_gtk_kde(enum crypto_cipher cipher, const uint8_t *key,
void handshake_util_build_igtk_kde(enum crypto_cipher cipher, const uint8_t *key,
unsigned int key_index, uint8_t *to);
-DEFINE_CLEANUP_FUNC(handshake_state_free);
+DEFINE_CLEANUP_FUNC(handshake_state_unref);
@@ -376,6 +376,7 @@ struct handshake_state *netdev_handshake_state_new(struct netdev *netdev)
nhs->super.ifindex = netdev->index;
nhs->super.free = netdev_handshake_state_free;
+ nhs->super.refcount = 1;
nhs->netdev = netdev;
/*
@@ -828,7 +829,7 @@ static void netdev_connect_free(struct netdev *netdev)
eapol_preauth_cancel(netdev->index);
if (netdev->handshake) {
- handshake_state_free(netdev->handshake);
+ handshake_state_unref(netdev->handshake);
netdev->handshake = NULL;
}
@@ -4239,7 +4240,7 @@ int netdev_reassociate(struct netdev *netdev, const struct scan_bss *target_bss,
eapol_sm_free(old_sm);
if (old_hs)
- handshake_state_free(old_hs);
+ handshake_state_unref(old_hs);
return 0;
}
@@ -1497,7 +1497,7 @@ static void p2p_handshake_event(struct handshake_state *hs,
static void p2p_try_connect_group(struct p2p_device *dev)
{
struct scan_bss *bss = dev->conn_wsc_bss;
- _auto_(handshake_state_free) struct handshake_state *hs = NULL;
+ _auto_(handshake_state_unref) struct handshake_state *hs = NULL;
struct iovec ie_iov[16];
int ie_num = 0;
int r;
@@ -1394,7 +1394,7 @@ static struct handshake_state *station_handshake_setup(struct station *station,
return hs;
not_supported:
- handshake_state_free(hs);
+ handshake_state_unref(hs);
return NULL;
}
@@ -2484,7 +2484,7 @@ static void station_preauthenticate_cb(struct netdev *netdev,
}
if (station_transition_reassociate(station, bss, new_hs) < 0) {
- handshake_state_free(new_hs);
+ handshake_state_unref(new_hs);
station_roam_failed(station);
}
}
@@ -2687,7 +2687,7 @@ static bool station_try_next_transition(struct station *station,
}
if (station_transition_reassociate(station, bss, new_hs) < 0) {
- handshake_state_free(new_hs);
+ handshake_state_unref(new_hs);
return false;
}
@@ -3734,7 +3734,7 @@ int __station_connect_network(struct station *station, struct network *network,
station_netdev_event,
station_connect_cb, station);
if (r < 0) {
- handshake_state_free(hs);
+ handshake_state_unref(hs);
return r;
}
@@ -393,7 +393,7 @@ static int wsc_enrollee_connect(struct wsc_enrollee *wsce, struct scan_bss *bss,
return 0;
error:
- handshake_state_free(hs);
+ handshake_state_unref(hs);
return r;
}