Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
317 changes: 112 additions & 205 deletions lib/compress/zstd_lazy.c
Original file line number Diff line number Diff line change
Expand Up @@ -865,204 +865,32 @@ FORCE_INLINE_TEMPLATE size_t ZSTD_HcFindBestMatch_extDict_selectMLS (
* (SIMD) Row-based matchfinder
***********************************/
/* Constants for row-based hash */
#define ZSTD_ROW_HASH_TAG_OFFSET 1 /* byte offset of hashes in the match state's tagTable from the beginning of a row */
#define ZSTD_ROW_HASH_TAG_OFFSET 16 /* byte offset of hashes in the match state's tagTable from the beginning of a row */
#define ZSTD_ROW_HASH_TAG_BITS 8 /* nb bits to use for the tag */
#define ZSTD_ROW_HASH_TAG_MASK ((1u << ZSTD_ROW_HASH_TAG_BITS) - 1)

#define ZSTD_ROW_HASH_CACHE_MASK (ZSTD_ROW_HASH_CACHE_SIZE - 1)

typedef U32 ZSTD_VecMask; /* Clarifies when we are interacting with a U32 representing a mask of matches */

#if !defined(ZSTD_NO_INTRINSICS) && (defined(__SSE2__) || defined(_M_AMD64)) /* SIMD SSE version*/

#include <emmintrin.h>
typedef __m128i ZSTD_Vec128;

/* Returns a 128-bit container with 128-bits from src */
static ZSTD_Vec128 ZSTD_Vec128_read(const void* const src) {
return _mm_loadu_si128((ZSTD_Vec128 const*)src);
}

/* Returns a ZSTD_Vec128 with the byte "val" packed 16 times */
static ZSTD_Vec128 ZSTD_Vec128_set8(BYTE val) {
return _mm_set1_epi8((char)val);
}

/* Do byte-by-byte comparison result of x and y. Then collapse 128-bit resultant mask
* into a 32-bit mask that is the MSB of each byte.
* */
static ZSTD_VecMask ZSTD_Vec128_cmpMask8(ZSTD_Vec128 x, ZSTD_Vec128 y) {
return (ZSTD_VecMask)_mm_movemask_epi8(_mm_cmpeq_epi8(x, y));
}

typedef struct {
__m128i fst;
__m128i snd;
} ZSTD_Vec256;

static ZSTD_Vec256 ZSTD_Vec256_read(const void* const ptr) {
ZSTD_Vec256 v;
v.fst = ZSTD_Vec128_read(ptr);
v.snd = ZSTD_Vec128_read((ZSTD_Vec128 const*)ptr + 1);
return v;
}

static ZSTD_Vec256 ZSTD_Vec256_set8(BYTE val) {
ZSTD_Vec256 v;
v.fst = ZSTD_Vec128_set8(val);
v.snd = ZSTD_Vec128_set8(val);
return v;
}

static ZSTD_VecMask ZSTD_Vec256_cmpMask8(ZSTD_Vec256 x, ZSTD_Vec256 y) {
ZSTD_VecMask fstMask;
ZSTD_VecMask sndMask;
fstMask = ZSTD_Vec128_cmpMask8(x.fst, y.fst);
sndMask = ZSTD_Vec128_cmpMask8(x.snd, y.snd);
return fstMask | (sndMask << 16);
}

#elif !defined(ZSTD_NO_INTRINSICS) && defined(__ARM_NEON) /* SIMD ARM NEON Version */

#include <arm_neon.h>
typedef uint8x16_t ZSTD_Vec128;

static ZSTD_Vec128 ZSTD_Vec128_read(const void* const src) {
return vld1q_u8((const BYTE* const)src);
}

static ZSTD_Vec128 ZSTD_Vec128_set8(BYTE val) {
return vdupq_n_u8(val);
}

/* Mimics '_mm_movemask_epi8()' from SSE */
static U32 ZSTD_vmovmaskq_u8(ZSTD_Vec128 val) {
/* Shift out everything but the MSB bits in each byte */
uint16x8_t highBits = vreinterpretq_u16_u8(vshrq_n_u8(val, 7));
/* Merge the even lanes together with vsra (right shift and add) */
uint32x4_t paired16 = vreinterpretq_u32_u16(vsraq_n_u16(highBits, highBits, 7));
uint64x2_t paired32 = vreinterpretq_u64_u32(vsraq_n_u32(paired16, paired16, 14));
uint8x16_t paired64 = vreinterpretq_u8_u64(vsraq_n_u64(paired32, paired32, 28));
/* Extract the low 8 bits from each lane, merge */
return vgetq_lane_u8(paired64, 0) | ((U32)vgetq_lane_u8(paired64, 8) << 8);
}

static ZSTD_VecMask ZSTD_Vec128_cmpMask8(ZSTD_Vec128 x, ZSTD_Vec128 y) {
return (ZSTD_VecMask)ZSTD_vmovmaskq_u8(vceqq_u8(x, y));
}

typedef struct {
uint8x16_t fst;
uint8x16_t snd;
} ZSTD_Vec256;

static ZSTD_Vec256 ZSTD_Vec256_read(const void* const ptr) {
ZSTD_Vec256 v;
v.fst = ZSTD_Vec128_read(ptr);
v.snd = ZSTD_Vec128_read((ZSTD_Vec128 const*)ptr + 1);
return v;
}

static ZSTD_Vec256 ZSTD_Vec256_set8(BYTE val) {
ZSTD_Vec256 v;
v.fst = ZSTD_Vec128_set8(val);
v.snd = ZSTD_Vec128_set8(val);
return v;
}

static ZSTD_VecMask ZSTD_Vec256_cmpMask8(ZSTD_Vec256 x, ZSTD_Vec256 y) {
ZSTD_VecMask fstMask;
ZSTD_VecMask sndMask;
fstMask = ZSTD_Vec128_cmpMask8(x.fst, y.fst);
sndMask = ZSTD_Vec128_cmpMask8(x.snd, y.snd);
return fstMask | (sndMask << 16);
}

#else /* Scalar fallback version */

#define VEC128_NB_SIZE_T (16 / sizeof(size_t))
typedef struct {
size_t vec[VEC128_NB_SIZE_T];
} ZSTD_Vec128;

static ZSTD_Vec128 ZSTD_Vec128_read(const void* const src) {
ZSTD_Vec128 ret;
ZSTD_memcpy(ret.vec, src, VEC128_NB_SIZE_T*sizeof(size_t));
return ret;
}

static ZSTD_Vec128 ZSTD_Vec128_set8(BYTE val) {
ZSTD_Vec128 ret = { {0} };
int startBit = sizeof(size_t) * 8 - 8;
for (;startBit >= 0; startBit -= 8) {
unsigned j = 0;
for (;j < VEC128_NB_SIZE_T; ++j) {
ret.vec[j] |= ((size_t)val << startBit);
}
}
return ret;
}

/* Compare x to y, byte by byte, generating a "matches" bitfield */
static ZSTD_VecMask ZSTD_Vec128_cmpMask8(ZSTD_Vec128 x, ZSTD_Vec128 y) {
ZSTD_VecMask res = 0;
unsigned i = 0;
unsigned l = 0;
for (; i < VEC128_NB_SIZE_T; ++i) {
const size_t cmp1 = x.vec[i];
const size_t cmp2 = y.vec[i];
unsigned j = 0;
for (; j < sizeof(size_t); ++j, ++l) {
if (((cmp1 >> j*8) & 0xFF) == ((cmp2 >> j*8) & 0xFF)) {
res |= ((U32)1 << (j+i*sizeof(size_t)));
}
}
}
return res;
}

#define VEC256_NB_SIZE_T 2*VEC128_NB_SIZE_T
typedef struct {
size_t vec[VEC256_NB_SIZE_T];
} ZSTD_Vec256;

static ZSTD_Vec256 ZSTD_Vec256_read(const void* const src) {
ZSTD_Vec256 ret;
ZSTD_memcpy(ret.vec, src, VEC256_NB_SIZE_T*sizeof(size_t));
return ret;
}

static ZSTD_Vec256 ZSTD_Vec256_set8(BYTE val) {
ZSTD_Vec256 ret = { {0} };
int startBit = sizeof(size_t) * 8 - 8;
for (;startBit >= 0; startBit -= 8) {
unsigned j = 0;
for (;j < VEC256_NB_SIZE_T; ++j) {
ret.vec[j] |= ((size_t)val << startBit);
}
}
return ret;
}

/* Compare x to y, byte by byte, generating a "matches" bitfield */
static ZSTD_VecMask ZSTD_Vec256_cmpMask8(ZSTD_Vec256 x, ZSTD_Vec256 y) {
ZSTD_VecMask res = 0;
unsigned i = 0;
unsigned l = 0;
for (; i < VEC256_NB_SIZE_T; ++i) {
const size_t cmp1 = x.vec[i];
const size_t cmp2 = y.vec[i];
unsigned j = 0;
for (; j < sizeof(size_t); ++j, ++l) {
if (((cmp1 >> j*8) & 0xFF) == ((cmp2 >> j*8) & 0xFF)) {
res |= ((U32)1 << (j+i*sizeof(size_t)));
}
}
}
return res;
}
#if !defined(ZSTD_NO_INTRINSICS)
# if defined(_M_AMD64) || (defined (_M_IX86) && defined(_M_IX86_FP) && (_M_IX86_FP >= 2))
# define ZSTD_ARCH_X86_SSE2
# endif
# if defined(__SSE2__)
# define ZSTD_ARCH_X86_SSE2
# endif
# if defined(__ARM_NEON)
# define ZSTD_ARCH_ARM_NEON
# endif
#
#
# if defined(ZSTD_ARCH_X86_SSE2)
# include <emmintrin.h>
# elif defined(ZSTD_ARCH_ARM_NEON)
# include <arm_neon.h>
# endif
#endif

#endif /* !defined(ZSTD_NO_INTRINSICS) && defined(__SSE2__) */
typedef U32 ZSTD_VecMask; /* Clarifies when we are interacting with a U32 representing a mask of matches */

/* ZSTD_VecMask_next():
* Starting from the LSB, returns the idx of the next non-zero bit.
Expand Down Expand Up @@ -1226,24 +1054,103 @@ void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip) {

/* Returns a ZSTD_VecMask (U32) that has the nth bit set to 1 if the newly-computed "tag" matches
* the hash at the nth position in a row of the tagTable.
*/
* Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield
* to match up with the actual layout of the entries within the hashTable */
FORCE_INLINE_TEMPLATE
ZSTD_VecMask ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, const U32 rowEntries) {
ZSTD_VecMask matches = 0;
const BYTE* const src = tagRow + ZSTD_ROW_HASH_TAG_OFFSET;
#if defined(ZSTD_ARCH_X86_SSE2)
assert(ZSTD_isAligned(src, 16));
if (rowEntries == 16) {
ZSTD_Vec128 hashes = ZSTD_Vec128_read(tagRow + ZSTD_ROW_HASH_TAG_OFFSET);
ZSTD_Vec128 expandedTags = ZSTD_Vec128_set8(tag);
matches = ZSTD_Vec128_cmpMask8(hashes, expandedTags);
} else if (rowEntries == 32) {
ZSTD_Vec256 hashes = ZSTD_Vec256_read(tagRow + ZSTD_ROW_HASH_TAG_OFFSET);
ZSTD_Vec256 expandedTags = ZSTD_Vec256_set8(tag);
matches = ZSTD_Vec256_cmpMask8(hashes, expandedTags);
const __m128i chunk = _mm_load_si128((const __m128i*)(const void*)src);
const __m128i equalMask = _mm_cmpeq_epi8(chunk, _mm_set1_epi8(tag));
const U32 matches = (U32)_mm_movemask_epi8(equalMask);
return ZSTD_VecMask_rotateRight(matches, head, 16);
} else {
assert(0);
const __m128i chunk0 = _mm_load_si128((const __m128i*)(const void*)&src[0]);
const __m128i chunk1 = _mm_load_si128((const __m128i*)(const void*)&src[16]);
const __m128i equalMask0 = _mm_cmpeq_epi8(chunk0, _mm_set1_epi8(tag));
const __m128i equalMask1 = _mm_cmpeq_epi8(chunk1, _mm_set1_epi8(tag));
const U32 lo = (U32)_mm_movemask_epi8(equalMask0);
const U32 hi = (U32)_mm_movemask_epi8(equalMask1);
assert(rowEntries == 32);
return ZSTD_VecMask_rotateRight((hi << 16) | lo, head, 32);
}
/* Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield
to match up with the actual layout of the entries within the hashTable */
return ZSTD_VecMask_rotateRight(matches, head, rowEntries);
#else
# if defined(ZSTD_ARCH_ARM_NEON)
if (MEM_isLittleEndian()) {
assert(ZSTD_isAligned(src, 16));
if (rowEntries == 16) {
const uint8x16_t chunk = vld1q_u8(src);
const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag)));
const uint16x8_t t0 = vshlq_n_u16(equalMask, 7);
const uint32x4_t t1 = vreinterpretq_u32_u16(vsriq_n_u16(t0, t0, 14));
const uint64x2_t t2 = vreinterpretq_u64_u32(vshrq_n_u32(t1, 14));
const uint8x16_t t3 = vreinterpretq_u8_u64(vsraq_n_u64(t2, t2, 28));
const U16 hi = (U16)vgetq_lane_u8(t3, 8);
const U16 lo = (U16)vgetq_lane_u8(t3, 0);
return ZSTD_VecMask_rotateRight((hi << 8) | lo, head, 16);
} else {
const uint16x8x2_t chunk = vld2q_u16((const U16*)(const void*)src);
const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]);
const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]);
const uint8x16_t equalMask0 = vceqq_u8(chunk0, vdupq_n_u8(tag));
const uint8x16_t equalMask1 = vceqq_u8(chunk1, vdupq_n_u8(tag));
const int8x8_t pack0 = vqmovn_s16(vreinterpretq_s16_u8(equalMask0));
const int8x8_t pack1 = vqmovn_s16(vreinterpretq_s16_u8(equalMask1));
const uint8x8_t t0 = vreinterpret_u8_s8(pack0);
const uint8x8_t t1 = vreinterpret_u8_s8(pack1);
const uint8x8_t t2 = vsri_n_u8(t1, t0, 2);
const uint8x8x2_t t3 = vuzp_u8(t2, t0);
const uint8x8_t t4 = vsri_n_u8(t3.val[1], t3.val[0], 4);
const U32 matches = vget_lane_u32(vreinterpret_u32_u8(t4), 0);
assert(rowEntries == 32);
return ZSTD_VecMask_rotateRight(matches, head, 32);
}
}
# endif
{ /* SWAR */
const size_t chunkSize = sizeof(size_t);
const size_t shiftAmount = ((chunkSize * 8) - chunkSize);
const size_t xFF = ~((size_t)0);
const size_t x01 = xFF / 0xFF;
const size_t x80 = x01 << 7;
const size_t splatChar = tag * x01;
size_t matches = 0;
int i = rowEntries - chunkSize;
assert((sizeof(size_t) == 8) || (sizeof(size_t) == 4));
assert((rowEntries == 32) || (rowEntries == 16));
if (MEM_isLittleEndian()) { /* runtime check so have two loops */
const size_t extractMagic = (xFF / 0x7F) >> chunkSize;
do {
size_t chunk = MEM_readST(&src[i]);
chunk ^= splatChar;
chunk = (((chunk | x80) - x01) | chunk) & x80;
matches <<= chunkSize;
matches |= (chunk * extractMagic) >> shiftAmount;
i -= chunkSize;
} while (i >= 0);
} else { /* big endian: reverse bits during extraction */
const size_t msb = xFF ^ (xFF >> 1);
const size_t extractMagic = (msb / 0x1FF) | msb;
do {
size_t chunk = MEM_readST(&src[i]);
chunk ^= splatChar;
chunk = (((chunk | x80) - x01) | chunk) & x80;
matches <<= chunkSize;
matches |= ((chunk >> 7) * extractMagic) >> shiftAmount;
i -= chunkSize;
} while (i >= 0);
}
matches = ~matches;
if (rowEntries == 16) {
return ZSTD_VecMask_rotateRight((U16)matches, head, 16);
} else {
assert(rowEntries == 32);
return ZSTD_VecMask_rotateRight((U32)matches, head, 32);
}
}
#endif
}

/* The high-level approach of the SIMD row based match finder is as follows:
Expand Down