Skip to content

Commit 52598d5

Browse files
authored
Limit train samples (#2809)
* Limit training samples size to 2GB * simplified DISPLAYLEVEL() macro to use global vqriable instead of local. * refactored training samples loading * fixed compiler warning * addressed comments from the pull request * addressed @terrelln comments * missed some fixes * fixed type mismatch * Fixed bug passing estimated number of samples rather insted of the loaded number of samples. Changed unit conversion not to use bit-shifts. * fixed a declaration after code * fixed type conversion compile errors * fixed more type castting * fixed more type mismatching * changed sizes type to size_t * move type casting * more type cast fixes
1 parent 7868f38 commit 52598d5

File tree

5 files changed

+168
-83
lines changed

5 files changed

+168
-83
lines changed

lib/dictBuilder/cover.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@
4040
/*-*************************************
4141
* Constants
4242
***************************************/
43+
/**
44+
* There are 32bit indexes used to ref samples, so limit samples size to 4GB
45+
* on 64bit builds.
46+
* For 32bit builds we choose 1 GB.
47+
* Most 32bit platforms have 2GB user-mode addressable space and we allocate a large
48+
* contiguous buffer, so 1GB is already a high limit.
49+
*/
4350
#define COVER_MAX_SAMPLES_SIZE (sizeof(size_t) == 8 ? ((unsigned)-1) : ((unsigned)1 GB))
4451
#define COVER_DEFAULT_SPLITPOINT 1.0
4552

lib/dictBuilder/fastcover.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@
3232
/*-*************************************
3333
* Constants
3434
***************************************/
35+
/**
36+
* There are 32bit indexes used to ref samples, so limit samples size to 4GB
37+
* on 64bit builds.
38+
* For 32bit builds we choose 1 GB.
39+
* Most 32bit platforms have 2GB user-mode addressable space and we allocate a large
40+
* contiguous buffer, so 1GB is already a high limit.
41+
*/
3542
#define FASTCOVER_MAX_SAMPLES_SIZE (sizeof(size_t) == 8 ? ((unsigned)-1) : ((unsigned)1 GB))
3643
#define FASTCOVER_MAX_F 31
3744
#define FASTCOVER_MAX_ACCEL 10

programs/dibio.c

Lines changed: 149 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
static const size_t g_maxMemory = (sizeof(size_t) == 4) ? (2 GB - 64 MB) : ((size_t)(512 MB) << sizeof(size_t));
5050

5151
#define NOISELENGTH 32
52+
#define MAX_SAMPLES_SIZE (2 GB) /* training dataset limited to 2GB */
5253

5354

5455
/*-*************************************
@@ -88,6 +89,15 @@ static UTIL_time_t g_displayClock = UTIL_TIME_INITIALIZER;
8889
#undef MIN
8990
#define MIN(a,b) ((a) < (b) ? (a) : (b))
9091

92+
/**
93+
Returns the size of a file.
94+
If error returns -1.
95+
*/
96+
static S64 DiB_getFileSize (const char * fileName)
97+
{
98+
U64 const fileSize = UTIL_getFileSize(fileName);
99+
return (fileSize == UTIL_FILESIZE_UNKNOWN) ? -1 : (S64)fileSize;
100+
}
91101

92102
/* ********************************************************
93103
* File related operations
@@ -101,47 +111,67 @@ static UTIL_time_t g_displayClock = UTIL_TIME_INITIALIZER;
101111
* *bufferSizePtr is modified, it provides the amount data loaded within buffer.
102112
* sampleSizes is filled with the size of each sample.
103113
*/
104-
static unsigned DiB_loadFiles(void* buffer, size_t* bufferSizePtr,
105-
size_t* sampleSizes, unsigned sstSize,
106-
const char** fileNamesTable, unsigned nbFiles, size_t targetChunkSize,
107-
unsigned displayLevel)
114+
static int DiB_loadFiles(
115+
void* buffer, size_t* bufferSizePtr,
116+
size_t* sampleSizes, int sstSize,
117+
const char** fileNamesTable, int nbFiles,
118+
size_t targetChunkSize, int displayLevel )
108119
{
109120
char* const buff = (char*)buffer;
110-
size_t pos = 0;
111-
unsigned nbLoadedChunks = 0, fileIndex;
112-
113-
for (fileIndex=0; fileIndex<nbFiles; fileIndex++) {
114-
const char* const fileName = fileNamesTable[fileIndex];
115-
unsigned long long const fs64 = UTIL_getFileSize(fileName);
116-
unsigned long long remainingToLoad = (fs64 == UTIL_FILESIZE_UNKNOWN) ? 0 : fs64;
117-
U32 const nbChunks = targetChunkSize ? (U32)((fs64 + (targetChunkSize-1)) / targetChunkSize) : 1;
118-
U64 const chunkSize = targetChunkSize ? MIN(targetChunkSize, fs64) : fs64;
119-
size_t const maxChunkSize = (size_t)MIN(chunkSize, SAMPLESIZE_MAX);
120-
U32 cnb;
121-
FILE* const f = fopen(fileName, "rb");
122-
if (f==NULL) EXM_THROW(10, "zstd: dictBuilder: %s %s ", fileName, strerror(errno));
123-
DISPLAYUPDATE(2, "Loading %s... \r", fileName);
124-
for (cnb=0; cnb<nbChunks; cnb++) {
125-
size_t const toLoad = (size_t)MIN(maxChunkSize, remainingToLoad);
126-
if (toLoad > *bufferSizePtr-pos) break;
127-
{ size_t const readSize = fread(buff+pos, 1, toLoad, f);
128-
if (readSize != toLoad) EXM_THROW(11, "Pb reading %s", fileName);
129-
pos += readSize;
130-
sampleSizes[nbLoadedChunks++] = toLoad;
131-
remainingToLoad -= targetChunkSize;
132-
if (nbLoadedChunks == sstSize) { /* no more space left in sampleSizes table */
133-
fileIndex = nbFiles; /* stop there */
121+
size_t totalDataLoaded = 0;
122+
int nbSamplesLoaded = 0;
123+
int fileIndex = 0;
124+
FILE * f = NULL;
125+
126+
assert(targetChunkSize <= SAMPLESIZE_MAX);
127+
128+
while ( nbSamplesLoaded < sstSize && fileIndex < nbFiles ) {
129+
size_t fileDataLoaded;
130+
S64 const fileSize = DiB_getFileSize(fileNamesTable[fileIndex]);
131+
if (fileSize <= 0) /* skip if zero-size or file error */
132+
continue;
133+
134+
f = fopen( fileNamesTable[fileIndex], "rb");
135+
if (f == NULL)
136+
EXM_THROW(10, "zstd: dictBuilder: %s %s ", fileNamesTable[fileIndex], strerror(errno));
137+
DISPLAYUPDATE(2, "Loading %s... \r", fileNamesTable[fileIndex]);
138+
139+
/* Load the first chunk of data from the file */
140+
fileDataLoaded = targetChunkSize > 0 ?
141+
(size_t)MIN(fileSize, (S64)targetChunkSize) :
142+
(size_t)MIN(fileSize, SAMPLESIZE_MAX );
143+
if (totalDataLoaded + fileDataLoaded > *bufferSizePtr)
144+
break;
145+
if (fread( buff+totalDataLoaded, 1, fileDataLoaded, f ) != fileDataLoaded)
146+
EXM_THROW(11, "Pb reading %s", fileNamesTable[fileIndex]);
147+
sampleSizes[nbSamplesLoaded++] = fileDataLoaded;
148+
totalDataLoaded += fileDataLoaded;
149+
150+
/* If file-chunking is enabled, load the rest of the file as more samples */
151+
if (targetChunkSize > 0) {
152+
while( (S64)fileDataLoaded < fileSize && nbSamplesLoaded < sstSize ) {
153+
size_t const chunkSize = MIN((size_t)(fileSize-fileDataLoaded), targetChunkSize);
154+
if (totalDataLoaded + chunkSize > *bufferSizePtr) /* buffer is full */
134155
break;
135-
}
136-
if (toLoad < targetChunkSize) {
137-
fseek(f, (long)(targetChunkSize - toLoad), SEEK_CUR);
138-
} } }
139-
fclose(f);
156+
157+
if (fread( buff+totalDataLoaded, 1, chunkSize, f ) != chunkSize)
158+
EXM_THROW(11, "Pb reading %s", fileNamesTable[fileIndex]);
159+
sampleSizes[nbSamplesLoaded++] = chunkSize;
160+
totalDataLoaded += chunkSize;
161+
fileDataLoaded += chunkSize;
162+
}
163+
}
164+
fileIndex += 1;
165+
fclose(f); f = NULL;
140166
}
167+
if (f != NULL)
168+
fclose(f);
169+
141170
DISPLAYLEVEL(2, "\r%79s\r", "");
142-
*bufferSizePtr = pos;
143-
DISPLAYLEVEL(4, "loaded : %u KB \n", (unsigned)(pos >> 10))
144-
return nbLoadedChunks;
171+
DISPLAYLEVEL(4, "Loaded %d KB total training data, %d nb samples \n",
172+
(int)(totalDataLoaded / (1 KB)), nbSamplesLoaded );
173+
*bufferSizePtr = totalDataLoaded;
174+
return nbSamplesLoaded;
145175
}
146176

147177
#define DiB_rotl32(x,r) ((x << r) | (x >> (32 - r)))
@@ -223,58 +253,98 @@ static void DiB_saveDict(const char* dictFileName,
223253
if (n!=0) EXM_THROW(5, "%s : flush error", dictFileName) }
224254
}
225255

226-
227256
typedef struct {
228-
U64 totalSizeToLoad;
229-
unsigned oneSampleTooLarge;
230-
unsigned nbSamples;
257+
S64 totalSizeToLoad;
258+
int nbSamples;
259+
int oneSampleTooLarge;
231260
} fileStats;
232261

233262
/*! DiB_fileStats() :
234263
* Given a list of files, and a chunkSize (0 == no chunk, whole files)
235264
* provides the amount of data to be loaded and the resulting nb of samples.
236265
* This is useful primarily for allocation purpose => sample buffer, and sample sizes table.
237266
*/
238-
static fileStats DiB_fileStats(const char** fileNamesTable, unsigned nbFiles, size_t chunkSize, unsigned displayLevel)
267+
static fileStats DiB_fileStats(const char** fileNamesTable, int nbFiles, size_t chunkSize, int displayLevel)
239268
{
240269
fileStats fs;
241-
unsigned n;
270+
int n;
242271
memset(&fs, 0, sizeof(fs));
272+
273+
// We assume that if chunking is requsted, the chunk size is < SAMPLESIZE_MAX
274+
assert( chunkSize <= SAMPLESIZE_MAX );
275+
243276
for (n=0; n<nbFiles; n++) {
244-
U64 const fileSize = UTIL_getFileSize(fileNamesTable[n]);
245-
U64 const srcSize = (fileSize == UTIL_FILESIZE_UNKNOWN) ? 0 : fileSize;
246-
U32 const nbSamples = (U32)(chunkSize ? (srcSize + (chunkSize-1)) / chunkSize : 1);
247-
U64 const chunkToLoad = chunkSize ? MIN(chunkSize, srcSize) : srcSize;
248-
size_t const cappedChunkSize = (size_t)MIN(chunkToLoad, SAMPLESIZE_MAX);
249-
fs.totalSizeToLoad += cappedChunkSize * nbSamples;
250-
fs.oneSampleTooLarge |= (chunkSize > 2*SAMPLESIZE_MAX);
251-
fs.nbSamples += nbSamples;
277+
S64 const fileSize = DiB_getFileSize(fileNamesTable[n]);
278+
// TODO: is there a minimum sample size? What if the file is 1-byte?
279+
if (fileSize == 0) {
280+
DISPLAYLEVEL(3, "Sample file '%s' has zero size, skipping...\n", fileNamesTable[n]);
281+
continue;
282+
}
283+
284+
/* the case where we are breaking up files in sample chunks */
285+
if (chunkSize > 0)
286+
{
287+
// TODO: is there a minimum sample size? Can we have a 1-byte sample?
288+
fs.nbSamples += (int)((fileSize + chunkSize-1) / chunkSize);
289+
fs.totalSizeToLoad += fileSize;
290+
}
291+
else {
292+
/* the case where one file is one sample */
293+
if (fileSize > SAMPLESIZE_MAX) {
294+
/* flag excessively large sample files */
295+
fs.oneSampleTooLarge |= (fileSize > 2*SAMPLESIZE_MAX);
296+
297+
/* Limit to the first SAMPLESIZE_MAX (128kB) of the file */
298+
DISPLAYLEVEL(3, "Sample file '%s' is too large, limiting to %d KB",
299+
fileNamesTable[n], SAMPLESIZE_MAX / (1 KB));
300+
}
301+
fs.nbSamples += 1;
302+
fs.totalSizeToLoad += MIN(fileSize, SAMPLESIZE_MAX);
303+
}
252304
}
253-
DISPLAYLEVEL(4, "Preparing to load : %u KB \n", (unsigned)(fs.totalSizeToLoad >> 10));
305+
DISPLAYLEVEL(4, "Found training data %d files, %d KB, %d samples\n", nbFiles, (int)(fs.totalSizeToLoad / (1 KB)), fs.nbSamples);
254306
return fs;
255307
}
256308

257-
258-
int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
259-
const char** fileNamesTable, unsigned nbFiles, size_t chunkSize,
309+
int DiB_trainFromFiles(const char* dictFileName, size_t maxDictSize,
310+
const char** fileNamesTable, int nbFiles, size_t chunkSize,
260311
ZDICT_legacy_params_t* params, ZDICT_cover_params_t* coverParams,
261312
ZDICT_fastCover_params_t* fastCoverParams, int optimize)
262313
{
263-
unsigned const displayLevel = params ? params->zParams.notificationLevel :
264-
coverParams ? coverParams->zParams.notificationLevel :
265-
fastCoverParams ? fastCoverParams->zParams.notificationLevel :
266-
0; /* should never happen */
314+
fileStats fs;
315+
size_t* sampleSizes; /* vector of sample sizes. Each sample can be up to SAMPLESIZE_MAX */
316+
int nbSamplesLoaded; /* nb of samples effectively loaded in srcBuffer */
317+
size_t loadedSize; /* total data loaded in srcBuffer for all samples */
318+
void* srcBuffer /* contiguous buffer with training data/samples */;
267319
void* const dictBuffer = malloc(maxDictSize);
268-
fileStats const fs = DiB_fileStats(fileNamesTable, nbFiles, chunkSize, displayLevel);
269-
size_t* const sampleSizes = (size_t*)malloc(fs.nbSamples * sizeof(size_t));
270-
size_t const memMult = params ? MEMMULT :
271-
coverParams ? COVER_MEMMULT:
272-
FASTCOVER_MEMMULT;
273-
size_t const maxMem = DiB_findMaxMem(fs.totalSizeToLoad * memMult) / memMult;
274-
size_t loadedSize = (size_t) MIN ((unsigned long long)maxMem, fs.totalSizeToLoad);
275-
void* const srcBuffer = malloc(loadedSize+NOISELENGTH);
276320
int result = 0;
277321

322+
int const displayLevel = params ? params->zParams.notificationLevel :
323+
coverParams ? coverParams->zParams.notificationLevel :
324+
fastCoverParams ? fastCoverParams->zParams.notificationLevel : 0;
325+
326+
/* Shuffle input files before we start assessing how much sample datA to load.
327+
The purpose of the shuffle is to pick random samples when the sample
328+
set is larger than what we can load in memory. */
329+
DISPLAYLEVEL(3, "Shuffling input files\n");
330+
DiB_shuffle(fileNamesTable, nbFiles);
331+
332+
/* Figure out how much sample data to load with how many samples */
333+
fs = DiB_fileStats(fileNamesTable, nbFiles, chunkSize, displayLevel);
334+
335+
{
336+
int const memMult = params ? MEMMULT :
337+
coverParams ? COVER_MEMMULT:
338+
FASTCOVER_MEMMULT;
339+
size_t const maxMem = DiB_findMaxMem(fs.totalSizeToLoad * memMult) / memMult;
340+
/* Limit the size of the training data to the free memory */
341+
/* Limit the size of the training data to 2GB */
342+
/* TODO: there is oportunity to stop DiB_fileStats() early when the data limit is reached */
343+
loadedSize = (size_t)MIN( MIN((S64)maxMem, fs.totalSizeToLoad), MAX_SAMPLES_SIZE );
344+
srcBuffer = malloc(loadedSize+NOISELENGTH);
345+
sampleSizes = (size_t*)malloc(fs.nbSamples * sizeof(size_t));
346+
}
347+
278348
/* Checks */
279349
if ((!sampleSizes) || (!srcBuffer) || (!dictBuffer))
280350
EXM_THROW(12, "not enough memory for DiB_trainFiles"); /* should not happen */
@@ -289,31 +359,32 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
289359
DISPLAYLEVEL(2, "! Alternatively, split files into fixed-size blocks representative of samples, with -B# \n");
290360
EXM_THROW(14, "nb of samples too low"); /* we now clearly forbid this case */
291361
}
292-
if (fs.totalSizeToLoad < (unsigned long long)maxDictSize * 8) {
362+
if (fs.totalSizeToLoad < (S64)maxDictSize * 8) {
293363
DISPLAYLEVEL(2, "! Warning : data size of samples too small for target dictionary size \n");
294364
DISPLAYLEVEL(2, "! Samples should be about 100x larger than target dictionary size \n");
295365
}
296366

297367
/* init */
298-
if (loadedSize < fs.totalSizeToLoad)
299-
DISPLAYLEVEL(1, "Not enough memory; training on %u MB only...\n", (unsigned)(loadedSize >> 20));
368+
if ((S64)loadedSize < fs.totalSizeToLoad)
369+
DISPLAYLEVEL(1, "Training samples set too large (%u MB); training on %u MB only...\n",
370+
(unsigned)(fs.totalSizeToLoad / (1 MB)),
371+
(unsigned)(loadedSize / (1 MB)));
300372

301373
/* Load input buffer */
302-
DISPLAYLEVEL(3, "Shuffling input files\n");
303-
DiB_shuffle(fileNamesTable, nbFiles);
304-
305-
DiB_loadFiles(srcBuffer, &loadedSize, sampleSizes, fs.nbSamples, fileNamesTable, nbFiles, chunkSize, displayLevel);
374+
nbSamplesLoaded = DiB_loadFiles(
375+
srcBuffer, &loadedSize, sampleSizes, fs.nbSamples, fileNamesTable,
376+
nbFiles, chunkSize, displayLevel);
306377

307378
{ size_t dictSize;
308379
if (params) {
309380
DiB_fillNoise((char*)srcBuffer + loadedSize, NOISELENGTH); /* guard band, for end of buffer condition */
310381
dictSize = ZDICT_trainFromBuffer_legacy(dictBuffer, maxDictSize,
311-
srcBuffer, sampleSizes, fs.nbSamples,
382+
srcBuffer, sampleSizes, nbSamplesLoaded,
312383
*params);
313384
} else if (coverParams) {
314385
if (optimize) {
315386
dictSize = ZDICT_optimizeTrainFromBuffer_cover(dictBuffer, maxDictSize,
316-
srcBuffer, sampleSizes, fs.nbSamples,
387+
srcBuffer, sampleSizes, nbSamplesLoaded,
317388
coverParams);
318389
if (!ZDICT_isError(dictSize)) {
319390
unsigned splitPercentage = (unsigned)(coverParams->splitPoint * 100);
@@ -322,13 +393,13 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
322393
}
323394
} else {
324395
dictSize = ZDICT_trainFromBuffer_cover(dictBuffer, maxDictSize, srcBuffer,
325-
sampleSizes, fs.nbSamples, *coverParams);
396+
sampleSizes, nbSamplesLoaded, *coverParams);
326397
}
327398
} else {
328399
assert(fastCoverParams != NULL);
329400
if (optimize) {
330401
dictSize = ZDICT_optimizeTrainFromBuffer_fastCover(dictBuffer, maxDictSize,
331-
srcBuffer, sampleSizes, fs.nbSamples,
402+
srcBuffer, sampleSizes, nbSamplesLoaded,
332403
fastCoverParams);
333404
if (!ZDICT_isError(dictSize)) {
334405
unsigned splitPercentage = (unsigned)(fastCoverParams->splitPoint * 100);
@@ -338,7 +409,7 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
338409
}
339410
} else {
340411
dictSize = ZDICT_trainFromBuffer_fastCover(dictBuffer, maxDictSize, srcBuffer,
341-
sampleSizes, fs.nbSamples, *fastCoverParams);
412+
sampleSizes, nbSamplesLoaded, *fastCoverParams);
342413
}
343414
}
344415
if (ZDICT_isError(dictSize)) {

programs/dibio.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
`parameters` is optional and can be provided with values set to 0, meaning "default".
3232
@return : 0 == ok. Any other : error.
3333
*/
34-
int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
35-
const char** fileNamesTable, unsigned nbFiles, size_t chunkSize,
34+
int DiB_trainFromFiles(const char* dictFileName, size_t maxDictSize,
35+
const char** fileNamesTable, int nbFiles, size_t chunkSize,
3636
ZDICT_legacy_params_t* params, ZDICT_cover_params_t* coverParams,
3737
ZDICT_fastCover_params_t* fastCoverParams, int optimize);
3838

programs/zstdcli.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,18 +1290,18 @@ int main(int argCount, const char* argv[])
12901290
int const optimize = !coverParams.k || !coverParams.d;
12911291
coverParams.nbThreads = (unsigned)nbWorkers;
12921292
coverParams.zParams = zParams;
1293-
operationResult = DiB_trainFromFiles(outFileName, maxDictSize, filenames->fileNames, (unsigned)filenames->tableSize, blockSize, NULL, &coverParams, NULL, optimize);
1293+
operationResult = DiB_trainFromFiles(outFileName, maxDictSize, filenames->fileNames, (int)filenames->tableSize, blockSize, NULL, &coverParams, NULL, optimize);
12941294
} else if (dict == fastCover) {
12951295
int const optimize = !fastCoverParams.k || !fastCoverParams.d;
12961296
fastCoverParams.nbThreads = (unsigned)nbWorkers;
12971297
fastCoverParams.zParams = zParams;
1298-
operationResult = DiB_trainFromFiles(outFileName, maxDictSize, filenames->fileNames, (unsigned)filenames->tableSize, blockSize, NULL, NULL, &fastCoverParams, optimize);
1298+
operationResult = DiB_trainFromFiles(outFileName, maxDictSize, filenames->fileNames, (int)filenames->tableSize, blockSize, NULL, NULL, &fastCoverParams, optimize);
12991299
} else {
13001300
ZDICT_legacy_params_t dictParams;
13011301
memset(&dictParams, 0, sizeof(dictParams));
13021302
dictParams.selectivityLevel = dictSelect;
13031303
dictParams.zParams = zParams;
1304-
operationResult = DiB_trainFromFiles(outFileName, maxDictSize, filenames->fileNames, (unsigned)filenames->tableSize, blockSize, &dictParams, NULL, NULL, 0);
1304+
operationResult = DiB_trainFromFiles(outFileName, maxDictSize, filenames->fileNames, (int)filenames->tableSize, blockSize, &dictParams, NULL, NULL, 0);
13051305
}
13061306
#else
13071307
(void)dictCLevel; (void)dictSelect; (void)dictID; (void)maxDictSize; /* not used when ZSTD_NODICT set */

0 commit comments

Comments
 (0)