1 /* strstr optimized with 512-bit AVX-512 instructions
2 Copyright (C) 2022 Free Software Foundation, Inc.
3 This file is part of the GNU C Library.
5 The GNU C Library is free software; you can redistribute it and/or
6 modify it under the terms of the GNU Lesser General Public
7 License as published by the Free Software Foundation; either
8 version 2.1 of the License, or (at your option) any later version.
10 The GNU C Library is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13 Lesser General Public License for more details.
15 You should have received a copy of the GNU Lesser General Public
16 License along with the GNU C Library; if not, see
17 <https://www.gnu.org/licenses/>. */
19 #include <immintrin.h>
24 #define FULL_MMASK64 0xffffffffffffffff
25 #define ONE_64BIT 0x1ull
26 #define ZMM_SIZE_IN_BYTES 64
29 #define cvtmask64_u64(...) (uint64_t) (__VA_ARGS__)
30 #define kshiftri_mask64(x, y) ((x) >> (y))
31 #define kand_mask64(x, y) ((x) & (y))
34 Returns the index of the first edge within the needle, returns 0 if no edge
35 is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg'
38 find_edge_in_needle (const char *ned
)
41 while (ned
[ind
+ 1] != '\0')
43 if (ned
[ind
] != ned
[ind
+ 1])
52 Compare needle with haystack byte by byte at specified location
55 verify_string_match (const char *hay
, const size_t hay_index
, const char *ned
,
58 while (ned
[ind
] != '\0')
60 if (ned
[ind
] != hay
[hay_index
+ ind
])
68 Compare needle with haystack at specified location. The first 64 bytes are
69 compared using a ZMM register.
72 verify_string_match_avx512 (const char *hay
, const size_t hay_index
,
73 const char *ned
, const __mmask64 ned_mask
,
74 const __m512i ned_zmm
)
76 /* check first 64 bytes using zmm and then scalar */
77 __m512i hay_zmm
= _mm512_loadu_si512 (hay
+ hay_index
); // safe to do so
78 __mmask64 match
= _mm512_mask_cmpneq_epi8_mask (ned_mask
, hay_zmm
, ned_zmm
);
79 if (match
!= 0x0) // failed the first few chars
81 else if (ned_mask
== FULL_MMASK64
)
82 return verify_string_match (hay
, hay_index
, ned
, ZMM_SIZE_IN_BYTES
);
87 __strstr_avx512 (const char *haystack
, const char *ned
)
91 return (char *)haystack
;
93 return (char *)strchr (haystack
, ned
[0]);
95 size_t edge
= find_edge_in_needle (ned
);
97 /* ensure haystack is as long as the pos of edge in needle */
98 for (int ii
= 0; ii
< edge
; ++ii
)
100 if (haystack
[ii
] == '\0')
105 Load 64 bytes of the needle and save it to a zmm register
106 Read one cache line at a time to avoid loading across a page boundary
108 __mmask64 ned_load_mask
= _bzhi_u64 (
109 FULL_MMASK64
, 64 - ((uintptr_t) (ned
) & 63));
110 __m512i ned_zmm
= _mm512_maskz_loadu_epi8 (ned_load_mask
, ned
);
111 __mmask64 ned_nullmask
112 = _mm512_mask_testn_epi8_mask (ned_load_mask
, ned_zmm
, ned_zmm
);
114 if (__glibc_unlikely (ned_nullmask
== 0x0))
116 ned_zmm
= _mm512_loadu_si512 (ned
);
117 ned_nullmask
= _mm512_testn_epi8_mask (ned_zmm
, ned_zmm
);
118 ned_load_mask
= ned_nullmask
^ (ned_nullmask
- ONE_64BIT
);
119 if (ned_nullmask
!= 0x0)
120 ned_load_mask
= ned_load_mask
>> 1;
124 ned_load_mask
= ned_nullmask
^ (ned_nullmask
- ONE_64BIT
);
125 ned_load_mask
= ned_load_mask
>> 1;
127 const __m512i ned0
= _mm512_set1_epi8 (ned
[edge
]);
128 const __m512i ned1
= _mm512_set1_epi8 (ned
[edge
+ 1]);
131 Read the bytes of haystack in the current cache line
133 size_t hay_index
= edge
;
134 __mmask64 loadmask
= _bzhi_u64 (
135 FULL_MMASK64
, 64 - ((uintptr_t) (haystack
+ hay_index
) & 63));
136 /* First load is a partial cache line */
137 __m512i hay0
= _mm512_maskz_loadu_epi8 (loadmask
, haystack
+ hay_index
);
138 /* Search for NULL and compare only till null char */
140 = cvtmask64_u64 (_mm512_mask_testn_epi8_mask (loadmask
, hay0
, hay0
));
141 uint64_t cmpmask
= nullmask
^ (nullmask
- ONE_64BIT
);
142 cmpmask
= cmpmask
& cvtmask64_u64 (loadmask
);
143 /* Search for the 2 charaters of needle */
144 __mmask64 k0
= _mm512_cmpeq_epi8_mask (hay0
, ned0
);
145 __mmask64 k1
= _mm512_cmpeq_epi8_mask (hay0
, ned1
);
146 k1
= kshiftri_mask64 (k1
, 1);
147 /* k2 masks tell us if both chars from needle match */
148 uint64_t k2
= cvtmask64_u64 (kand_mask64 (k0
, k1
)) & cmpmask
;
149 /* For every match, search for the entire needle for a full match */
152 uint64_t bitcount
= _tzcnt_u64 (k2
);
154 size_t match_pos
= hay_index
+ bitcount
- edge
;
155 if (((uintptr_t) (haystack
+ match_pos
) & (PAGESIZE
- 1))
156 < PAGESIZE
- 1 - ZMM_SIZE_IN_BYTES
)
159 * Use vector compare as long as you are not crossing a page
161 if (verify_string_match_avx512 (haystack
, match_pos
, ned
,
162 ned_load_mask
, ned_zmm
))
163 return (char *)haystack
+ match_pos
;
167 if (verify_string_match (haystack
, match_pos
, ned
, 0))
168 return (char *)haystack
+ match_pos
;
171 /* We haven't checked for potential match at the last char yet */
172 haystack
= (const char *)(((uintptr_t) (haystack
+ hay_index
) | 63));
176 Loop over one cache line at a time to prevent reading over page
180 while (nullmask
== 0)
182 hay0
= _mm512_loadu_si512 (haystack
+ hay_index
);
183 hay1
= _mm512_load_si512 (haystack
+ hay_index
184 + 1); // Always 64 byte aligned
185 nullmask
= cvtmask64_u64 (_mm512_testn_epi8_mask (hay1
, hay1
));
186 /* Compare only till null char */
187 cmpmask
= nullmask
^ (nullmask
- ONE_64BIT
);
188 k0
= _mm512_cmpeq_epi8_mask (hay0
, ned0
);
189 k1
= _mm512_cmpeq_epi8_mask (hay1
, ned1
);
190 /* k2 masks tell us if both chars from needle match */
191 k2
= cvtmask64_u64 (kand_mask64 (k0
, k1
)) & cmpmask
;
192 /* For every match, compare full strings for potential match */
195 uint64_t bitcount
= _tzcnt_u64 (k2
);
197 size_t match_pos
= hay_index
+ bitcount
- edge
;
198 if (((uintptr_t) (haystack
+ match_pos
) & (PAGESIZE
- 1))
199 < PAGESIZE
- 1 - ZMM_SIZE_IN_BYTES
)
202 * Use vector compare as long as you are not crossing a page
204 if (verify_string_match_avx512 (haystack
, match_pos
, ned
,
205 ned_load_mask
, ned_zmm
))
206 return (char *)haystack
+ match_pos
;
210 /* Compare byte by byte */
211 if (verify_string_match (haystack
, match_pos
, ned
, 0))
212 return (char *)haystack
+ match_pos
;
215 hay_index
+= ZMM_SIZE_IN_BYTES
;