Skip to content

Commit b5de0c6

Browse files
committed
RDMA/cma: Fix use after free race in roce multicast join
The roce path triggers a work queue that continues to touch the id_priv but doesn't hold any reference on it. Futher, unlike in the IB case, the work queue is not fenced during rdma_destroy_id(). This can trigger a use after free if a destroy is triggered in the incredibly narrow window after the queue_work and the work starting and obtaining the handler_mutex. The only purpose of this work queue is to run the ULP event callback from the standard context, so switch the design to use the existing cma_work_handler() scheme. This simplifies quite a lot of the flow: - Use the cma_work_handler() callback to launch the work for roce. This requires generating the event synchronously inside the rdma_join_multicast(), which in turn means the dummy struct ib_sa_multicast can become a simple stack variable. - cm_work_handler() used the id_priv kref, so we can entirely eliminate the kref inside struct cma_multicast. Since the cma_multicast never leaks into an unprotected work queue the kfree can be done at the same time as for IB. - Eliminating the general multicast.ib requires using cma_set_mgid() in a few places to recompute the mgid. Fixes: 3c86aa7 ("RDMA/cm: Add RDMA CM support for IBoE devices") Link: https://lore.kernel.org/r/[email protected] Signed-off-by: Leon Romanovsky <[email protected]> Signed-off-by: Jason Gunthorpe <[email protected]>
1 parent 3788d29 commit b5de0c6

File tree

1 file changed

+88
-108
lines changed
  • drivers/infiniband/core

1 file changed

+88
-108
lines changed

drivers/infiniband/core/cma.c

Lines changed: 88 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ static const char * const cma_events[] = {
6868
[RDMA_CM_EVENT_TIMEWAIT_EXIT] = "timewait exit",
6969
};
7070

