-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Limit train samples #2809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Limit train samples #2809
Changes from 3 commits
d758afe
8b607bf
d21fd2e
474d126
9d17075
338309a
1d9211a
a463506
9eb56a3
9fdd983
4188559
8836197
4f0071a
e5a0a9f
4385973
cedc9e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,21 +49,25 @@ | |
| static const size_t g_maxMemory = (sizeof(size_t) == 4) ? (2 GB - 64 MB) : ((size_t)(512 MB) << sizeof(size_t)); | ||
|
|
||
| #define NOISELENGTH 32 | ||
| #define MAX_SAMPLES_SIZE (2 GB) /* training dataset limited to 2GB */ | ||
|
|
||
|
|
||
| /*-************************************* | ||
| * Console display | ||
| ***************************************/ | ||
| #define DISPLAY_LEVEL_DEFAULT 2 | ||
| static int g_displayLevel = DISPLAY_LEVEL_DEFAULT; | ||
stanjo74 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| #define DISPLAY(...) fprintf(stderr, __VA_ARGS__) | ||
| #define DISPLAYLEVEL(l, ...) if (displayLevel>=l) { DISPLAY(__VA_ARGS__); } | ||
| #define DISPLAYLEVEL(l, ...) if (g_displayLevel>=l) { DISPLAY(__VA_ARGS__); } | ||
|
|
||
| static const U64 g_refreshRate = SEC_TO_MICRO / 6; | ||
| static UTIL_time_t g_displayClock = UTIL_TIME_INITIALIZER; | ||
|
|
||
| #define DISPLAYUPDATE(l, ...) { if (displayLevel>=l) { \ | ||
| if ((UTIL_clockSpanMicro(g_displayClock) > g_refreshRate) || (displayLevel>=4)) \ | ||
| #define DISPLAYUPDATE(l, ...) { if (g_displayLevel>=l) { \ | ||
| if ((UTIL_clockSpanMicro(g_displayClock) > g_refreshRate) || (g_displayLevel>=4)) \ | ||
| { g_displayClock = UTIL_getTime(); DISPLAY(__VA_ARGS__); \ | ||
| if (displayLevel>=4) fflush(stderr); } } } | ||
| if (g_displayLevel>=4) fflush(stderr); } } } | ||
|
|
||
| /*-************************************* | ||
| * Exceptions | ||
|
|
@@ -88,6 +92,20 @@ static UTIL_time_t g_displayClock = UTIL_TIME_INITIALIZER; | |
| #undef MIN | ||
| #define MIN(a,b) ((a) < (b) ? (a) : (b)) | ||
|
|
||
| /** | ||
| Returns the size of a file. | ||
| If error returns 0. Zero filesize or error is same for us. | ||
| Emit warning when the file is inaccessible or zero size. | ||
stanjo74 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| */ | ||
| static size_t Dib_getFileSize (const char * fileName) | ||
stanjo74 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| { | ||
| size_t fileSize = UTIL_getFileSize(fileName); | ||
stanjo74 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (fileSize == UTIL_FILESIZE_UNKNOWN) | ||
| fileSize = 0; | ||
| return fileSize; | ||
| } | ||
|
|
||
|
|
||
|
|
||
| /* ******************************************************** | ||
| * File related operations | ||
|
|
@@ -102,46 +120,74 @@ static UTIL_time_t g_displayClock = UTIL_TIME_INITIALIZER; | |
| * sampleSizes is filled with the size of each sample. | ||
| */ | ||
| static unsigned DiB_loadFiles(void* buffer, size_t* bufferSizePtr, | ||
| size_t* sampleSizes, unsigned sstSize, | ||
| const char** fileNamesTable, unsigned nbFiles, size_t targetChunkSize, | ||
| unsigned displayLevel) | ||
| size_t* sampleSizes, int sstSize, | ||
| const char** fileNamesTable, int nbFiles, | ||
| size_t targetChunkSize ) | ||
| { | ||
| char* const buff = (char*)buffer; | ||
| size_t pos = 0; | ||
| unsigned nbLoadedChunks = 0, fileIndex; | ||
|
|
||
| for (fileIndex=0; fileIndex<nbFiles; fileIndex++) { | ||
| const char* const fileName = fileNamesTable[fileIndex]; | ||
| unsigned long long const fs64 = UTIL_getFileSize(fileName); | ||
| unsigned long long remainingToLoad = (fs64 == UTIL_FILESIZE_UNKNOWN) ? 0 : fs64; | ||
| U32 const nbChunks = targetChunkSize ? (U32)((fs64 + (targetChunkSize-1)) / targetChunkSize) : 1; | ||
| U64 const chunkSize = targetChunkSize ? MIN(targetChunkSize, fs64) : fs64; | ||
| size_t const maxChunkSize = (size_t)MIN(chunkSize, SAMPLESIZE_MAX); | ||
| U32 cnb; | ||
| FILE* const f = fopen(fileName, "rb"); | ||
| if (f==NULL) EXM_THROW(10, "zstd: dictBuilder: %s %s ", fileName, strerror(errno)); | ||
| DISPLAYUPDATE(2, "Loading %s... \r", fileName); | ||
| for (cnb=0; cnb<nbChunks; cnb++) { | ||
| size_t const toLoad = (size_t)MIN(maxChunkSize, remainingToLoad); | ||
| if (toLoad > *bufferSizePtr-pos) break; | ||
| { size_t const readSize = fread(buff+pos, 1, toLoad, f); | ||
| if (readSize != toLoad) EXM_THROW(11, "Pb reading %s", fileName); | ||
| pos += readSize; | ||
| sampleSizes[nbLoadedChunks++] = toLoad; | ||
| remainingToLoad -= targetChunkSize; | ||
| if (nbLoadedChunks == sstSize) { /* no more space left in sampleSizes table */ | ||
| fileIndex = nbFiles; /* stop there */ | ||
| char * buff = (char*)buffer; | ||
stanjo74 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| size_t totalDataLoaded = 0; | ||
| int nbSamplesLoaded = 0; | ||
| int fileIndex = 0; | ||
| FILE * f = NULL; | ||
stanjo74 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| assert(targetChunkSize <= SAMPLESIZE_MAX); | ||
|
|
||
| while ( nbSamplesLoaded < sstSize | ||
| && fileIndex < nbFiles ) | ||
| { | ||
| size_t const fileSize = Dib_getFileSize(fileNamesTable[fileIndex]); | ||
| if (fileSize == 0) | ||
| continue; | ||
|
|
||
| f = fopen( fileNamesTable[fileIndex], "rb"); | ||
| if (f == NULL) | ||
| EXM_THROW(10, "zstd: dictBuilder: %s %s ", fileNamesTable[fileIndex], strerror(errno)); | ||
| DISPLAYUPDATE(2, "Loading %s... \r", fileNamesTable[fileIndex]); | ||
|
|
||
| /* Load the first chunk of data from the file */ | ||
| size_t const chunkSize = targetChunkSize > 0 ? | ||
| MIN(fileSize, targetChunkSize) : | ||
| MIN(fileSize, SAMPLESIZE_MAX ); | ||
| if (totalDataLoaded + chunkSize > *bufferSizePtr) | ||
| break; | ||
|
|
||
| size_t fileDataLoaded = fread( | ||
| buff+totalDataLoaded, 1, chunkSize, f ); | ||
| if (fileDataLoaded != chunkSize) | ||
| EXM_THROW(11, "Pb reading %s", fileNamesTable[fileIndex]); | ||
| sampleSizes[nbSamplesLoaded++] = fileDataLoaded; | ||
| totalDataLoaded += fileDataLoaded; | ||
|
|
||
| /* If file-chunking is enabled, load the rest of the file as more samples */ | ||
| if (targetChunkSize > 0) { | ||
| while( fileDataLoaded < fileSize && nbSamplesLoaded < sstSize ) { | ||
|
|
||
| size_t const chunkSize = MIN(fileSize-fileDataLoaded, targetChunkSize); | ||
| if (chunkSize == 0) /* no more to read */ | ||
| break; | ||
| } | ||
| if (toLoad < targetChunkSize) { | ||
| fseek(f, (long)(targetChunkSize - toLoad), SEEK_CUR); | ||
| } } } | ||
| fclose(f); | ||
| if (totalDataLoaded + chunkSize > *bufferSizePtr) /* buffer is full */ | ||
| break; | ||
|
|
||
| size_t chunkDataLoaded = fread( | ||
| buff+totalDataLoaded, 1, chunkSize, f ); | ||
| if (chunkDataLoaded != chunkSize) | ||
| EXM_THROW(11, "Pb reading %s", fileNamesTable[fileIndex]); | ||
|
|
||
| sampleSizes[nbSamplesLoaded++] = chunkDataLoaded; | ||
| totalDataLoaded += chunkDataLoaded; | ||
| fileDataLoaded += chunkDataLoaded; | ||
| } | ||
| } | ||
| fileIndex += 1; | ||
| fclose(f); f = NULL; | ||
| } | ||
| if (f != NULL) | ||
| fclose(f); | ||
|
|
||
| DISPLAYLEVEL(2, "\r%79s\r", ""); | ||
| *bufferSizePtr = pos; | ||
| DISPLAYLEVEL(4, "loaded : %u KB \n", (unsigned)(pos >> 10)) | ||
| return nbLoadedChunks; | ||
| DISPLAYLEVEL(4, "loaded : %u KB \n", (unsigned)(totalDataLoaded >> 10)) | ||
| *bufferSizePtr = totalDataLoaded; | ||
| return nbSamplesLoaded; | ||
| } | ||
|
|
||
| #define DiB_rotl32(x,r) ((x << r) | (x >> (32 - r))) | ||
|
|
@@ -223,7 +269,6 @@ static void DiB_saveDict(const char* dictFileName, | |
| if (n!=0) EXM_THROW(5, "%s : flush error", dictFileName) } | ||
| } | ||
|
|
||
|
|
||
| typedef struct { | ||
| U64 totalSizeToLoad; | ||
| unsigned oneSampleTooLarge; | ||
|
|
@@ -235,22 +280,47 @@ typedef struct { | |
| * provides the amount of data to be loaded and the resulting nb of samples. | ||
| * This is useful primarily for allocation purpose => sample buffer, and sample sizes table. | ||
| */ | ||
| static fileStats DiB_fileStats(const char** fileNamesTable, unsigned nbFiles, size_t chunkSize, unsigned displayLevel) | ||
| static fileStats DiB_fileStats(const char** fileNamesTable, unsigned nbFiles, size_t chunkSize) | ||
| { | ||
| fileStats fs; | ||
| unsigned n; | ||
| memset(&fs, 0, sizeof(fs)); | ||
|
|
||
| // We assume that if chunking is requsted, the chunk size is < SAMPLESIZE_MAX | ||
| assert( chunkSize <= SAMPLESIZE_MAX ); | ||
|
|
||
| for (n=0; n<nbFiles; n++) { | ||
| U64 const fileSize = UTIL_getFileSize(fileNamesTable[n]); | ||
| U64 const srcSize = (fileSize == UTIL_FILESIZE_UNKNOWN) ? 0 : fileSize; | ||
| U32 const nbSamples = (U32)(chunkSize ? (srcSize + (chunkSize-1)) / chunkSize : 1); | ||
| U64 const chunkToLoad = chunkSize ? MIN(chunkSize, srcSize) : srcSize; | ||
| size_t const cappedChunkSize = (size_t)MIN(chunkToLoad, SAMPLESIZE_MAX); | ||
| fs.totalSizeToLoad += cappedChunkSize * nbSamples; | ||
| fs.oneSampleTooLarge |= (chunkSize > 2*SAMPLESIZE_MAX); | ||
| fs.nbSamples += nbSamples; | ||
| U64 fileSize = Dib_getFileSize(fileNamesTable[n]); | ||
| if (fileSize == 0) { | ||
| DISPLAYLEVEL(3, "Sample file '%s' has zero size, skipping...\n", fileNamesTable[n]); | ||
| continue; | ||
| } | ||
|
|
||
| /* the case where we are breaking up files in sample chunks */ | ||
| if (chunkSize > 0) | ||
| { | ||
| size_t nbWholeChunks = fileSize / chunkSize; | ||
| size_t leftoverChunkSize = fileSize % chunkSize; | ||
| fs.nbSamples += nbWholeChunks + (leftoverChunkSize > 0); | ||
| fs.totalSizeToLoad += nbWholeChunks * chunkSize + leftoverChunkSize; | ||
| } | ||
| else { | ||
| /* the case where one file is one sample */ | ||
| if (fileSize > SAMPLESIZE_MAX) { | ||
| /* flag excessively large smaple files */ | ||
stanjo74 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| fs.oneSampleTooLarge |= (fileSize > 2*SAMPLESIZE_MAX); | ||
|
|
||
| /* Limit to the first SAMPLESIZE_MAX (128kB) of the file */ | ||
| DISPLAYLEVEL(3, "Sample file '%s' is too large, limiting to %ukB", | ||
| fileNamesTable[n], SAMPLESIZE_MAX >> 10); | ||
| fileSize = SAMPLESIZE_MAX; | ||
| } | ||
| fs.nbSamples += 1; | ||
| fs.totalSizeToLoad += fileSize; | ||
| } | ||
| } | ||
| DISPLAYLEVEL(4, "Preparing to load : %u KB \n", (unsigned)(fs.totalSizeToLoad >> 10)); | ||
| DISPLAYLEVEL(4, "Number of samples %u\n", fs.nbSamples ); | ||
| return fs; | ||
| } | ||
|
|
||
|
|
@@ -260,18 +330,30 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize, | |
| ZDICT_legacy_params_t* params, ZDICT_cover_params_t* coverParams, | ||
| ZDICT_fastCover_params_t* fastCoverParams, int optimize) | ||
| { | ||
| unsigned const displayLevel = params ? params->zParams.notificationLevel : | ||
| g_displayLevel = params ? params->zParams.notificationLevel : | ||
| coverParams ? coverParams->zParams.notificationLevel : | ||
| fastCoverParams ? fastCoverParams->zParams.notificationLevel : | ||
| 0; /* should never happen */ | ||
| void* const dictBuffer = malloc(maxDictSize); | ||
| fileStats const fs = DiB_fileStats(fileNamesTable, nbFiles, chunkSize, displayLevel); | ||
|
|
||
| /* Shuffle input files before we start assessing how much sample date to load. | ||
| The purpose of the shuffle is to pick random samples when the sample | ||
| set is larger than what we can load in memory. */ | ||
|
||
| DISPLAYLEVEL(3, "Shuffling input files\n"); | ||
| DiB_shuffle(fileNamesTable, nbFiles); | ||
|
|
||
| /* Figure out how much sample data to load with how many samples */ | ||
| fileStats const fs = DiB_fileStats(fileNamesTable, nbFiles, chunkSize); | ||
|
|
||
| size_t* const sampleSizes = (size_t*)malloc(fs.nbSamples * sizeof(size_t)); | ||
| size_t const memMult = params ? MEMMULT : | ||
stanjo74 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| coverParams ? COVER_MEMMULT: | ||
| FASTCOVER_MEMMULT; | ||
| size_t const maxMem = DiB_findMaxMem(fs.totalSizeToLoad * memMult) / memMult; | ||
| size_t loadedSize = (size_t) MIN ((unsigned long long)maxMem, fs.totalSizeToLoad); | ||
| /* Limit the size of the training data to the free memory */ | ||
| /* Limit the size of the training data to 2GB */ | ||
| /* TODO: there is oportunity to stop DiB_fileStats() early when the data limit is reached */ | ||
| size_t loadedSize = MIN( MIN(maxMem, fs.totalSizeToLoad), MAX_SAMPLES_SIZE ); | ||
| void* const srcBuffer = malloc(loadedSize+NOISELENGTH); | ||
| int result = 0; | ||
|
|
||
|
|
@@ -296,13 +378,12 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize, | |
|
|
||
| /* init */ | ||
| if (loadedSize < fs.totalSizeToLoad) | ||
| DISPLAYLEVEL(1, "Not enough memory; training on %u MB only...\n", (unsigned)(loadedSize >> 20)); | ||
| DISPLAYLEVEL(1, "Trainig samples set too large (%u MB); training on %u MB only...\n", | ||
| (unsigned)(fs.totalSizeToLoad >> 20), | ||
| (unsigned)(loadedSize >> 20)); | ||
|
|
||
| /* Load input buffer */ | ||
| DISPLAYLEVEL(3, "Shuffling input files\n"); | ||
| DiB_shuffle(fileNamesTable, nbFiles); | ||
|
|
||
| DiB_loadFiles(srcBuffer, &loadedSize, sampleSizes, fs.nbSamples, fileNamesTable, nbFiles, chunkSize, displayLevel); | ||
| DiB_loadFiles(srcBuffer, &loadedSize, sampleSizes, fs.nbSamples, fileNamesTable, nbFiles, chunkSize); | ||
|
|
||
| { size_t dictSize; | ||
| if (params) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.