Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ type ClientOptions struct {
// PropagateTraceparent is used to control whether the W3C Trace Context HTTP traceparent header
// is propagated on outgoing http requests.
PropagateTraceparent bool
// StrictTraceContinuation is used to control trace continuation from 3rd party services that happen to be
// instrumented by Sentry.
//
// Enabling the option means that the SDK will require the org ids from baggage to match for continuing the trace.
StrictTraceContinuation bool
// OrgID configures the orgID used for trace continuation.
//
// In most cases the orgID is already parsed from the DSN. This option should be used when non-standard Sentry DSNs
// are used, such as self-hosted or when using a local Relay.
OrgID uint64
// List of regexp strings that will be used to match against event's message
// and if applicable, caught errors type and value.
// If the match is found, then a whole event will be dropped.
Expand Down Expand Up @@ -404,7 +414,9 @@ func NewClient(options ClientOptions) (*Client, error) {
client.batchMeter = newMetricBatchProcessor(&client)
client.batchMeter.Start()
}

if options.OrgID != 0 && client.dsn != nil {
client.dsn.SetOrgID(options.OrgID)
}
client.setupIntegrations()

return &client, nil
Expand Down
3 changes: 3 additions & 0 deletions dynamic_sampling_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ func DynamicSamplingContextFromTransaction(span *Span) DynamicSamplingContext {
if publicKey := dsn.GetPublicKey(); publicKey != "" {
entries["public_key"] = publicKey
}
if orgID := dsn.GetOrgID(); orgID != 0 {
entries["org_id"] = strconv.FormatUint(orgID, 10)
}
}
if release := client.options.Release; release != "" {
entries["release"] = release
Expand Down
23 changes: 23 additions & 0 deletions internal/protocol/dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Dsn struct {
port int
path string
projectID string
orgID uint64
}

// NewDsn creates a Dsn by parsing rawURL. Most users will never call this
Expand Down Expand Up @@ -90,6 +91,17 @@ func NewDsn(rawURL string) (*Dsn, error) {
return nil, &DsnParseError{"empty host"}
}

// OrgID (optional)
var orgID uint64
parts := strings.Split(host, ".")
orgPart := parts[0]
if len(orgPart) >= 2 && orgPart[0] == 'o' {
parsedOrgID, err := strconv.ParseUint(orgPart[1:], 10, 64)
if err == nil {
orgID = parsedOrgID
}
}

// Port
var port int
if p := parsedURL.Port(); p != "" {
Expand Down Expand Up @@ -126,6 +138,7 @@ func NewDsn(rawURL string) (*Dsn, error) {
port: port,
path: path,
projectID: projectID,
orgID: orgID,
}, nil
}

Expand Down Expand Up @@ -182,6 +195,16 @@ func (dsn Dsn) GetProjectID() string {
return dsn.projectID
}

// GetOrgID returns the orgID that was parsed from the DSN.
func (dsn Dsn) GetOrgID() uint64 {
return dsn.orgID
}

// SetOrgID sets the orgID used for trace continuation.
func (dsn *Dsn) SetOrgID(orgID uint64) {
dsn.orgID = orgID
}

// GetAPIURL returns the URL of the envelope endpoint of the project
// associated with the DSN.
func (dsn Dsn) GetAPIURL() *url.URL {
Expand Down
81 changes: 70 additions & 11 deletions tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -953,8 +953,15 @@ func WithSpanOrigin(origin SpanOrigin) SpanOption {
func ContinueTrace(hub *Hub, traceparent, baggage string) SpanOption {
scope := hub.Scope()
propagationContext, _ := PropagationContextFromHeaders(traceparent, baggage)
scope.SetPropagationContext(propagationContext)
client := hub.Client()

if !shouldContinueTrace(client, propagationContext.DynamicSamplingContext) {
propagationContext = NewPropagationContext()
traceparent = ""
baggage = ""
}

scope.SetPropagationContext(propagationContext)
return ContinueFromHeaders(traceparent, baggage)
}

Expand All @@ -973,19 +980,35 @@ func ContinueFromRequest(r *http.Request) SpanOption {
// an existing TraceID and propagates the Dynamic Sampling context.
func ContinueFromHeaders(trace, baggage string) SpanOption {
return func(s *Span) {
if trace != "" {
s.updateFromSentryTrace([]byte(trace))
if trace == "" {
return
}

if baggage != "" {
s.updateFromBaggage([]byte(baggage))
// Parse baggage first to get org_id for comparison
var dsc DynamicSamplingContext
if baggage != "" {
parsed, err := DynamicSamplingContextFromHeader([]byte(baggage))
if err == nil {
dsc = parsed
}
}

// In case a sentry-trace header is present but there are no sentry-related
// values in the baggage, create an empty, frozen DynamicSamplingContext.
if !s.dynamicSamplingContext.HasEntries() {
s.dynamicSamplingContext = DynamicSamplingContext{
Frozen: true,
}
client := hubFromContext(s.ctx).Client()
if !shouldContinueTrace(client, dsc) {
return // leave span unchanged → behaves as head of trace
}

s.updateFromSentryTrace([]byte(trace))

if baggage != "" {
s.updateFromBaggage([]byte(baggage))
}

// In case a sentry-trace header is present but there are no sentry-related
// values in the baggage, create an empty, frozen DynamicSamplingContext.
if !s.dynamicSamplingContext.HasEntries() {
s.dynamicSamplingContext = DynamicSamplingContext{
Frozen: true,
}
}
}
Expand All @@ -998,6 +1021,10 @@ func ContinueFromTrace(trace string) SpanOption {
if trace == "" {
return
}
client := hubFromContext(s.ctx).Client()
if !shouldContinueTrace(client, DynamicSamplingContext{}) {
return
}
s.updateFromSentryTrace([]byte(trace))
}
}
Expand Down Expand Up @@ -1077,3 +1104,35 @@ func HTTPtoSpanStatus(code int) SpanStatus {
}
return SpanStatusUnknown
}

func shouldContinueTrace(client *Client, dsc DynamicSamplingContext) bool {
if client == nil {
return true
}

sdkOrgID := client.options.OrgID
if sdkOrgID == 0 && client.dsn != nil {
sdkOrgID = client.dsn.GetOrgID()
}

baggageOrgStr := dsc.Entries["org_id"]
baggageOrgID := uint64(0)
if baggageOrgStr != "" {
baggageOrgID, _ = strconv.ParseUint(baggageOrgStr, 10, 64)
}

// we reject non-matching orgs regardless of strict mode
if sdkOrgID != 0 && baggageOrgID != 0 && sdkOrgID != baggageOrgID {
return false
}

// If strict mode is on, both must be present and match
if client.options.StrictTraceContinuation {
if sdkOrgID == 0 && baggageOrgID == 0 {
return true
}
return sdkOrgID == baggageOrgID
}

return true
}
73 changes: 70 additions & 3 deletions tracing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ func TestContinueSpanFromRequest(t *testing.T) {
sampled := sampled
t.Run(sampled.String(), func(t *testing.T) {
var s Span
s.ctx = context.Background()
hkey := http.CanonicalHeaderKey("sentry-trace")
hval := (&Span{
TraceID: traceID,
Expand Down Expand Up @@ -585,12 +586,13 @@ func TestContinueTransactionFromHeaders(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Span{}
s.ctx = context.Background()
spanOption := ContinueFromHeaders(tt.traceStr, tt.baggageStr)
spanOption(s)

if diff := cmp.Diff(tt.wantSpan, s, cmp.Options{
cmp.AllowUnexported(Span{}),
cmpopts.IgnoreFields(Span{}, "mu", "finishOnce"),
cmpopts.IgnoreFields(Span{}, "ctx", "mu", "finishOnce"),
}); diff != "" {
t.Fatalf("Expected no difference on spans, got: %s", diff)
}
Expand All @@ -605,13 +607,14 @@ func TestContinueSpanFromTrace(t *testing.T) {
for _, sampled := range []Sampled{SampledTrue, SampledFalse, SampledUndefined} {
sampled := sampled
t.Run(sampled.String(), func(t *testing.T) {
var s Span
s := &Span{}
s.ctx = context.Background()
trace := (&Span{
TraceID: traceID,
SpanID: spanID,
Sampled: sampled,
}).ToSentryTrace()
ContinueFromTrace(trace)(&s)
ContinueFromTrace(trace)(s)
if s.TraceID != traceID {
t.Errorf("got %q, want %q", s.TraceID, traceID)
}
Expand Down Expand Up @@ -1287,3 +1290,67 @@ func TestSpanScopeManagement(t *testing.T) {
t.Errorf("expected SpanID %s, got %s", transaction.SpanID, spanID)
}
}

func TestStrictTraceContinuation(t *testing.T) {
incomingTraceID := TraceIDFromHex("bc6d53f15eb88f4320054569b8c553d4")
sentryTrace := "bc6d53f15eb88f4320054569b8c553d4-b72fa28504b07285-1"

baggageWithOrg := func(orgID string) string {
return "sentry-org_id=" + orgID + ",sentry-trace_id=bc6d53f15eb88f4320054569b8c553d4"
}
baggageWithoutOrg := "sentry-trace_id=bc6d53f15eb88f4320054569b8c553d4"

tests := []struct {
name string
baggageOrgID string
sdkOrgID uint64
strict bool
wantContinued bool
}{
{"strict=false, baggage=1, sdk=1", "1", 1, false, true},
{"strict=false, baggage=none, sdk=1", "", 1, false, true},
{"strict=false, baggage=1, sdk=none", "1", 0, false, true},
{"strict=false, baggage=none, sdk=none", "", 0, false, true},
{"strict=false, baggage=1, sdk=2", "1", 2, false, false},

{"strict=true, baggage=1, sdk=1", "1", 1, true, true},
{"strict=true, baggage=none, sdk=1", "", 1, true, false},
{"strict=true, baggage=1, sdk=none", "1", 0, true, false},
{"strict=true, baggage=none, sdk=none", "", 0, true, true},
{"strict=true, baggage=1, sdk=2", "1", 2, true, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := &MockTransport{}
ctx := NewTestContext(ClientOptions{
EnableTracing: true,
TracesSampleRate: 1.0,
Transport: transport,
StrictTraceContinuation: tt.strict,
OrgID: tt.sdkOrgID,
})

baggage := baggageWithoutOrg
if tt.baggageOrgID != "" {
baggage = baggageWithOrg(tt.baggageOrgID)
}

hub := GetHubFromContext(ctx)
transaction := StartTransaction(ctx, "test",
ContinueTrace(hub, sentryTrace, baggage),
)
transaction.Finish()

if tt.wantContinued {
if transaction.TraceID != incomingTraceID {
t.Errorf("expected trace to be continued, got new TraceID %s", transaction.TraceID)
}
} else {
if transaction.TraceID == incomingTraceID {
t.Errorf("expected new trace, but got continued TraceID %s", transaction.TraceID)
}
}
})
}
}
Loading