diff --git a/tsdb/engine/tsm1/engine.go b/tsdb/engine/tsm1/engine.go index 493fe09ee67..b22e76cff45 100644 --- a/tsdb/engine/tsm1/engine.go +++ b/tsdb/engine/tsm1/engine.go @@ -1290,7 +1290,7 @@ func (e *Engine) addToIndexFromKey(keys [][]byte, fieldTypes []influxql.DataType keys[i], field = SeriesAndFieldFromCompositeKey(keys[i]) name := models.ParseName(keys[i]) mf := e.fieldset.CreateFieldsIfNotExists(name) - if err := mf.CreateFieldIfNotExists(field, fieldTypes[i]); err != nil { + if _, err := mf.CreateFieldIfNotExists(field, fieldTypes[i]); err != nil { return err } diff --git a/tsdb/field_validator.go b/tsdb/field_validator.go index c428e6976a9..887ed9d36e7 100644 --- a/tsdb/field_validator.go +++ b/tsdb/field_validator.go @@ -13,9 +13,11 @@ const MaxFieldValueLength = 1048576 // ValidateFields will return a PartialWriteError if: // - the point has inconsistent fields, or // - the point has fields that are too long -func ValidateFields(mf *MeasurementFields, point models.Point, skipSizeValidation bool) error { +func ValidateFields(mf *MeasurementFields, point models.Point, skipSizeValidation bool) ([]*FieldCreate, error) { pointSize := point.StringSize() iter := point.FieldIterator() + var fieldsToCreate []*FieldCreate + for iter.Next() { if !skipSizeValidation { // Check for size of field too large. Note it is much cheaper to check the whole point size @@ -23,7 +25,7 @@ func ValidateFields(mf *MeasurementFields, point models.Point, skipSizeValidatio // unescape the string, and must at least parse the string) if pointSize > MaxFieldValueLength && iter.Type() == models.String { if sz := len(iter.StringValue()); sz > MaxFieldValueLength { - return PartialWriteError{ + return nil, PartialWriteError{ Reason: fmt.Sprintf( "input field \"%s\" on measurement \"%s\" is too long, %d > %d", iter.FieldKey(), point.Name(), sz, MaxFieldValueLength), @@ -33,14 +35,9 @@ func ValidateFields(mf *MeasurementFields, point models.Point, skipSizeValidatio } } + fieldKey := iter.FieldKey() // Skip fields name "time", they are illegal. - if bytes.Equal(iter.FieldKey(), timeBytes) { - continue - } - - // If the fields is not present, there cannot be a conflict. - f := mf.FieldBytes(iter.FieldKey()) - if f == nil { + if bytes.Equal(fieldKey, timeBytes) { continue } @@ -49,18 +46,26 @@ func ValidateFields(mf *MeasurementFields, point models.Point, skipSizeValidatio continue } - // If the types are not the same, there is a conflict. - if f.Type != dataType { - return PartialWriteError{ + // If the field is not present, remember to create it. + f := mf.FieldBytes(fieldKey) + if f == nil { + fieldsToCreate = append(fieldsToCreate, &FieldCreate{ + Measurement: point.Name(), + Field: &Field{ + Name: string(fieldKey), + Type: dataType, + }}) + } else if f.Type != dataType { + // If the types are not the same, there is a conflict. + return nil, PartialWriteError{ Reason: fmt.Sprintf( "%s: input field \"%s\" on measurement \"%s\" is type %s, already exists as type %s", - ErrFieldTypeConflict, iter.FieldKey(), point.Name(), dataType, f.Type), + ErrFieldTypeConflict, fieldKey, point.Name(), dataType, f.Type), Dropped: 1, } } } - - return nil + return fieldsToCreate, nil } // dataTypeFromModelsFieldType returns the influxql.DataType that corresponds to the diff --git a/tsdb/shard.go b/tsdb/shard.go index d52e2b5c461..6b6e9c4da79 100644 --- a/tsdb/shard.go +++ b/tsdb/shard.go @@ -572,13 +572,13 @@ func (s *Shard) WritePoints(points []models.Point, tracker StatsTracker) error { // to the caller, but continue on writing the remaining points. writeError = err } - atomic.AddInt64(&s.stats.FieldsCreated, int64(len(fieldsToCreate))) // add any new fields and keep track of what needs to be saved - if err := s.createFieldsAndMeasurements(fieldsToCreate); err != nil { + if numFieldsCreated, err := s.createFieldsAndMeasurements(fieldsToCreate); err != nil { return err + } else { + atomic.AddInt64(&s.stats.FieldsCreated, int64(numFieldsCreated)) } - engineTracker := tracker engineTracker.AddedPoints = func(points, values int64) { if tracker.AddedPoints != nil { @@ -697,61 +697,44 @@ func (s *Shard) validateSeriesAndFields(points []models.Point, tracker StatsTrac continue } - // Skip any points whos keys have been dropped. Dropped has already been incremented for them. + // Skip any points whose keys have been dropped. Dropped has already been incremented for them. if len(droppedKeys) > 0 && bytesutil.Contains(droppedKeys, keys[i]) { continue } - name := p.Name() - mf := engine.MeasurementFields(name) - - // Check with the field validator. - if err := ValidateFields(mf, p, s.options.Config.SkipFieldSizeValidation); err != nil { - switch err := err.(type) { - case PartialWriteError: - if reason == "" { - reason = err.Reason + err := func(p models.Point, iter models.FieldIterator) error { + var newFields []*FieldCreate + var validateErr error + name := p.Name() + mf := engine.MeasurementFields(name) + mf.mu.RLock() + defer mf.mu.RUnlock() + // Check with the field validator. + if newFields, validateErr = ValidateFields(mf, p, s.options.Config.SkipFieldSizeValidation); validateErr != nil { + var err PartialWriteError + switch { + case errors.As(validateErr, &err): + // This will turn into an error later, outside this lambda + if reason == "" { + reason = err.Reason + } + dropped += err.Dropped + atomic.AddInt64(&s.stats.WritePointsDropped, int64(err.Dropped)) + default: + return err } - dropped += err.Dropped - atomic.AddInt64(&s.stats.WritePointsDropped, int64(err.Dropped)) - default: - return nil, nil, err + return nil } - continue - } - - points[j] = points[i] - j++ - // Create any fields that are missing. - iter.Reset() - for iter.Next() { - fieldKey := iter.FieldKey() - - // Skip fields named "time". They are illegal. - if bytes.Equal(fieldKey, timeBytes) { - continue - } - - if mf.FieldBytes(fieldKey) != nil { - continue - } - - dataType := dataTypeFromModelsFieldType(iter.Type()) - if dataType == influxql.Unknown { - continue - } - - fieldsToCreate = append(fieldsToCreate, &FieldCreate{ - Measurement: name, - Field: &Field{ - Name: string(fieldKey), - Type: dataType, - }, - }) + points[j] = points[i] + j++ + fieldsToCreate = append(fieldsToCreate, newFields...) + return nil + }(p, iter) + if err != nil { + return nil, nil, err } } - if dropped > 0 { err = PartialWriteError{Reason: reason, Dropped: dropped, Database: s.database, RetentionPolicy: s.retentionPolicy} } @@ -781,31 +764,33 @@ func makePrintable(s string) string { return b.String() } -func (s *Shard) createFieldsAndMeasurements(fieldsToCreate []*FieldCreate) error { +func (s *Shard) createFieldsAndMeasurements(fieldsToCreate []*FieldCreate) (int, error) { if len(fieldsToCreate) == 0 { - return nil + return 0, nil } engine, err := s.engineNoLock() if err != nil { - return err + return 0, err } - + numCreated := 0 // add fields changes := make([]*FieldChange, 0, len(fieldsToCreate)) for _, f := range fieldsToCreate { mf := engine.MeasurementFields(f.Measurement) - if err := mf.CreateFieldIfNotExists([]byte(f.Field.Name), f.Field.Type); err != nil { - return err + if created, err := mf.CreateFieldIfNotExists([]byte(f.Field.Name), f.Field.Type); err != nil { + return 0, err + } else if created { + numCreated++ + s.index.SetFieldName(f.Measurement, f.Field.Name) + changes = append(changes, &FieldChange{ + FieldCreate: *f, + ChangeType: AddMeasurementField, + }) } - s.index.SetFieldName(f.Measurement, f.Field.Name) - changes = append(changes, &FieldChange{ - FieldCreate: *f, - ChangeType: AddMeasurementField, - }) } - return engine.MeasurementFieldSet().Save(changes) + return numCreated, engine.MeasurementFieldSet().Save(changes) } // DeleteSeriesRange deletes all values from for seriesKeys between min and max (inclusive) @@ -1577,7 +1562,7 @@ func (a Shards) ExpandSources(sources influxql.Sources) (influxql.Sources, error // MeasurementFields holds the fields of a measurement and their codec. type MeasurementFields struct { - mu sync.Mutex + mu sync.RWMutex fields atomic.Value // map[string]*Field } @@ -1616,15 +1601,15 @@ func (m *MeasurementFields) bytes() int { // CreateFieldIfNotExists creates a new field with an autoincrementing ID. // Returns an error if 255 fields have already been created on the measurement or // the fields already exists with a different type. -func (m *MeasurementFields) CreateFieldIfNotExists(name []byte, typ influxql.DataType) error { +func (m *MeasurementFields) CreateFieldIfNotExists(name []byte, typ influxql.DataType) (bool, error) { fields := m.fields.Load().(map[string]*Field) // Ignore if the field already exists. if f := fields[string(name)]; f != nil { if f.Type != typ { - return ErrFieldTypeConflict + return false, ErrFieldTypeConflict } - return nil + return false, nil } m.mu.Lock() @@ -1634,9 +1619,9 @@ func (m *MeasurementFields) CreateFieldIfNotExists(name []byte, typ influxql.Dat // Re-check field and type under write lock. if f := fields[string(name)]; f != nil { if f.Type != typ { - return ErrFieldTypeConflict + return false, ErrFieldTypeConflict } - return nil + return false, nil } fieldsUpdate := make(map[string]*Field, len(fields)+1) @@ -1652,7 +1637,7 @@ func (m *MeasurementFields) CreateFieldIfNotExists(name []byte, typ influxql.Dat fieldsUpdate[string(name)] = f m.fields.Store(fieldsUpdate) - return nil + return true, nil } func (m *MeasurementFields) FieldN() int { @@ -2325,7 +2310,7 @@ func (fs *MeasurementFieldSet) ApplyChanges() error { fs.Delete(string(fc.Measurement)) } else { mf := fs.CreateFieldsIfNotExists(fc.Measurement) - if err := mf.CreateFieldIfNotExists([]byte(fc.Field.Name), fc.Field.Type); err != nil { + if _, err := mf.CreateFieldIfNotExists([]byte(fc.Field.Name), fc.Field.Type); err != nil { err = fmt.Errorf("failed creating %q.%q: %w", fc.Measurement, fc.Field.Name, err) log.Error("field creation", zap.Error(err)) return err diff --git a/tsdb/shard_test.go b/tsdb/shard_test.go index 693affa3008..9add6af35e0 100644 --- a/tsdb/shard_test.go +++ b/tsdb/shard_test.go @@ -7,6 +7,7 @@ import ( "fmt" "math" "os" + "path" "path/filepath" "reflect" "regexp" @@ -14,12 +15,10 @@ import ( "sort" "strings" "sync" + "sync/atomic" "testing" "time" - assert2 "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/davecgh/go-spew/spew" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -30,9 +29,12 @@ import ( "github.com/influxdata/influxdb/query" "github.com/influxdata/influxdb/tsdb" _ "github.com/influxdata/influxdb/tsdb/engine" + "github.com/influxdata/influxdb/tsdb/engine/tsm1" _ "github.com/influxdata/influxdb/tsdb/index" "github.com/influxdata/influxdb/tsdb/index/inmem" "github.com/influxdata/influxql" + assert2 "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestShardWriteAndIndex(t *testing.T) { @@ -52,7 +54,8 @@ func TestShardWriteAndIndex(t *testing.T) { // Calling WritePoints when the engine is not open will return // ErrEngineClosed. - if got, exp := sh.WritePoints(nil, tsdb.NoopStatsTracker()), tsdb.ErrEngineClosed; got != exp { + got := sh.WritePoints(nil, tsdb.NoopStatsTracker()) + if exp := tsdb.ErrEngineClosed; got != exp { t.Fatalf("got %v, expected %v", got, exp) } @@ -122,7 +125,8 @@ func TestShard_Open_CorruptFieldsIndex(t *testing.T) { // Calling WritePoints when the engine is not open will return // ErrEngineClosed. - if got, exp := sh.WritePoints(nil, tsdb.NoopStatsTracker()), tsdb.ErrEngineClosed; got != exp { + got := sh.WritePoints(nil, tsdb.NoopStatsTracker()) + if exp := tsdb.ErrEngineClosed; got != exp { t.Fatalf("got %v, expected %v", got, exp) } @@ -1687,7 +1691,7 @@ func TestMeasurementFieldSet_SaveLoad(t *testing.T) { } defer checkMeasurementFieldSetClose(t, mf) fields := mf.CreateFieldsIfNotExists([]byte(measurement)) - if err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil { + if _, err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil { t.Fatalf("create field error: %v", err) } change := tsdb.FieldChange{ @@ -1739,7 +1743,7 @@ func TestMeasurementFieldSet_Corrupt(t *testing.T) { measurement := []byte("cpu") fields := mf.CreateFieldsIfNotExists(measurement) fieldName := "value" - if err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil { + if _, err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil { t.Fatalf("create field error: %v", err) } change := tsdb.FieldChange{ @@ -1810,7 +1814,7 @@ func TestMeasurementFieldSet_CorruptChangeFile(t *testing.T) { defer checkMeasurementFieldSetClose(t, mf) for _, f := range testFields { fields := mf.CreateFieldsIfNotExists([]byte(f.Measurement)) - if err := fields.CreateFieldIfNotExists([]byte(f.Field), f.FieldType); err != nil { + if _, err := fields.CreateFieldIfNotExists([]byte(f.Field), f.FieldType); err != nil { t.Fatalf("create field error: %v", err) } change := tsdb.FieldChange{ @@ -1872,7 +1876,7 @@ func TestMeasurementFieldSet_DeleteEmpty(t *testing.T) { defer checkMeasurementFieldSetClose(t, mf) fields := mf.CreateFieldsIfNotExists([]byte(measurement)) - if err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil { + if _, err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil { t.Fatalf("create field error: %v", err) } @@ -2005,7 +2009,7 @@ func testFieldMaker(t *testing.T, wg *sync.WaitGroup, mf *tsdb.MeasurementFieldS fields := mf.CreateFieldsIfNotExists([]byte(measurement)) for _, fieldName := range fieldNames { - if err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil { + if _, err := fields.CreateFieldIfNotExists([]byte(fieldName), influxql.Float); err != nil { t.Logf("create field error: %v", err) t.Fail() return @@ -2655,3 +2659,161 @@ func (a seriesIDSets) ForEach(f func(ids *tsdb.SeriesIDSet)) error { } return nil } + +// Tests concurrently writing to the same shard with different field types which +// can trigger a panic when the shard is snapshotted to TSM files. +func TestShard_WritePoints_ForceFieldConflictConcurrent(t *testing.T) { + const Runs = 50 + if testing.Short() || runtime.GOOS == "windows" { + t.Skip("Skipping on short or windows") + } + for i := 0; i < Runs; i++ { + conflictShard(t, i) + } +} + +func conflictShard(t *testing.T, run int) { + const measurement = "cpu" + const field = "value" + const numTypes = 4 // float, int, bool, string + const pointCopies = 10 + const trialsPerShard = 10 + + tmpDir, _ := os.MkdirTemp("", "shard_test") + defer func() { + require.NoError(t, os.RemoveAll(tmpDir), "removing %s", tmpDir) + }() + tmpShard := filepath.Join(tmpDir, "shard") + tmpWal := filepath.Join(tmpDir, "wal") + + sfile := MustOpenSeriesFile() + defer func() { + require.NoError(t, sfile.Close(), "closing series file") + require.NoError(t, os.RemoveAll(sfile.Path()), "removing series file %s", sfile.Path()) + }() + + opts := tsdb.NewEngineOptions() + opts.Config.WALDir = tmpWal + opts.InmemIndex = inmem.NewIndex(filepath.Base(tmpDir), sfile.SeriesFile) + opts.SeriesIDSets = seriesIDSets([]*tsdb.SeriesIDSet{}) + sh := tsdb.NewShard(1, tmpShard, tmpWal, sfile.SeriesFile, opts) + require.NoError(t, sh.Open(), "opening shard: %s", sh.Path()) + defer func() { + require.NoError(t, sh.Close(), "closing shard %s", tmpShard) + }() + var wg sync.WaitGroup + mu := sync.RWMutex{} + maxConcurrency := atomic.Int64{} + + currentTime := time.Now() + + points := make([]models.Point, 0, pointCopies*numTypes) + + for i := 0; i < pointCopies; i++ { + points = append(points, models.MustNewPoint( + measurement, + models.NewTags(map[string]string{"host": "server"}), + map[string]interface{}{field: 1.0}, + currentTime.Add(time.Duration(i)*time.Second), + )) + points = append(points, models.MustNewPoint( + measurement, + models.NewTags(map[string]string{"host": "server"}), + map[string]interface{}{field: int64(1)}, + currentTime.Add(time.Duration(i)*time.Second), + )) + points = append(points, models.MustNewPoint( + measurement, + models.NewTags(map[string]string{"host": "server"}), + map[string]interface{}{field: "one"}, + currentTime.Add(time.Duration(i)*time.Second), + )) + points = append(points, models.MustNewPoint( + measurement, + models.NewTags(map[string]string{"host": "server"}), + map[string]interface{}{field: true}, + currentTime.Add(time.Duration(i)*time.Second), + )) + } + concurrency := atomic.Int64{} + + for i := 0; i < trialsPerShard; i++ { + mu.Lock() + wg.Add(len(points)) + // Write points concurrently + for i := 0; i < pointCopies; i++ { + for j := 0; j < numTypes; j++ { + concurrency.Add(1) + go func(mp models.Point) { + mu.RLock() + defer concurrency.Add(-1) + defer mu.RUnlock() + defer wg.Done() + if err := sh.WritePoints([]models.Point{mp}, tsdb.NoopStatsTracker()); err == nil { + fs, err := mp.Fields() + require.NoError(t, err, "getting fields") + require.Equal(t, + sh.MeasurementFields([]byte(measurement)).Field(field).Type, + influxql.InspectDataType(fs[field]), + "field types mismatch on run %d: types exp: %s, got: %s", run+1, sh.MeasurementFields([]byte(measurement)).Field(field).Type.String(), influxql.InspectDataType(fs[field]).String()) + } else { + require.ErrorContains(t, err, tsdb.ErrFieldTypeConflict.Error(), "unexpected error") + } + if c := concurrency.Load(); maxConcurrency.Load() < c { + maxConcurrency.Store(c) + } + }(points[i*numTypes+j]) + } + } + mu.Unlock() + wg.Wait() + dir, err := sh.CreateSnapshot(false) + require.NoError(t, err, "creating snapshot: %s", sh.Path()) + require.NoError(t, os.RemoveAll(dir), "removing snapshot directory %s", dir) + } + keyType := map[string]byte{} + files, err := os.ReadDir(tmpShard) + require.NoError(t, err, "reading shard directory %s", tmpShard) + for i, file := range files { + if !strings.HasSuffix(path.Ext(file.Name()), tsm1.TSMFileExtension) { + continue + } + ffile := path.Join(tmpShard, file.Name()) + fh, err := os.Open(ffile) + require.NoError(t, err, "opening snapshot file %s", ffile) + tr, err := tsm1.NewTSMReader(fh) + require.NoError(t, err, "creating TSM reader for %s", ffile) + key, typ := tr.KeyAt(0) + if oldTyp, ok := keyType[string(key)]; ok { + require.Equal(t, oldTyp, typ, + "field type mismatch in run %d TSM file %d -- %q in %s\nfirst seen: %s, newest: %s, field type: %s", + run+1, + i+1, + string(key), + ffile, + blockTypeString(oldTyp), + blockTypeString(typ), + sh.MeasurementFields([]byte(measurement)).Field(field).Type.String()) + } else { + keyType[string(key)] = typ + } + // Must close after all uses of key (mapped memory) + require.NoError(t, tr.Close(), "closing TSM reader") + } + // t.Logf("Type %s wins run %d with concurrency: %d", sh.MeasurementFields([]byte(measurement)).Field(field).Type.String(), run+1, maxConcurrency.Load()) +} + +func blockTypeString(typ byte) string { + switch typ { + case tsm1.BlockFloat64: + return "float64" + case tsm1.BlockInteger: + return "int64" + case tsm1.BlockBoolean: + return "bool" + case tsm1.BlockString: + return "string" + default: + return "unknown" + } +}