Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions channels/channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ func (c *Channels) CreateNew(selfPeer peer.ID, tid datatransfer.TransferID, base
log.Errorw("failed to create new tracking channel for data-transfer", "channelID", chid, "err", err)
return datatransfer.ChannelID{}, err
}
log.Debugw("created tracking channel for data-transfer, emitting channel Open event", "channelID", chid)
return chid, c.stateMachines.Send(chid, datatransfer.Open)
log.Debugw("created tracking channel for data-transfer", "channelID", chid)
return chid, nil
}

// InProgress returns a list of in progress channels
Expand Down Expand Up @@ -169,6 +169,10 @@ func (c *Channels) GetByID(ctx context.Context, chid datatransfer.ChannelID) (da
return c.fromInternalChannelState(internalChannel), nil
}

func (c *Channels) Open(chid datatransfer.ChannelID) error {
return c.send(chid, datatransfer.Open)
}

// Accept marks a data transfer as accepted
func (c *Channels) Accept(chid datatransfer.ChannelID) error {
return c.send(chid, datatransfer.Accept)
Expand Down
11 changes: 9 additions & 2 deletions channels/channels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func TestChannels(t *testing.T) {
t.Run("adding channels", func(t *testing.T) {
chid, err := channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[0], peers[0], peers[1])
require.NoError(t, err)
require.NoError(t, channelList.Open(chid))
require.Equal(t, peers[0], chid.Initiator)
require.Equal(t, tid1, chid.ID)

Expand All @@ -65,6 +66,7 @@ func TestChannels(t *testing.T) {
// can add for different id
chid, err = channelList.CreateNew(peers[2], tid2, cids[1], selector, fv2, peers[3], peers[2], peers[3])
require.NoError(t, err)
require.NoError(t, channelList.Open(chid))
require.Equal(t, peers[3], chid.Initiator)
require.Equal(t, tid2, chid.ID)
state = checkEvent(ctx, t, received, datatransfer.Open)
Expand Down Expand Up @@ -139,6 +141,7 @@ func TestChannels(t *testing.T) {

chid, err := channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[0], peers[0], peers[1])
require.NoError(t, err)
require.NoError(t, channelList.Open(chid))
checkEvent(ctx, t, received, datatransfer.Open)
require.NoError(t, channelList.Accept(chid))
checkEvent(ctx, t, received, datatransfer.Accept)
Expand Down Expand Up @@ -169,8 +172,9 @@ func TestChannels(t *testing.T) {
err = channelList.Start(ctx)
require.NoError(t, err)

_, err = channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[0], peers[0], peers[1])
chid, err := channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[0], peers[0], peers[1])
require.NoError(t, err)
require.NoError(t, channelList.Open(chid))
state := checkEvent(ctx, t, received, datatransfer.Open)
require.Equal(t, datatransfer.Requested, state.Status())
require.Equal(t, uint64(0), state.Received())
Expand Down Expand Up @@ -235,8 +239,9 @@ func TestChannels(t *testing.T) {
err = channelList.Start(ctx)
require.NoError(t, err)

_, err = channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[1], peers[0], peers[1])
chid, err := channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[1], peers[0], peers[1])
require.NoError(t, err)
require.NoError(t, channelList.Open(chid))
state := checkEvent(ctx, t, received, datatransfer.Open)

err = channelList.TransferInitiated(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1})
Expand Down Expand Up @@ -381,6 +386,7 @@ func TestChannels(t *testing.T) {

chid, err := channelList.CreateNew(peers[0], tid2, cids[1], selector, fv2, peers[2], peers[1], peers[2])
require.NoError(t, err)
require.NoError(t, channelList.Open(chid))
require.Equal(t, peers[2], chid.Initiator)
require.Equal(t, tid2, chid.ID)
state = checkEvent(ctx, t, received, datatransfer.Open)
Expand Down Expand Up @@ -425,6 +431,7 @@ func TestChannels(t *testing.T) {

chid, err := channelList.CreateNew(peers[3], tid1, cids[0], selector, fv1, peers[3], peers[0], peers[3])
require.NoError(t, err)
require.NoError(t, channelList.Open(chid))
state := checkEvent(ctx, t, received, datatransfer.Open)
require.Equal(t, datatransfer.Requested, state.Status())

Expand Down
51 changes: 51 additions & 0 deletions channelsubscriptions/channelsubscriptions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package channelsubscriptions

import (
"sync"

datatransfer "github.com/filecoin-project/go-data-transfer/v2"
"github.com/filecoin-project/go-data-transfer/v2/channels"
)

type ChannelSubscriptions struct {
subscriptions map[datatransfer.ChannelID][]datatransfer.Subscriber
subscriptionsLk sync.RWMutex
unsub datatransfer.Unsubscribe
}

type SubscriptionAPI interface {
SubscribeToEvents(subscriber datatransfer.Subscriber) datatransfer.Unsubscribe
}

func NewChannelSubscriptions(subscriptionAPI SubscriptionAPI) *ChannelSubscriptions {
cs := &ChannelSubscriptions{
subscriptions: make(map[datatransfer.ChannelID][]datatransfer.Subscriber),
}
cs.unsub = subscriptionAPI.SubscribeToEvents(cs.subscriber)
return cs
}

func (cs *ChannelSubscriptions) Stop() {
cs.unsub()
}

func (cs *ChannelSubscriptions) Subscribe(chid datatransfer.ChannelID, cb datatransfer.Subscriber) {
cs.subscriptionsLk.Lock()
defer cs.subscriptionsLk.Unlock()
cs.subscriptions[chid] = append(cs.subscriptions[chid], cb)
}

func (cs *ChannelSubscriptions) subscriber(evt datatransfer.Event, state datatransfer.ChannelState) {
cs.subscriptionsLk.RLock()
cbs := cs.subscriptions[state.ChannelID()]
for _, cb := range cbs {
cb(evt, state)
}
cs.subscriptionsLk.RUnlock()

if channels.IsChannelTerminated(state.Status()) {
cs.subscriptionsLk.Lock()
delete(cs.subscriptions, state.ChannelID())
cs.subscriptionsLk.Unlock()
}
}
78 changes: 78 additions & 0 deletions channelsubscriptions/channelsubscriptions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package channelsubscriptions_test

import (
"testing"

"github.com/stretchr/testify/require"

datatransfer "github.com/filecoin-project/go-data-transfer/v2"
"github.com/filecoin-project/go-data-transfer/v2/channelsubscriptions"
"github.com/filecoin-project/go-data-transfer/v2/testutil"
)

func TestChannelSubscriptions(t *testing.T) {
peers := testutil.GeneratePeers(2)
tid1 := datatransfer.TransferID(0)
tid2 := datatransfer.TransferID(1)

chid1 := datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}
chid2 := datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid2}

ms := &mockSubscriptionAPI{}
cs := channelsubscriptions.NewChannelSubscriptions(ms)
require.NotNil(t, ms.subscriber)
var events []datatransfer.EventCode
// no events while not subscribed
ms.subscriber(datatransfer.Event{Code: datatransfer.Open}, testutil.NewMockChannelState(testutil.MockChannelStateParams{
ChannelID: chid1,
}))
require.Empty(t, events)
cs.Subscribe(chid1, func(evt datatransfer.Event, state datatransfer.ChannelState) {
events = append(events, evt.Code)
})
// receives events after subscription
ms.subscriber(datatransfer.Event{Code: datatransfer.Open}, testutil.NewMockChannelState(testutil.MockChannelStateParams{
ChannelID: chid1,
}))
require.Len(t, events, 1)
require.Equal(t, events[0], datatransfer.Open)
ms.subscriber(datatransfer.Event{Code: datatransfer.Accept}, testutil.NewMockChannelState(testutil.MockChannelStateParams{
ChannelID: chid1,
}))
require.Len(t, events, 2)
require.Equal(t, events[1], datatransfer.Accept)
// does not receive events for other channels
ms.subscriber(datatransfer.Event{Code: datatransfer.TransferInitiated}, testutil.NewMockChannelState(testutil.MockChannelStateParams{
ChannelID: chid2,
}))
require.Len(t, events, 2)

// send final event
ms.subscriber(datatransfer.Event{Code: datatransfer.CleanupComplete}, testutil.NewMockChannelState(testutil.MockChannelStateParams{
ChannelID: chid1,
Complete: true,
}))
require.Len(t, events, 3)
require.Equal(t, events[2], datatransfer.CleanupComplete)

// receives no more events after complete
ms.subscriber(datatransfer.Event{Code: datatransfer.EventCode(datatransfer.DataSent)}, testutil.NewMockChannelState(testutil.MockChannelStateParams{
ChannelID: chid1,
}))
require.Len(t, events, 3)

// verify stop unsubscribes
cs.Stop()
require.Nil(t, ms.subscriber)
}

type mockSubscriptionAPI struct {
subscriber datatransfer.Subscriber
}

func (ms *mockSubscriptionAPI) SubscribeToEvents(subscriber datatransfer.Subscriber) datatransfer.Unsubscribe {
ms.subscriber = subscriber
return func() {
ms.subscriber = nil
}
}
3 changes: 3 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ const ErrRejected = errorType("response rejected")

// ErrUnsupported indicates an operation is not supported by the transport protocol
const ErrUnsupported = errorType("unsupported")

// ErrAlreadySubscribed indicates a subscription to events exists for the given channel
const ErrAlreadySubscribed = errorType("already subscribed to events for given channel id")
36 changes: 31 additions & 5 deletions impl/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
datatransfer "github.com/filecoin-project/go-data-transfer/v2"
"github.com/filecoin-project/go-data-transfer/v2/channelmonitor"
"github.com/filecoin-project/go-data-transfer/v2/channels"
"github.com/filecoin-project/go-data-transfer/v2/channelsubscriptions"
"github.com/filecoin-project/go-data-transfer/v2/message"
"github.com/filecoin-project/go-data-transfer/v2/message/types"
"github.com/filecoin-project/go-data-transfer/v2/network"
Expand All @@ -45,6 +46,7 @@ type manager struct {
channelMonitorCfg *channelmonitor.Config
transferIDGen *timeCounter
spansIndex *tracing.SpansIndex
channelSubscriptions *channelsubscriptions.ChannelSubscriptions
}

type internalEvent struct {
Expand All @@ -59,7 +61,7 @@ func dispatcher(evt pubsub.Event, subscriberFn pubsub.SubscriberFn) error {
}
cb, ok := subscriberFn.(datatransfer.Subscriber)
if !ok {
return errors.New("wrong type of event")
return errors.New("wrong type of subscriber")
}
cb(ie.evt, ie.state)
return nil
Expand All @@ -72,7 +74,7 @@ func readyDispatcher(evt pubsub.Event, fn pubsub.SubscriberFn) error {
}
cb, ok := fn.(datatransfer.ReadyFunc)
if !ok {
return errors.New("wrong type of event")
return errors.New("wrong type of event subscriber")
}
cb(migrateErr)
return nil
Expand Down Expand Up @@ -117,7 +119,7 @@ func NewDataTransfer(ds datastore.Batching, dataTransferNetwork network.DataTran
// Create push / pull channel monitor after applying config options as the config
// options may apply to the monitor
m.channelMonitor = channelmonitor.NewMonitor(m, m.channelMonitorCfg)

m.channelSubscriptions = channelsubscriptions.NewChannelSubscriptions(m)
return m, nil
}

Expand Down Expand Up @@ -158,6 +160,7 @@ func (m *manager) Stop(ctx context.Context) error {
log.Info("stop data-transfer module")
m.channelMonitor.Shutdown()
m.spansIndex.EndAll()
m.channelSubscriptions.Stop()
return m.transport.Shutdown(ctx)
}

Expand All @@ -176,9 +179,11 @@ func (m *manager) RegisterVoucherType(voucherType datatransfer.TypeIdentifier, v

// OpenPushDataChannel opens a data transfer that will send data to the recipient peer and
// transfer parts of the piece that match the selector
func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (datatransfer.ChannelID, error) {
func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.TypedVoucher, baseCid cid.Cid, selector datamodel.Node, options ...datatransfer.TransferOption) (datatransfer.ChannelID, error) {
log.Infof("open push channel to %s with base cid %s", requestTo, baseCid)

tc := datatransfer.FromOptions(options)

req, err := m.newRequest(ctx, selector, false, voucher, baseCid, requestTo)
if err != nil {
return datatransfer.ChannelID{}, err
Expand All @@ -189,6 +194,15 @@ func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, vo
if err != nil {
return chid, err
}

if eventsCb := tc.EventsCb(); eventsCb != nil {
m.channelSubscriptions.Subscribe(chid, eventsCb)
}

if err := m.channels.Open(chid); err != nil {
return chid, err
}

ctx, span := m.spansIndex.SpanForChannel(ctx, chid)
processor, has := m.transportConfigurers.Processor(voucher.Type)
if has {
Expand Down Expand Up @@ -217,19 +231,31 @@ func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, vo

// OpenPullDataChannel opens a data transfer that will request data from the sending peer and
// transfer parts of the piece that match the selector
func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.TypedVoucher, baseCid cid.Cid, selector datamodel.Node) (datatransfer.ChannelID, error) {
func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, voucher datatransfer.TypedVoucher, baseCid cid.Cid, selector datamodel.Node, options ...datatransfer.TransferOption) (datatransfer.ChannelID, error) {
log.Infof("open pull channel to %s with base cid %s", requestTo, baseCid)

tc := datatransfer.FromOptions(options)

req, err := m.newRequest(ctx, selector, true, voucher, baseCid, requestTo)
if err != nil {
return datatransfer.ChannelID{}, err
}

// initiator = us, sender = them, receiver = us
chid, err := m.channels.CreateNew(m.peerID, req.TransferID(), baseCid, selector, voucher,
m.peerID, requestTo, m.peerID)
if err != nil {
return chid, err
}

if eventsCb := tc.EventsCb(); eventsCb != nil {
m.channelSubscriptions.Subscribe(chid, eventsCb)
}

if err := m.channels.Open(chid); err != nil {
return chid, err
}

ctx, span := m.spansIndex.SpanForChannel(ctx, chid)
processor, has := m.transportConfigurers.Processor(voucher.Type)
if has {
Expand Down
8 changes: 6 additions & 2 deletions impl/receiving_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,14 @@ func (m *manager) acceptRequest(chid datatransfer.ChannelID, incoming datatransf
dataReceiver,
)
if err != nil {
log.Errorw("failed to create and start tracking channel", "channelID", chid, "err", err)
log.Errorw("failed to create tracking channel", "channelID", chid, "err", err)
return result, err
}
err = m.channels.Open(chid)
if err != nil {
log.Errorw("failed to start tracking channel", "channelID", chid, "err", err)
return result, err
}

// record that the channel was accepted
log.Debugw("successfully created and started tracking channel", "channelID", chid)
if err := m.channels.Accept(chid); err != nil {
Expand Down
Loading