diff --git a/client.go b/client.go index fd6cba163..fb7306111 100644 --- a/client.go +++ b/client.go @@ -148,6 +148,16 @@ type ClientOptions struct { // PropagateTraceparent is used to control whether the W3C Trace Context HTTP traceparent header // is propagated on outgoing http requests. PropagateTraceparent bool + // StrictTraceContinuation is used to control trace continuation from 3rd party services that happen to be + // instrumented by Sentry. + // + // Enabling the option means that the SDK will require the org ids from baggage to match for continuing the trace. + StrictTraceContinuation bool + // OrgID configures the orgID used for trace propagation and features like StrictTraceContinuation. + // + // In most cases the orgID is already parsed from the DSN. This option should be used when non-standard Sentry DSNs + // are used, such as self-hosted or when using a local Relay. + OrgID uint64 // List of regexp strings that will be used to match against event's message // and if applicable, caught errors type and value. // If the match is found, then a whole event will be dropped. @@ -404,7 +414,9 @@ func NewClient(options ClientOptions) (*Client, error) { client.batchMeter = newMetricBatchProcessor(&client) client.batchMeter.Start() } - + if options.OrgID != 0 && client.dsn != nil { + client.dsn.SetOrgID(options.OrgID) + } client.setupIntegrations() return &client, nil diff --git a/client_test.go b/client_test.go index e90464103..6a2754333 100644 --- a/client_test.go +++ b/client_test.go @@ -900,6 +900,37 @@ func TestSampleRate(t *testing.T) { }) } } +func TestClient_ParseOrgID(t *testing.T) { + c, err := NewClient(ClientOptions{ + Dsn: "https://example@o1.ingest.us.sentry.io/1337", + }) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, uint64(1), c.dsn.GetOrgID(), "Custom org id should override the DSN parsed one") +} + +func TestClient_ParseOrgIDInvalid(t *testing.T) { + c, err := NewClient(ClientOptions{ + // org id is MaxUint64 + 1, should be considered empty + Dsn: "https://example@o18446744073709551616.ingest.us.sentry.io/1337", + }) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, uint64(0), c.dsn.GetOrgID(), "Custom org id should override the DSN parsed one") +} + +func TestClientOptions_OrgIDShouldOverrideParsed(t *testing.T) { + c, err := NewClient(ClientOptions{ + Dsn: "https://example@o1.ingest.us.sentry.io/1337", + OrgID: 2, + }) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, uint64(2), c.dsn.GetOrgID(), "Custom org id should override the DSN parsed one") +} func BenchmarkProcessEvent(b *testing.B) { c, err := NewClient(ClientOptions{ diff --git a/dynamic_sampling_context.go b/dynamic_sampling_context.go index 5ae38748e..a257f6f5d 100644 --- a/dynamic_sampling_context.go +++ b/dynamic_sampling_context.go @@ -63,6 +63,9 @@ func DynamicSamplingContextFromTransaction(span *Span) DynamicSamplingContext { if publicKey := dsn.GetPublicKey(); publicKey != "" { entries["public_key"] = publicKey } + if orgID := dsn.GetOrgID(); orgID != 0 { + entries["org_id"] = strconv.FormatUint(orgID, 10) + } } if release := client.options.Release; release != "" { entries["release"] = release @@ -113,7 +116,7 @@ func (d DynamicSamplingContext) String() string { return baggage.String() } -// Constructs a new DynamicSamplingContext using a scope and client. Accessing +// DynamicSamplingContextFromScope Constructs a new DynamicSamplingContext using a scope and client. Accessing // fields on the scope are not thread safe, and this function should only be // called within scope methods. func DynamicSamplingContextFromScope(scope *Scope, client *Client) DynamicSamplingContext { @@ -139,6 +142,9 @@ func DynamicSamplingContextFromScope(scope *Scope, client *Client) DynamicSampli if publicKey := dsn.GetPublicKey(); publicKey != "" { entries["public_key"] = publicKey } + if orgID := dsn.GetOrgID(); orgID != 0 { + entries["org_id"] = strconv.FormatUint(orgID, 10) + } } if release := client.options.Release; release != "" { entries["release"] = release diff --git a/internal/protocol/dsn.go b/internal/protocol/dsn.go index 42aff3142..b49e88c5e 100644 --- a/internal/protocol/dsn.go +++ b/internal/protocol/dsn.go @@ -49,6 +49,7 @@ type Dsn struct { port int path string projectID string + orgID uint64 } // NewDsn creates a Dsn by parsing rawURL. Most users will never call this @@ -90,6 +91,17 @@ func NewDsn(rawURL string) (*Dsn, error) { return nil, &DsnParseError{"empty host"} } + // OrgID (optional) + var orgID uint64 + parts := strings.Split(host, ".") + orgPart := parts[0] + if len(orgPart) >= 2 && orgPart[0] == 'o' { + parsedOrgID, err := strconv.ParseUint(orgPart[1:], 10, 64) + if err == nil { + orgID = parsedOrgID + } + } + // Port var port int if p := parsedURL.Port(); p != "" { @@ -126,6 +138,7 @@ func NewDsn(rawURL string) (*Dsn, error) { port: port, path: path, projectID: projectID, + orgID: orgID, }, nil } @@ -182,6 +195,18 @@ func (dsn Dsn) GetProjectID() string { return dsn.projectID } +// GetOrgID returns the orgID that was parsed from the DSN. +func (dsn Dsn) GetOrgID() uint64 { + return dsn.orgID +} + +// SetOrgID sets the orgID used for trace continuation. +// +// This function is used for overriding the orgID parsed from the DSN. +func (dsn *Dsn) SetOrgID(orgID uint64) { + dsn.orgID = orgID +} + // GetAPIURL returns the URL of the envelope endpoint of the project // associated with the DSN. func (dsn Dsn) GetAPIURL() *url.URL { diff --git a/tracing.go b/tracing.go index 70b146d5e..e393c327a 100644 --- a/tracing.go +++ b/tracing.go @@ -953,8 +953,15 @@ func WithSpanOrigin(origin SpanOrigin) SpanOption { func ContinueTrace(hub *Hub, traceparent, baggage string) SpanOption { scope := hub.Scope() propagationContext, _ := PropagationContextFromHeaders(traceparent, baggage) - scope.SetPropagationContext(propagationContext) + client := hub.Client() + + if !shouldContinueTrace(client, propagationContext.DynamicSamplingContext) { + propagationContext = NewPropagationContext() + traceparent = "" + baggage = "" + } + scope.SetPropagationContext(propagationContext) return ContinueFromHeaders(traceparent, baggage) } @@ -973,19 +980,35 @@ func ContinueFromRequest(r *http.Request) SpanOption { // an existing TraceID and propagates the Dynamic Sampling context. func ContinueFromHeaders(trace, baggage string) SpanOption { return func(s *Span) { - if trace != "" { - s.updateFromSentryTrace([]byte(trace)) + if trace == "" { + return + } - if baggage != "" { - s.updateFromBaggage([]byte(baggage)) + // Parse baggage first to get org_id for comparison + var dsc DynamicSamplingContext + if baggage != "" { + parsed, err := DynamicSamplingContextFromHeader([]byte(baggage)) + if err == nil { + dsc = parsed } + } - // In case a sentry-trace header is present but there are no sentry-related - // values in the baggage, create an empty, frozen DynamicSamplingContext. - if !s.dynamicSamplingContext.HasEntries() { - s.dynamicSamplingContext = DynamicSamplingContext{ - Frozen: true, - } + client := hubFromContext(s.ctx).Client() + if !shouldContinueTrace(client, dsc) { + return // leave span unchanged → behaves as head of trace + } + + s.updateFromSentryTrace([]byte(trace)) + + if baggage != "" { + s.updateFromBaggage([]byte(baggage)) + } + + // In case a sentry-trace header is present but there are no sentry-related + // values in the baggage, create an empty, frozen DynamicSamplingContext. + if !s.dynamicSamplingContext.HasEntries() { + s.dynamicSamplingContext = DynamicSamplingContext{ + Frozen: true, } } } @@ -998,6 +1021,10 @@ func ContinueFromTrace(trace string) SpanOption { if trace == "" { return } + client := hubFromContext(s.ctx).Client() + if !shouldContinueTrace(client, DynamicSamplingContext{}) { + return + } s.updateFromSentryTrace([]byte(trace)) } } @@ -1077,3 +1104,35 @@ func HTTPtoSpanStatus(code int) SpanStatus { } return SpanStatusUnknown } + +func shouldContinueTrace(client *Client, dsc DynamicSamplingContext) bool { + if client == nil { + return true + } + + var sdkOrgID uint64 + if client.dsn != nil { + sdkOrgID = client.dsn.GetOrgID() + } + + baggageOrgStr := dsc.Entries["org_id"] + baggageOrgID := uint64(0) + if baggageOrgStr != "" { + baggageOrgID, _ = strconv.ParseUint(baggageOrgStr, 10, 64) + } + + // we reject non-matching orgs regardless of strict mode + if sdkOrgID != 0 && baggageOrgID != 0 && sdkOrgID != baggageOrgID { + return false + } + + // If strict mode is on, both must be present and match + if client.options.StrictTraceContinuation { + if sdkOrgID == 0 && baggageOrgID == 0 { + return true + } + return sdkOrgID == baggageOrgID + } + + return true +} diff --git a/tracing_test.go b/tracing_test.go index 9a698e0b6..faff6da69 100644 --- a/tracing_test.go +++ b/tracing_test.go @@ -485,6 +485,7 @@ func TestContinueSpanFromRequest(t *testing.T) { sampled := sampled t.Run(sampled.String(), func(t *testing.T) { var s Span + s.ctx = context.Background() hkey := http.CanonicalHeaderKey("sentry-trace") hval := (&Span{ TraceID: traceID, @@ -585,12 +586,13 @@ func TestContinueTransactionFromHeaders(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &Span{} + s.ctx = context.Background() spanOption := ContinueFromHeaders(tt.traceStr, tt.baggageStr) spanOption(s) if diff := cmp.Diff(tt.wantSpan, s, cmp.Options{ cmp.AllowUnexported(Span{}), - cmpopts.IgnoreFields(Span{}, "mu", "finishOnce"), + cmpopts.IgnoreFields(Span{}, "ctx", "mu", "finishOnce"), }); diff != "" { t.Fatalf("Expected no difference on spans, got: %s", diff) } @@ -605,13 +607,14 @@ func TestContinueSpanFromTrace(t *testing.T) { for _, sampled := range []Sampled{SampledTrue, SampledFalse, SampledUndefined} { sampled := sampled t.Run(sampled.String(), func(t *testing.T) { - var s Span + s := &Span{} + s.ctx = context.Background() trace := (&Span{ TraceID: traceID, SpanID: spanID, Sampled: sampled, }).ToSentryTrace() - ContinueFromTrace(trace)(&s) + ContinueFromTrace(trace)(s) if s.TraceID != traceID { t.Errorf("got %q, want %q", s.TraceID, traceID) } @@ -1287,3 +1290,68 @@ func TestSpanScopeManagement(t *testing.T) { t.Errorf("expected SpanID %s, got %s", transaction.SpanID, spanID) } } + +func TestStrictTraceContinuation(t *testing.T) { + incomingTraceID := TraceIDFromHex("bc6d53f15eb88f4320054569b8c553d4") + sentryTrace := "bc6d53f15eb88f4320054569b8c553d4-b72fa28504b07285-1" + + baggageWithOrg := func(orgID string) string { + return "sentry-org_id=" + orgID + ",sentry-trace_id=bc6d53f15eb88f4320054569b8c553d4" + } + baggageWithoutOrg := "sentry-trace_id=bc6d53f15eb88f4320054569b8c553d4" + + tests := []struct { + name string + baggageOrgID string + sdkOrgID uint64 + strict bool + wantContinued bool + }{ + {"strict=false, baggage=1, sdk=1", "1", 1, false, true}, + {"strict=false, baggage=none, sdk=1", "", 1, false, true}, + {"strict=false, baggage=1, sdk=none", "1", 0, false, true}, + {"strict=false, baggage=none, sdk=none", "", 0, false, true}, + {"strict=false, baggage=1, sdk=2", "1", 2, false, false}, + + {"strict=true, baggage=1, sdk=1", "1", 1, true, true}, + {"strict=true, baggage=none, sdk=1", "", 1, true, false}, + {"strict=true, baggage=1, sdk=none", "1", 0, true, false}, + {"strict=true, baggage=none, sdk=none", "", 0, true, true}, + {"strict=true, baggage=1, sdk=2", "1", 2, true, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := &MockTransport{} + ctx := NewTestContext(ClientOptions{ + Dsn: testDsn, + EnableTracing: true, + TracesSampleRate: 1.0, + Transport: transport, + StrictTraceContinuation: tt.strict, + OrgID: tt.sdkOrgID, + }) + + baggage := baggageWithoutOrg + if tt.baggageOrgID != "" { + baggage = baggageWithOrg(tt.baggageOrgID) + } + + hub := GetHubFromContext(ctx) + transaction := StartTransaction(ctx, "test", + ContinueTrace(hub, sentryTrace, baggage), + ) + transaction.Finish() + + if tt.wantContinued { + if transaction.TraceID != incomingTraceID { + t.Errorf("expected trace to be continued, got new TraceID %s", transaction.TraceID) + } + } else { + if transaction.TraceID == incomingTraceID { + t.Errorf("expected new trace, but got continued TraceID %s", transaction.TraceID) + } + } + }) + } +}