Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 80 additions & 4 deletions storage/grpc_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ func (c *grpcStorageClient) OpenWriter(params *openWriterParams, opts ...storage
appendGen: params.appendGen,
finalizeOnClose: params.finalizeOnClose,

buf: make([]byte, 0, chunkSize),
buf: nil, // Allocated lazily on first buffered write.
chunkSize: chunkSize,
writeQuantum: writeQuantum,
lastSegmentStart: lastSegmentStart,
sendableUnits: sendableUnits,
Expand Down Expand Up @@ -256,7 +257,8 @@ type gRPCWriter struct {
appendGen int64
finalizeOnClose bool

buf []byte
buf []byte
chunkSize int
// A writeQuantum is the largest quantity of data which can be sent to the
// service in a single message.
writeQuantum int
Expand Down Expand Up @@ -557,8 +559,10 @@ type gRPCWriterCommand interface {
}

type gRPCWriterCommandWrite struct {
p []byte
done chan struct{}
p []byte
done chan struct{}
initialOffset int64
hasStarted bool
}

func (c *gRPCWriterCommandWrite) handle(w *gRPCWriter, cs gRPCWriterCommandHandleChans) error {
Expand All @@ -568,6 +572,20 @@ func (c *gRPCWriterCommandWrite) handle(w *gRPCWriter, cs gRPCWriterCommandHandl
return nil
}

// Try Zero-Copy send.
if len(w.buf) == 0 && (len(c.p) >= w.chunkSize || w.forceOneShot) {
done, err := c.attemptZeroCopyWrite(w, cs)
if err != nil {
return err
}
if done {
return nil
}
}

if w.buf == nil {
w.buf = make([]byte, 0, w.chunkSize)
}
wblen := len(w.buf)
allKnownBytes := wblen + len(c.p)
fullBufs := allKnownBytes / cap(w.buf)
Expand Down Expand Up @@ -691,6 +709,64 @@ func (c *gRPCWriterCommandWrite) handle(w *gRPCWriter, cs gRPCWriterCommandHandl
return nil
}

func (c *gRPCWriterCommandWrite) attemptZeroCopyWrite(w *gRPCWriter, cs gRPCWriterCommandHandleChans) (bool, error) {
// Initialize state on the first attempt of this command.
if !c.hasStarted {
c.initialOffset = w.bufBaseOffset
c.hasStarted = true
}

// Calculate the offset delta. If w.bufBaseOffset > c.initialOffset,
// the server persisted data from a previous attempt; we must skip those bytes.
skip64 := w.bufBaseOffset - c.initialOffset
if skip64 < 0 {
skip64 = 0
}
// If we've already sent everything in c.p, we're done.
if skip64 >= int64(len(c.p)) {
close(c.done)
return true, nil
}
skip := int(skip64)

pending := c.p[skip:]
n := len(pending)
toSend := n

// Unless forced, align the send size to the buffer capacity (chunk size)
// to ensure we only do zero-copy on full chunks.
if !w.forceOneShot {
toSend = (n / w.chunkSize) * w.chunkSize
}

if toSend == 0 {
// Remaining data is smaller than a chunk; fall through to buffering.
return false, nil
}

// Perform zero-copy send on the aligned slice.
newOffset, ok := w.sendBufferToTarget(cs, pending[:toSend], w.bufBaseOffset, toSend, func(cmp gRPCBidiWriteCompletion) {
w.handleCompletion(cmp)
})
if !ok {
return false, w.streamSender.err()
}

w.bufBaseOffset = newOffset

if toSend == n {
close(c.done)
return true, nil
}

// Partial send complete. Advance the command's view of the data so the
// caller can buffer the remaining tail without re-sending what we just sent.
c.p = c.p[skip+toSend:]
c.initialOffset = w.bufBaseOffset

return false, nil
}

type gRPCWriterCommandFlush struct {
done chan int64
}
Expand Down
233 changes: 233 additions & 0 deletions storage/grpc_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
package storage

import (
"context"
"sync"
"testing"

"cloud.google.com/go/storage/internal/apiv2/storagepb"
gax "github.com/googleapis/gax-go/v2"
"google.golang.org/protobuf/proto"
)

Expand Down Expand Up @@ -99,3 +102,233 @@ func TestGetObjectChecksums(t *testing.T) {
})
}
}

func TestGRPCWriter_OneShot_ZeroCopy(t *testing.T) {
// One-shot mode (ChunkSize = 0) must bypass buffering even for small data.
data := []byte("small-payload-for-oneshot")

mockSender := &mockZeroCopySender{}
w := &gRPCWriter{
buf: nil,
chunkSize: 0,
writeQuantum: 256 * 1024,
sendableUnits: 1,
writesChan: make(chan gRPCWriterCommand, 1),
donec: make(chan struct{}),
streamSender: mockSender,
settings: &settings{},
forceOneShot: true, // Enable one-shot mode.
}
w.progress = func(int64) {}
w.setObj = func(*ObjectAttrs) {}
w.setSize = func(int64) {}

go func() {
w.writeLoop(context.Background())
close(w.donec)
}()

n, err := w.Write(data)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != len(data) {
t.Errorf("Short write")
}

w.Close()
mockSender.wg.Wait()

mockSender.mu.Lock()
defer mockSender.mu.Unlock()

dataRequests := filterDataRequests(mockSender.requests)

if len(dataRequests) != 1 {
t.Fatalf("Expected 1 data request, got %d", len(dataRequests))
}

// Verify that the small payload was sent directly from the user's slice
// without being copied into the internal buffer.
if &dataRequests[0].buf[0] != &data[0] {
t.Errorf("OneShot: Expected zero-copy for small payload, but buffer was copied")
}
}