71+
static void cma_set_mgid(struct rdma_id_private *id_priv, struct sockaddr *addr,
72+
union ib_gid *mgid);
73+
7174
const char *__attribute_const__ rdma_event_msg(enum rdma_cm_event_type event)
7275
{
7376
size_t index = event;
@@ -345,13 +348,10 @@ struct ib_device *cma_get_ib_dev(struct cma_device *cma_dev)
345348

346349
struct cma_multicast {
347350
struct rdma_id_private *id_priv;
348-
union {
349-
struct ib_sa_multicast *ib;
350-
} multicast;
351+
struct ib_sa_multicast *sa_mc;
351352
struct list_head list;
352353
void *context;
353354
struct sockaddr_storage addr;
354-
struct kref mcref;
355355
u8 join_state;
356356
};
357357

@@ -363,12 +363,6 @@ struct cma_work {
363363
struct rdma_cm_event event;
364364
};
365365

366-
struct iboe_mcast_work {
367-
struct work_struct work;
368-
struct rdma_id_private *id;
369-
struct cma_multicast *mc;
370-
};
371-
372366
union cma_ip_addr {
373367
struct in6_addr ip6;
374368
struct {
@@ -475,14 +469,6 @@ static void cma_attach_to_dev(struct rdma_id_private *id_priv,
475469
rdma_start_port(cma_dev->device)];
476470
}
477471

478-
static inline void release_mc(struct kref *kref)
479-
{
480-
struct cma_multicast *mc = container_of(kref, struct cma_multicast, mcref);
481-
482-
kfree(mc->multicast.ib);
483-
kfree(mc);
484-
}
485-
486472
static void cma_release_dev(struct rdma_id_private *id_priv)
487473
{
488474
mutex_lock(&lock);
@@ -1778,14 +1764,10 @@ static void cma_release_port(struct rdma_id_private *id_priv)
17781764
static void destroy_mc(struct rdma_id_private *id_priv,
17791765
struct cma_multicast *mc)
17801766
{
1781-
if (rdma_cap_ib_mcast(id_priv->id.device, id_priv->id.port_num)) {
1782-
ib_sa_free_multicast(mc->multicast.ib);
1783-
kfree(mc);
1784-
return;
1785-
}
1767+
if (rdma_cap_ib_mcast(id_priv->id.device, id_priv->id.port_num))
1768+
ib_sa_free_multicast(mc->sa_mc);
17861769

1787-
if (rdma_protocol_roce(id_priv->id.device,
1788-
id_priv->id.port_num)) {
1770+
if (rdma_protocol_roce(id_priv->id.device, id_priv->id.port_num)) {
17891771
struct rdma_dev_addr *dev_addr =
17901772
&id_priv->id.route.addr.dev_addr;
17911773
struct net_device *ndev = NULL;
@@ -1794,11 +1776,15 @@ static void destroy_mc(struct rdma_id_private *id_priv,
17941776
ndev = dev_get_by_index(dev_addr->net,
17951777
dev_addr->bound_dev_if);
17961778
if (ndev) {
1797-
cma_igmp_send(ndev, &mc->multicast.ib->rec.mgid, false);
1779+
union ib_gid mgid;
1780+
1781+
cma_set_mgid(id_priv, (struct sockaddr *)&mc->addr,
1782+
&mgid);
1783+
cma_igmp_send(ndev, &mgid, false);
17981784
dev_put(ndev);
17991785
}
1800-
kref_put(&mc->mcref, release_mc);
18011786
}
1787+
kfree(mc);
18021788
}
18031789

18041790
static void cma_leave_mc_groups(struct rdma_id_private *id_priv)
@@ -2664,6 +2650,8 @@ static void cma_work_handler(struct work_struct *_work)
26642650
mutex_unlock(&id_priv->handler_mutex);
26652651
cma_id_put(id_priv);
26662652
out_free:
2653+
if (work->event.event == RDMA_CM_EVENT_MULTICAST_JOIN)
2654+
rdma_destroy_ah_attr(&work->event.param.ud.ah_attr);
26672655
kfree(work);
26682656
}
26692657

@@ -4324,53 +4312,66 @@ int rdma_disconnect(struct rdma_cm_id *id)
43244312
}
43254313
EXPORT_SYMBOL(rdma_disconnect);
43264314

4315+
static void cma_make_mc_event(int status, struct rdma_id_private *id_priv,
4316+
struct ib_sa_multicast *multicast,
4317+
struct rdma_cm_event *event,
4318+
struct cma_multicast *mc)
4319+
{
4320+
struct rdma_dev_addr *dev_addr;
4321+
enum ib_gid_type gid_type;
4322+
struct net_device *ndev;
4323+
4324+
if (!status)
4325+
status = cma_set_qkey(id_priv, be32_to_cpu(multicast->rec.qkey));
4326+
else
4327+
pr_debug_ratelimited("RDMA CM: MULTICAST_ERROR: failed to join multicast. status %d\n",
4328+
status);
4329+
4330+
event->status = status;
4331+
event->param.ud.private_data = mc->context;
4332+
if (status) {
4333+
event->event = RDMA_CM_EVENT_MULTICAST_ERROR;
4334+
return;
4335+
}
4336+
4337+
dev_addr = &id_priv->id.route.addr.dev_addr;
4338+
ndev = dev_get_by_index(dev_addr->net, dev_addr->bound_dev_if);
4339+
gid_type =
4340+
id_priv->cma_dev
4341+
->default_gid_type[id_priv->id.port_num -
4342+
rdma_start_port(
4343+
id_priv->cma_dev->device)];
4344+
4345+
event->event = RDMA_CM_EVENT_MULTICAST_JOIN;
4346+
if (ib_init_ah_from_mcmember(id_priv->id.device, id_priv->id.port_num,
4347+
&multicast->rec, ndev, gid_type,
4348+
&event->param.ud.ah_attr)) {
4349+
event->event = RDMA_CM_EVENT_MULTICAST_ERROR;
4350+
goto out;
4351+
}
4352+
4353+
event->param.ud.qp_num = 0xFFFFFF;
4354+
event->param.ud.qkey = be32_to_cpu(multicast->rec.qkey);
4355+
4356+
out:
4357+
if (ndev)
4358+
dev_put(ndev);
4359+
}
4360+
43274361
static int cma_ib_mc_handler(int status, struct ib_sa_multicast *multicast)
43284362
{
4329-
struct rdma_id_private *id_priv;
43304363
struct cma_multicast *mc = multicast->context;
4364+
struct rdma_id_private *id_priv = mc->id_priv;
43314365
struct rdma_cm_event event = {};
43324366
int ret = 0;
43334367

4334-
id_priv = mc->id_priv;
43354368
mutex_lock(&id_priv->handler_mutex);
43364369
if (READ_ONCE(id_priv->state) == RDMA_CM_DEVICE_REMOVAL ||
43374370
READ_ONCE(id_priv->state) == RDMA_CM_DESTROYING)
43384371
goto out;
43394372

4340-
if (!status)
4341-
status = cma_set_qkey(id_priv, be32_to_cpu(multicast->rec.qkey));
4342-
else
4343-
pr_debug_ratelimited("RDMA CM: MULTICAST_ERROR: failed to join multicast. status %d\n",
4344-
status);
4345-
event.status = status;
4346-
event.param.ud.private_data = mc->context;
4347-
if (!status) {
4348-
struct rdma_dev_addr *dev_addr =
4349-
&id_priv->id.route.addr.dev_addr;
4350-
struct net_device *ndev =
4351-
dev_get_by_index(dev_addr->net, dev_addr->bound_dev_if);
4352-
enum ib_gid_type gid_type =
4353-
id_priv->cma_dev->default_gid_type[id_priv->id.port_num -
4354-
rdma_start_port(id_priv->cma_dev->device)];
4355-
4356-
event.event = RDMA_CM_EVENT_MULTICAST_JOIN;
4357-
ret = ib_init_ah_from_mcmember(id_priv->id.device,
4358-
id_priv->id.port_num,
4359-
&multicast->rec,
4360-
ndev, gid_type,
4361-
&event.param.ud.ah_attr);
4362-
if (ret)
4363-
event.event = RDMA_CM_EVENT_MULTICAST_ERROR;
4364-
4365-
event.param.ud.qp_num = 0xFFFFFF;
4366-
event.param.ud.qkey = be32_to_cpu(multicast->rec.qkey);
4367-
if (ndev)
4368-
dev_put(ndev);
4369-
} else
4370-
event.event = RDMA_CM_EVENT_MULTICAST_ERROR;
4371-
4373+
cma_make_mc_event(status, id_priv, multicast, &event, mc);
43724374
ret = cma_cm_event_handler(id_priv, &event);
4373-
43744375
rdma_destroy_ah_attr(&event.param.ud.ah_attr);
43754376
if (ret) {
43764377
destroy_id_handler_unlock(id_priv);
@@ -4460,23 +4461,10 @@ static int cma_join_ib_multicast(struct rdma_id_private *id_priv,
44604461
IB_SA_MCMEMBER_REC_MTU |
44614462
IB_SA_MCMEMBER_REC_HOP_LIMIT;
44624463

4463-
mc->multicast.ib = ib_sa_join_multicast(&sa_client, id_priv->id.device,
4464-
id_priv->id.port_num, &rec,
4465-
comp_mask, GFP_KERNEL,
4466-
cma_ib_mc_handler, mc);
4467-
return PTR_ERR_OR_ZERO(mc->multicast.ib);
4468-
}
4469-
4470-
static void iboe_mcast_work_handler(struct work_struct *work)
4471-
{
4472-
struct iboe_mcast_work *mw = container_of(work, struct iboe_mcast_work, work);
4473-
struct cma_multicast *mc = mw->mc;
4474-
struct ib_sa_multicast *m = mc->multicast.ib;
4475-
4476-
mc->multicast.ib->context = mc;
4477-
cma_ib_mc_handler(0, m);
4478-
kref_put(&mc->mcref, release_mc);
4479-
kfree(mw);
4464+
mc->sa_mc = ib_sa_join_multicast(&sa_client, id_priv->id.device,
4465+
id_priv->id.port_num, &rec, comp_mask,
4466+
GFP_KERNEL, cma_ib_mc_handler, mc);
4467+
return PTR_ERR_OR_ZERO(mc->sa_mc);
44804468
}
44814469

44824470
static void cma_iboe_set_mgid(struct sockaddr *addr, union ib_gid *mgid,
@@ -4511,52 +4499,47 @@ static void cma_iboe_set_mgid(struct sockaddr *addr, union ib_gid *mgid,
45114499
static int cma_iboe_join_multicast(struct rdma_id_private *id_priv,
45124500
struct cma_multicast *mc)
45134501
{
4514-
struct iboe_mcast_work *work;
4502+
struct cma_work *work;
45154503
struct rdma_dev_addr *dev_addr = &id_priv->id.route.addr.dev_addr;
45164504
int err = 0;
45174505
struct sockaddr *addr = (struct sockaddr *)&mc->addr;
45184506
struct net_device *ndev = NULL;
4507+
struct ib_sa_multicast ib;
45194508
enum ib_gid_type gid_type;
45204509
bool send_only;
45214510

45224511
send_only = mc->join_state == BIT(SENDONLY_FULLMEMBER_JOIN);
45234512

4524-
if (cma_zero_addr((struct sockaddr *)&mc->addr))
4513+
if (cma_zero_addr(addr))
45254514
return -EINVAL;
45264515

45274516
work = kzalloc(sizeof *work, GFP_KERNEL);
45284517
if (!work)
45294518
return -ENOMEM;
45304519

4531-
mc->multicast.ib = kzalloc(sizeof(struct ib_sa_multicast), GFP_KERNEL);
4532-
if (!mc->multicast.ib) {
4533-
err = -ENOMEM;
4534-
goto out1;
4535-
}
4536-
45374520
gid_type = id_priv->cma_dev->default_gid_type[id_priv->id.port_num -
45384521
rdma_start_port(id_priv->cma_dev->device)];
4539-
cma_iboe_set_mgid(addr, &mc->multicast.ib->rec.mgid, gid_type);
4522+
cma_iboe_set_mgid(addr, &ib.rec.mgid, gid_type);
45404523

4541-
mc->multicast.ib->rec.pkey = cpu_to_be16(0xffff);
4524+
ib.rec.pkey = cpu_to_be16(0xffff);
45424525
if (id_priv->id.ps == RDMA_PS_UDP)
4543-
mc->multicast.ib->rec.qkey = cpu_to_be32(RDMA_UDP_QKEY);
4526+
ib.rec.qkey = cpu_to_be32(RDMA_UDP_QKEY);
45444527

45454528
if (dev_addr->bound_dev_if)
45464529
ndev = dev_get_by_index(dev_addr->net, dev_addr->bound_dev_if);
45474530
if (!ndev) {
45484531
err = -ENODEV;
4549-
goto out2;
4532+
goto err_free;
45504533
}
4551-
mc->multicast.ib->rec.rate = iboe_get_rate(ndev);
4552-
mc->multicast.ib->rec.hop_limit = 1;
4553-
mc->multicast.ib->rec.mtu = iboe_get_mtu(ndev->mtu);
4534+
ib.rec.rate = iboe_get_rate(ndev);
4535+
ib.rec.hop_limit = 1;
4536+
ib.rec.mtu = iboe_get_mtu(ndev->mtu);
45544537

45554538
if (addr->sa_family == AF_INET) {
45564539
if (gid_type == IB_GID_TYPE_ROCE_UDP_ENCAP) {
4557-
mc->multicast.ib->rec.hop_limit = IPV6_DEFAULT_HOPLIMIT;
4540+
ib.rec.hop_limit = IPV6_DEFAULT_HOPLIMIT;
45584541
if (!send_only) {
4559-
err = cma_igmp_send(ndev, &mc->multicast.ib->rec.mgid,
4542+
err = cma_igmp_send(ndev, &ib.rec.mgid,
45604543
true);
45614544
}
45624545
}
@@ -4565,24 +4548,22 @@ static int cma_iboe_join_multicast(struct rdma_id_private *id_priv,
45654548
err = -ENOTSUPP;
45664549
}
45674550
dev_put(ndev);
4568-
if (err || !mc->multicast.ib->rec.mtu) {
4551+
if (err || !ib.rec.mtu) {
45694552
if (!err)
45704553
err = -EINVAL;
4571-
goto out2;
4554+
goto err_free;
45724555
}
45734556
rdma_ip2gid((struct sockaddr *)&id_priv->id.route.addr.src_addr,
4574-
&mc->multicast.ib->rec.port_gid);
4557+
&ib.rec.port_gid);
45754558
work->id = id_priv;
4576-
work->mc = mc;
4577-
INIT_WORK(&work->work, iboe_mcast_work_handler);
4578-
kref_get(&mc->mcref);
4559+
INIT_WORK(&work->work, cma_work_handler);
4560+
cma_make_mc_event(0, id_priv, &ib, &work->event, mc);
4561+
/* Balances with cma_id_put() in cma_work_handler */
4562+
cma_id_get(id_priv);
45794563
queue_work(cma_wq, &work->work);
4580-
45814564
return 0;
45824565

4583-
out2:
4584-
kfree(mc->multicast.ib);
4585-
out1:
4566+
err_free:
45864567
kfree(work);
45874568
return err;
45884569
}
@@ -4604,7 +4585,7 @@ int rdma_join_multicast(struct rdma_cm_id *id, struct sockaddr *addr,
46044585
READ_ONCE(id_priv->state) != RDMA_CM_ADDR_RESOLVED))
46054586
return -EINVAL;
46064587

4607-
mc = kmalloc(sizeof *mc, GFP_KERNEL);
4588+
mc = kzalloc(sizeof(*mc), GFP_KERNEL);
46084589
if (!mc)
46094590
return -ENOMEM;
46104591

@@ -4614,7 +4595,6 @@ int rdma_join_multicast(struct rdma_cm_id *id, struct sockaddr *addr,
46144595
mc->join_state = join_state;
46154596

46164597
if (rdma_protocol_roce(id->device, id->port_num)) {
4617-
kref_init(&mc->mcref);
46184598
ret = cma_iboe_join_multicast(id_priv, mc);
46194599
if (ret)
46204600
goto out_err;

0 commit comments

Comments
 (0)