Skip to content
This repository was archived by the owner on Feb 1, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions providerquerymanager/providerquerymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,24 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context,
inProgressRequestChan: inProgressRequestChan,
}:
case <-pqm.ctx.Done():
return nil
ch := make(chan peer.ID)
close(ch)
return ch
case <-sessionCtx.Done():
return nil
ch := make(chan peer.ID)
close(ch)
return ch
}

// DO NOT select on sessionCtx. We only want to abort here if we're
// shutting down because we can't actually _cancel_ the request till we
// get to receiveProviders.
var receivedInProgressRequest inProgressRequest
select {
case <-pqm.ctx.Done():
return nil
case <-sessionCtx.Done():
return nil
ch := make(chan peer.ID)
close(ch)
return ch
case receivedInProgressRequest = <-inProgressRequestChan:
}

Expand Down Expand Up @@ -170,7 +177,9 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k
case <-pqm.ctx.Done():
return
case <-sessionCtx.Done():
pqm.cancelProviderRequest(k, incomingProviders)
if incomingProviders != nil {
pqm.cancelProviderRequest(k, incomingProviders)
}
return
case provider, ok := <-incomingProviders:
if !ok {
Expand Down Expand Up @@ -228,7 +237,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() {
wg.Add(1)
go func(p peer.ID) {
defer wg.Done()
err := pqm.network.ConnectTo(pqm.ctx, p)
err := pqm.network.ConnectTo(findProviderCtx, p)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? I can't remember if we did this for a reason.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Historical note: this was not correct. See #226.

if err != nil {
log.Debugf("failed to connect to provider %s: %s", p, err)
return
Expand Down Expand Up @@ -397,12 +406,12 @@ func (crm *cancelRequestMessage) debugMessage() string {
func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) {
requestStatus, ok := pqm.inProgressRequestStatuses[crm.k]
if !ok {
log.Errorf("Attempt to cancel request for cid (%s) not in progress", crm.k.String())
// Request finished while queued.
return
}
_, ok = requestStatus.listeners[crm.incomingProviders]
if !ok {
log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String())
// Request finished and _restarted_ while queued.
return
}
delete(requestStatus.listeners, crm.incomingProviders)
Expand Down
57 changes: 57 additions & 0 deletions providerquerymanager/providerquerymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,60 @@ func TestFindProviderTimeout(t *testing.T) {
t.Fatal("Find provider request should have timed out, did not")
}
}

func TestFindProviderPreCanceled(t *testing.T) {
peers := testutil.GeneratePeers(10)
fpn := &fakeProviderNetwork{
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := New(ctx, fpn)
providerQueryManager.Startup()
providerQueryManager.SetFindProviderTimeout(100 * time.Millisecond)
keys := testutil.GenerateCids(1)

sessionCtx, cancel := context.WithCancel(ctx)
cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0])
if firstRequestChan == nil {
t.Fatal("expected non-nil channel")
}
select {
case <-firstRequestChan:
case <-time.After(10 * time.Millisecond):
t.Fatal("shouldn't have blocked waiting on a closed context")
}
}

func TestCancelFindProvidersAfterCompletion(t *testing.T) {
peers := testutil.GeneratePeers(2)
fpn := &fakeProviderNetwork{
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := New(ctx, fpn)
providerQueryManager.Startup()
providerQueryManager.SetFindProviderTimeout(100 * time.Millisecond)
keys := testutil.GenerateCids(1)

sessionCtx, cancel := context.WithCancel(ctx)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0])
<-firstRequestChan // wait for everything to start.
time.Sleep(10 * time.Millisecond) // wait for the incoming providres to stop.
cancel() // cancel the context.

timer := time.NewTimer(10 * time.Millisecond)
defer timer.Stop()
for {
select {
case _, ok := <-firstRequestChan:
if !ok {
return
}
case <-timer.C:
t.Fatal("should have finished receiving responses within timeout")
}
}
}