func TestGRPCWriter_DirtyBuffer_CopyFallback(t *testing.T) {
chunkSize := 100
part1 := make([]byte, 50)
part2 := make([]byte, 50)

mockSender := &mockZeroCopySender{}
w := &gRPCWriter{
buf: nil,
chunkSize: chunkSize,
writeQuantum: chunkSize,
sendableUnits: 1,
writesChan: make(chan gRPCWriterCommand, 1),
donec: make(chan struct{}),
streamSender: mockSender,
settings: &settings{},
}
// Initialize required callbacks.
w.progress = func(int64) {}
w.setObj = func(*ObjectAttrs) {}
w.setSize = func(int64) {}

go func() {
w.writeLoop(context.Background())
close(w.donec)
}()

w.Write(part1)
w.Write(part2)

w.Close()
mockSender.wg.Wait()

mockSender.mu.Lock()
defer mockSender.mu.Unlock()

dataRequests := filterDataRequests(mockSender.requests)

if len(dataRequests) != 1 {
t.Fatalf("Expected 1 combined data request, got %d", len(dataRequests))
}

// Verify that the internal buffer was used (copy fallback) because the
// individual writes were too small to trigger the zero-copy path.
sentBuf := dataRequests[0].buf

if &sentBuf[0] == &part1[0] {
t.Errorf("Expected copy (buffering), but got zero-copy of part1")
}
if &sentBuf[0] == &part2[0] {
t.Errorf("Expected copy (buffering), but got zero-copy of part2")
}

if len(sentBuf) != 100 {
t.Errorf("Expected 100 bytes sent, got %d", len(sentBuf))
}
}

func TestGRPCWriter_ZeroCopyOptimization(t *testing.T) {
chunkSize := 256 * 1024
// Data size is 2 full chunks + 100 bytes.
dataSize := (chunkSize * 2) + 100
data := make([]byte, dataSize)
data[0] = 1
data[chunkSize] = 2

mockSender := &mockZeroCopySender{}
w := &gRPCWriter{
buf: nil,
chunkSize: chunkSize,
writeQuantum: chunkSize,
sendableUnits: 10,
writesChan: make(chan gRPCWriterCommand, 1),
donec: make(chan struct{}),
streamSender: mockSender,
settings: &settings{},
}
w.progress = func(int64) {}
w.setObj = func(*ObjectAttrs) {}
w.setSize = func(int64) {}

go func() {
w.writeLoop(context.Background())
close(w.donec)
}()

w.Write(data)
w.Close()
mockSender.wg.Wait()

mockSender.mu.Lock()
defer mockSender.mu.Unlock()

dataRequests := filterDataRequests(mockSender.requests)

// Expect 3 requests: two zero-copy full chunks and one copied tail.
if len(dataRequests) != 3 {
t.Fatalf("Expected 3 data requests, got %d", len(dataRequests))
}

// Verify zero-copy on the first chunk.
if &dataRequests[0].buf[0] != &data[0] {
t.Errorf("Chunk 1: Zero-copy optimization failed (buffer copied)")
}

// Verify zero-copy on the second chunk.
if &dataRequests[1].buf[0] != &data[chunkSize] {
t.Errorf("Chunk 2: Zero-copy optimization failed (buffer copied)")
}

// Verify copy on the tail.
if &dataRequests[2].buf[0] == &data[chunkSize*2] {
t.Errorf("Tail: Expected buffer copy for small tail, but got zero-copy")
}
}

type mockZeroCopySender struct {
mu sync.Mutex
requests []gRPCBidiWriteRequest
errResult error
wg sync.WaitGroup // Waits for all async operations to complete.
}

func (m *mockZeroCopySender) connect(ctx context.Context, cs gRPCBufSenderChans, opts ...gax.CallOption) {
m.wg.Add(1)
go func() {
defer m.wg.Done()

// Track active flush goroutines to prevent closing the channel prematurely.
var completionWg sync.WaitGroup

defer func() {
completionWg.Wait()
close(cs.completions)
}()

for req := range cs.requests {
m.mu.Lock()
m.requests = append(m.requests, req)
m.mu.Unlock()

if req.requestAck {
select {
case cs.requestAcks <- struct{}{}:
case <-ctx.Done():
return
}
}

if req.flush {
completionWg.Add(1)
// Send completions asynchronously to avoid blocking the request loop.
go func(offset int64) {
defer completionWg.Done()
select {
case cs.completions <- gRPCBidiWriteCompletion{
flushOffset: offset,
}:
case <-ctx.Done():
}
}(req.offset + int64(len(req.buf)))
}
}
}()
}

func (m *mockZeroCopySender) err() error { return m.errResult }

// filterDataRequests returns only requests containing data, ignoring protocol overhead.
func filterDataRequests(reqs []gRPCBidiWriteRequest) []gRPCBidiWriteRequest {
var dataReqs []gRPCBidiWriteRequest
for _, r := range reqs {
if len(r.buf) > 0 {
dataReqs = append(dataReqs, r)
}
}
return dataReqs
}
Loading