4949static const size_t g_maxMemory = (sizeof (size_t ) == 4 ) ? (2 GB - 64 MB ) : ((size_t )(512 MB ) << sizeof (size_t ));
5050
5151#define NOISELENGTH 32
52+ #define MAX_SAMPLES_SIZE (2 GB) /* training dataset limited to 2GB */
5253
5354
5455/*-*************************************
@@ -88,6 +89,15 @@ static UTIL_time_t g_displayClock = UTIL_TIME_INITIALIZER;
8889#undef MIN
8990#define MIN (a ,b ) ((a) < (b) ? (a) : (b))
9091
92+ /**
93+ Returns the size of a file.
94+ If error returns -1.
95+ */
96+ static S64 DiB_getFileSize (const char * fileName )
97+ {
98+ U64 const fileSize = UTIL_getFileSize (fileName );
99+ return (fileSize == UTIL_FILESIZE_UNKNOWN ) ? -1 : (S64 )fileSize ;
100+ }
91101
92102/* ********************************************************
93103* File related operations
@@ -101,47 +111,67 @@ static UTIL_time_t g_displayClock = UTIL_TIME_INITIALIZER;
101111 * *bufferSizePtr is modified, it provides the amount data loaded within buffer.
102112 * sampleSizes is filled with the size of each sample.
103113 */
104- static unsigned DiB_loadFiles (void * buffer , size_t * bufferSizePtr ,
105- size_t * sampleSizes , unsigned sstSize ,
106- const char * * fileNamesTable , unsigned nbFiles , size_t targetChunkSize ,
107- unsigned displayLevel )
114+ static int DiB_loadFiles (
115+ void * buffer , size_t * bufferSizePtr ,
116+ size_t * sampleSizes , int sstSize ,
117+ const char * * fileNamesTable , int nbFiles ,
118+ size_t targetChunkSize , int displayLevel )
108119{
109120 char * const buff = (char * )buffer ;
110- size_t pos = 0 ;
111- unsigned nbLoadedChunks = 0 , fileIndex ;
112-
113- for (fileIndex = 0 ; fileIndex < nbFiles ; fileIndex ++ ) {
114- const char * const fileName = fileNamesTable [fileIndex ];
115- unsigned long long const fs64 = UTIL_getFileSize (fileName );
116- unsigned long long remainingToLoad = (fs64 == UTIL_FILESIZE_UNKNOWN ) ? 0 : fs64 ;
117- U32 const nbChunks = targetChunkSize ? (U32 )((fs64 + (targetChunkSize - 1 )) / targetChunkSize ) : 1 ;
118- U64 const chunkSize = targetChunkSize ? MIN (targetChunkSize , fs64 ) : fs64 ;
119- size_t const maxChunkSize = (size_t )MIN (chunkSize , SAMPLESIZE_MAX );
120- U32 cnb ;
121- FILE * const f = fopen (fileName , "rb" );
122- if (f == NULL ) EXM_THROW (10 , "zstd: dictBuilder: %s %s " , fileName , strerror (errno ));
123- DISPLAYUPDATE (2 , "Loading %s... \r" , fileName );
124- for (cnb = 0 ; cnb < nbChunks ; cnb ++ ) {
125- size_t const toLoad = (size_t )MIN (maxChunkSize , remainingToLoad );
126- if (toLoad > * bufferSizePtr - pos ) break ;
127- { size_t const readSize = fread (buff + pos , 1 , toLoad , f );
128- if (readSize != toLoad ) EXM_THROW (11 , "Pb reading %s" , fileName );
129- pos += readSize ;
130- sampleSizes [nbLoadedChunks ++ ] = toLoad ;
131- remainingToLoad -= targetChunkSize ;
132- if (nbLoadedChunks == sstSize ) { /* no more space left in sampleSizes table */
133- fileIndex = nbFiles ; /* stop there */
121+ size_t totalDataLoaded = 0 ;
122+ int nbSamplesLoaded = 0 ;
123+ int fileIndex = 0 ;
124+ FILE * f = NULL ;
125+
126+ assert (targetChunkSize <= SAMPLESIZE_MAX );
127+
128+ while ( nbSamplesLoaded < sstSize && fileIndex < nbFiles ) {
129+ size_t fileDataLoaded ;
130+ S64 const fileSize = DiB_getFileSize (fileNamesTable [fileIndex ]);
131+ if (fileSize <= 0 ) /* skip if zero-size or file error */
132+ continue ;
133+
134+ f = fopen ( fileNamesTable [fileIndex ], "rb" );
135+ if (f == NULL )
136+ EXM_THROW (10 , "zstd: dictBuilder: %s %s " , fileNamesTable [fileIndex ], strerror (errno ));
137+ DISPLAYUPDATE (2 , "Loading %s... \r" , fileNamesTable [fileIndex ]);
138+
139+ /* Load the first chunk of data from the file */
140+ fileDataLoaded = targetChunkSize > 0 ?
141+ (size_t )MIN (fileSize , (S64 )targetChunkSize ) :
142+ (size_t )MIN (fileSize , SAMPLESIZE_MAX );
143+ if (totalDataLoaded + fileDataLoaded > * bufferSizePtr )
144+ break ;
145+ if (fread ( buff + totalDataLoaded , 1 , fileDataLoaded , f ) != fileDataLoaded )
146+ EXM_THROW (11 , "Pb reading %s" , fileNamesTable [fileIndex ]);
147+ sampleSizes [nbSamplesLoaded ++ ] = fileDataLoaded ;
148+ totalDataLoaded += fileDataLoaded ;
149+
150+ /* If file-chunking is enabled, load the rest of the file as more samples */
151+ if (targetChunkSize > 0 ) {
152+ while ( (S64 )fileDataLoaded < fileSize && nbSamplesLoaded < sstSize ) {
153+ size_t const chunkSize = MIN ((size_t )(fileSize - fileDataLoaded ), targetChunkSize );
154+ if (totalDataLoaded + chunkSize > * bufferSizePtr ) /* buffer is full */
134155 break ;
135- }
136- if (toLoad < targetChunkSize ) {
137- fseek (f , (long )(targetChunkSize - toLoad ), SEEK_CUR );
138- } } }
139- fclose (f );
156+
157+ if (fread ( buff + totalDataLoaded , 1 , chunkSize , f ) != chunkSize )
158+ EXM_THROW (11 , "Pb reading %s" , fileNamesTable [fileIndex ]);
159+ sampleSizes [nbSamplesLoaded ++ ] = chunkSize ;
160+ totalDataLoaded += chunkSize ;
161+ fileDataLoaded += chunkSize ;
162+ }
163+ }
164+ fileIndex += 1 ;
165+ fclose (f ); f = NULL ;
140166 }
167+ if (f != NULL )
168+ fclose (f );
169+
141170 DISPLAYLEVEL (2 , "\r%79s\r" , "" );
142- * bufferSizePtr = pos ;
143- DISPLAYLEVEL (4 , "loaded : %u KB \n" , (unsigned )(pos >> 10 ))
144- return nbLoadedChunks ;
171+ DISPLAYLEVEL (4 , "Loaded %d KB total training data, %d nb samples \n" ,
172+ (int )(totalDataLoaded / (1 KB )), nbSamplesLoaded );
173+ * bufferSizePtr = totalDataLoaded ;
174+ return nbSamplesLoaded ;
145175}
146176
147177#define DiB_rotl32 (x ,r ) ((x << r) | (x >> (32 - r)))
@@ -223,58 +253,98 @@ static void DiB_saveDict(const char* dictFileName,
223253 if (n != 0 ) EXM_THROW (5 , "%s : flush error" , dictFileName ) }
224254}
225255
226-
227256typedef struct {
228- U64 totalSizeToLoad ;
229- unsigned oneSampleTooLarge ;
230- unsigned nbSamples ;
257+ S64 totalSizeToLoad ;
258+ int nbSamples ;
259+ int oneSampleTooLarge ;
231260} fileStats ;
232261
233262/*! DiB_fileStats() :
234263 * Given a list of files, and a chunkSize (0 == no chunk, whole files)
235264 * provides the amount of data to be loaded and the resulting nb of samples.
236265 * This is useful primarily for allocation purpose => sample buffer, and sample sizes table.
237266 */
238- static fileStats DiB_fileStats (const char * * fileNamesTable , unsigned nbFiles , size_t chunkSize , unsigned displayLevel )
267+ static fileStats DiB_fileStats (const char * * fileNamesTable , int nbFiles , size_t chunkSize , int displayLevel )
239268{
240269 fileStats fs ;
241- unsigned n ;
270+ int n ;
242271 memset (& fs , 0 , sizeof (fs ));
272+
273+ // We assume that if chunking is requsted, the chunk size is < SAMPLESIZE_MAX
274+ assert ( chunkSize <= SAMPLESIZE_MAX );
275+
243276 for (n = 0 ; n < nbFiles ; n ++ ) {
244- U64 const fileSize = UTIL_getFileSize (fileNamesTable [n ]);
245- U64 const srcSize = (fileSize == UTIL_FILESIZE_UNKNOWN ) ? 0 : fileSize ;
246- U32 const nbSamples = (U32 )(chunkSize ? (srcSize + (chunkSize - 1 )) / chunkSize : 1 );
247- U64 const chunkToLoad = chunkSize ? MIN (chunkSize , srcSize ) : srcSize ;
248- size_t const cappedChunkSize = (size_t )MIN (chunkToLoad , SAMPLESIZE_MAX );
249- fs .totalSizeToLoad += cappedChunkSize * nbSamples ;
250- fs .oneSampleTooLarge |= (chunkSize > 2 * SAMPLESIZE_MAX );
251- fs .nbSamples += nbSamples ;
277+ S64 const fileSize = DiB_getFileSize (fileNamesTable [n ]);
278+ // TODO: is there a minimum sample size? What if the file is 1-byte?
279+ if (fileSize == 0 ) {
280+ DISPLAYLEVEL (3 , "Sample file '%s' has zero size, skipping...\n" , fileNamesTable [n ]);
281+ continue ;
282+ }
283+
284+ /* the case where we are breaking up files in sample chunks */
285+ if (chunkSize > 0 )
286+ {
287+ // TODO: is there a minimum sample size? Can we have a 1-byte sample?
288+ fs .nbSamples += (int )((fileSize + chunkSize - 1 ) / chunkSize );
289+ fs .totalSizeToLoad += fileSize ;
290+ }
291+ else {
292+ /* the case where one file is one sample */
293+ if (fileSize > SAMPLESIZE_MAX ) {
294+ /* flag excessively large sample files */
295+ fs .oneSampleTooLarge |= (fileSize > 2 * SAMPLESIZE_MAX );
296+
297+ /* Limit to the first SAMPLESIZE_MAX (128kB) of the file */
298+ DISPLAYLEVEL (3 , "Sample file '%s' is too large, limiting to %d KB" ,
299+ fileNamesTable [n ], SAMPLESIZE_MAX / (1 KB ));
300+ }
301+ fs .nbSamples += 1 ;
302+ fs .totalSizeToLoad += MIN (fileSize , SAMPLESIZE_MAX );
303+ }
252304 }
253- DISPLAYLEVEL (4 , "Preparing to load : %u KB \n" , ( unsigned )(fs .totalSizeToLoad >> 10 ) );
305+ DISPLAYLEVEL (4 , "Found training data %d files, %d KB, %d samples \n" , nbFiles , ( int )(fs .totalSizeToLoad / ( 1 KB )), fs . nbSamples );
254306 return fs ;
255307}
256308
257-
258- int DiB_trainFromFiles (const char * dictFileName , unsigned maxDictSize ,
259- const char * * fileNamesTable , unsigned nbFiles , size_t chunkSize ,
309+ int DiB_trainFromFiles (const char * dictFileName , size_t maxDictSize ,
310+ const char * * fileNamesTable , int nbFiles , size_t chunkSize ,
260311 ZDICT_legacy_params_t * params , ZDICT_cover_params_t * coverParams ,
261312 ZDICT_fastCover_params_t * fastCoverParams , int optimize )
262313{
263- unsigned const displayLevel = params ? params -> zParams .notificationLevel :
264- coverParams ? coverParams -> zParams .notificationLevel :
265- fastCoverParams ? fastCoverParams -> zParams .notificationLevel :
266- 0 ; /* should never happen */
314+ fileStats fs ;
315+ size_t * sampleSizes ; /* vector of sample sizes. Each sample can be up to SAMPLESIZE_MAX */
316+ int nbSamplesLoaded ; /* nb of samples effectively loaded in srcBuffer */
317+ size_t loadedSize ; /* total data loaded in srcBuffer for all samples */
318+ void * srcBuffer /* contiguous buffer with training data/samples */ ;
267319 void * const dictBuffer = malloc (maxDictSize );
268- fileStats const fs = DiB_fileStats (fileNamesTable , nbFiles , chunkSize , displayLevel );
269- size_t * const sampleSizes = (size_t * )malloc (fs .nbSamples * sizeof (size_t ));
270- size_t const memMult = params ? MEMMULT :
271- coverParams ? COVER_MEMMULT :
272- FASTCOVER_MEMMULT ;
273- size_t const maxMem = DiB_findMaxMem (fs .totalSizeToLoad * memMult ) / memMult ;
274- size_t loadedSize = (size_t ) MIN ((unsigned long long )maxMem , fs .totalSizeToLoad );
275- void * const srcBuffer = malloc (loadedSize + NOISELENGTH );
276320 int result = 0 ;
277321
322+ int const displayLevel = params ? params -> zParams .notificationLevel :
323+ coverParams ? coverParams -> zParams .notificationLevel :
324+ fastCoverParams ? fastCoverParams -> zParams .notificationLevel : 0 ;
325+
326+ /* Shuffle input files before we start assessing how much sample datA to load.
327+ The purpose of the shuffle is to pick random samples when the sample
328+ set is larger than what we can load in memory. */
329+ DISPLAYLEVEL (3 , "Shuffling input files\n" );
330+ DiB_shuffle (fileNamesTable , nbFiles );
331+
332+ /* Figure out how much sample data to load with how many samples */
333+ fs = DiB_fileStats (fileNamesTable , nbFiles , chunkSize , displayLevel );
334+
335+ {
336+ int const memMult = params ? MEMMULT :
337+ coverParams ? COVER_MEMMULT :
338+ FASTCOVER_MEMMULT ;
339+ size_t const maxMem = DiB_findMaxMem (fs .totalSizeToLoad * memMult ) / memMult ;
340+ /* Limit the size of the training data to the free memory */
341+ /* Limit the size of the training data to 2GB */
342+ /* TODO: there is oportunity to stop DiB_fileStats() early when the data limit is reached */
343+ loadedSize = (size_t )MIN ( MIN ((S64 )maxMem , fs .totalSizeToLoad ), MAX_SAMPLES_SIZE );
344+ srcBuffer = malloc (loadedSize + NOISELENGTH );
345+ sampleSizes = (size_t * )malloc (fs .nbSamples * sizeof (size_t ));
346+ }
347+
278348 /* Checks */
279349 if ((!sampleSizes ) || (!srcBuffer ) || (!dictBuffer ))
280350 EXM_THROW (12 , "not enough memory for DiB_trainFiles" ); /* should not happen */
@@ -289,31 +359,32 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
289359 DISPLAYLEVEL (2 , "! Alternatively, split files into fixed-size blocks representative of samples, with -B# \n" );
290360 EXM_THROW (14 , "nb of samples too low" ); /* we now clearly forbid this case */
291361 }
292- if (fs .totalSizeToLoad < (unsigned long long )maxDictSize * 8 ) {
362+ if (fs .totalSizeToLoad < (S64 )maxDictSize * 8 ) {
293363 DISPLAYLEVEL (2 , "! Warning : data size of samples too small for target dictionary size \n" );
294364 DISPLAYLEVEL (2 , "! Samples should be about 100x larger than target dictionary size \n" );
295365 }
296366
297367 /* init */
298- if (loadedSize < fs .totalSizeToLoad )
299- DISPLAYLEVEL (1 , "Not enough memory; training on %u MB only...\n" , (unsigned )(loadedSize >> 20 ));
368+ if ((S64 )loadedSize < fs .totalSizeToLoad )
369+ DISPLAYLEVEL (1 , "Training samples set too large (%u MB); training on %u MB only...\n" ,
370+ (unsigned )(fs .totalSizeToLoad / (1 MB )),
371+ (unsigned )(loadedSize / (1 MB )));
300372
301373 /* Load input buffer */
302- DISPLAYLEVEL (3 , "Shuffling input files\n" );
303- DiB_shuffle (fileNamesTable , nbFiles );
304-
305- DiB_loadFiles (srcBuffer , & loadedSize , sampleSizes , fs .nbSamples , fileNamesTable , nbFiles , chunkSize , displayLevel );
374+ nbSamplesLoaded = DiB_loadFiles (
375+ srcBuffer , & loadedSize , sampleSizes , fs .nbSamples , fileNamesTable ,
376+ nbFiles , chunkSize , displayLevel );
306377
307378 { size_t dictSize ;
308379 if (params ) {
309380 DiB_fillNoise ((char * )srcBuffer + loadedSize , NOISELENGTH ); /* guard band, for end of buffer condition */
310381 dictSize = ZDICT_trainFromBuffer_legacy (dictBuffer , maxDictSize ,
311- srcBuffer , sampleSizes , fs . nbSamples ,
382+ srcBuffer , sampleSizes , nbSamplesLoaded ,
312383 * params );
313384 } else if (coverParams ) {
314385 if (optimize ) {
315386 dictSize = ZDICT_optimizeTrainFromBuffer_cover (dictBuffer , maxDictSize ,
316- srcBuffer , sampleSizes , fs . nbSamples ,
387+ srcBuffer , sampleSizes , nbSamplesLoaded ,
317388 coverParams );
318389 if (!ZDICT_isError (dictSize )) {
319390 unsigned splitPercentage = (unsigned )(coverParams -> splitPoint * 100 );
@@ -322,13 +393,13 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
322393 }
323394 } else {
324395 dictSize = ZDICT_trainFromBuffer_cover (dictBuffer , maxDictSize , srcBuffer ,
325- sampleSizes , fs . nbSamples , * coverParams );
396+ sampleSizes , nbSamplesLoaded , * coverParams );
326397 }
327398 } else {
328399 assert (fastCoverParams != NULL );
329400 if (optimize ) {
330401 dictSize = ZDICT_optimizeTrainFromBuffer_fastCover (dictBuffer , maxDictSize ,
331- srcBuffer , sampleSizes , fs . nbSamples ,
402+ srcBuffer , sampleSizes , nbSamplesLoaded ,
332403 fastCoverParams );
333404 if (!ZDICT_isError (dictSize )) {
334405 unsigned splitPercentage = (unsigned )(fastCoverParams -> splitPoint * 100 );
@@ -338,7 +409,7 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
338409 }
339410 } else {
340411 dictSize = ZDICT_trainFromBuffer_fastCover (dictBuffer , maxDictSize , srcBuffer ,
341- sampleSizes , fs . nbSamples , * fastCoverParams );
412+ sampleSizes , nbSamplesLoaded , * fastCoverParams );
342413 }
343414 }
344415 if (ZDICT_isError (dictSize )) {
0 commit comments