Commit e736d05
authored
Fix UMAP outlier issue by checking for outliers and shuffling (#7131)
Closing #6454
Main difference between out simplicial set embedding and CPU UMAP was in negative sampling.
We should use updated values (value after adding gradients) in the negative sampling stage.
Dispatched to two kernels (and three usages) based on `n_components. Fixed like below.
- `optimize_batch_kernel_reg` (`n_components=2`): update the `current_reg` register value (used later in the negative sampling stage) along with `grads`
- `optimize_batch_kernel` (with shared memory): distinguish `current_buffer` (which used to JUST hold the gradient) from the `grad_buffer`. Now `current_buffer` and `grad_buffer` corresponds to the `current_reg` and `grads` registers in the register-approch kernel.
- `optimize_batch_kernel` (without shared memory): untouched because the grads are applied directly to global memory. This updated value in global memory is read directly for negative sampling later on.
## Visualizations 2D
50K samples random selected for plotting.
From the left
- CPU KNN + CPU UMAP
- GPU KNN + CPU UMAP
- GPU KNN + GPU UMAP Before fix
- GPU KNN + GPU UMAP After fix in this PR
Using dataset 639K x 384
<img width="2400" height="600" alt="unique_embeddings_Beauty_comparison" src="https://github.com/user-attachments/assets/2b687c82-4a2d-4288-bcaa-d95d54a1b8ae" />
Using dataset 1.8M x 384
<img width="2400" height="600" alt="unique_embeddings_Appliances_comparison" src="https://github.com/user-attachments/assets/66e94360-6a55-4d37-8851-69c00e485685" />
## Visualizations 3D
50K samples random selected for plotting.
Plotting the same dataset with `n_components=3` (Which uses the second kernel).
From the left
- GPU KNN + CPU UMAP
- GPU KNN + GPU UMAP Before fix
- GPU KNN + GPU UMAP After fix in this PR
Using dataset 639K x 384 (was already doing pretty well without outliers, still doing well)
<img width="1905" height="666" alt="Screenshot 2025-08-25 at 1 16 37 PM" src="https://github.com/user-attachments/assets/edbfec64-ae9a-45f6-84b4-cc7e3c431884" />
Using dataset 1.8M x 384
before fix had outliers.
<img width="1768" height="716" alt="Screenshot 2025-08-25 at 1 22 41 PM" src="https://github.com/user-attachments/assets/cfcffc8c-0ee3-4ad8-81f3-692483fec70e" />
Authors:
- Jinsol Park (https://github.com/jinsolp)
- Dante Gama Dessavre (https://github.com/dantegd)
- Simon Adorf (https://github.com/csadorf)
Approvers:
- Victor Lafargue (https://github.com/viclafargue)
- Divye Gala (https://github.com/divyegala)
- Simon Adorf (https://github.com/csadorf)
URL: #71311 parent e5adc43 commit e736d05
3 files changed
Lines changed: 185 additions & 21 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| 24 | + | |
24 | 25 | | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
28 | 29 | | |
| 30 | + | |
29 | 31 | | |
30 | 32 | | |
31 | 33 | | |
| 34 | + | |
32 | 35 | | |
33 | 36 | | |
| 37 | + | |
34 | 38 | | |
| 39 | + | |
35 | 40 | | |
| 41 | + | |
36 | 42 | | |
37 | 43 | | |
38 | 44 | | |
| |||
185 | 191 | | |
186 | 192 | | |
187 | 193 | | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
188 | 235 | | |
189 | 236 | | |
190 | 237 | | |
| |||
199 | 246 | | |
200 | 247 | | |
201 | 248 | | |
202 | | - | |
203 | | - | |
| 249 | + | |
| 250 | + | |
204 | 251 | | |
205 | 252 | | |
206 | 253 | | |
| |||
213 | 260 | | |
214 | 261 | | |
215 | 262 | | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
216 | 296 | | |
217 | 297 | | |
218 | 298 | | |
| |||
250 | 330 | | |
251 | 331 | | |
252 | 332 | | |
253 | | - | |
254 | | - | |
255 | 333 | | |
256 | 334 | | |
257 | 335 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
156 | 156 | | |
157 | 157 | | |
158 | 158 | | |
159 | | - | |
| 159 | + | |
| 160 | + | |
160 | 161 | | |
161 | 162 | | |
162 | 163 | | |
| |||
200 | 201 | | |
201 | 202 | | |
202 | 203 | | |
| 204 | + | |
203 | 205 | | |
204 | 206 | | |
205 | 207 | | |
| |||
252 | 254 | | |
253 | 255 | | |
254 | 256 | | |
| 257 | + | |
| 258 | + | |
255 | 259 | | |
256 | | - | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
257 | 268 | | |
258 | 269 | | |
259 | 270 | | |
| |||
267 | 278 | | |
268 | 279 | | |
269 | 280 | | |
270 | | - | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
271 | 284 | | |
272 | | - | |
273 | | - | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
274 | 288 | | |
275 | 289 | | |
276 | 290 | | |
| |||
282 | 296 | | |
283 | 297 | | |
284 | 298 | | |
285 | | - | |
| 299 | + | |
286 | 300 | | |
287 | 301 | | |
288 | 302 | | |
| |||
299 | 313 | | |
300 | 314 | | |
301 | 315 | | |
302 | | - | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
303 | 321 | | |
304 | 322 | | |
305 | 323 | | |
| |||
313 | 331 | | |
314 | 332 | | |
315 | 333 | | |
316 | | - | |
317 | | - | |
318 | | - | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
319 | 343 | | |
320 | 344 | | |
321 | | - | |
322 | | - | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
323 | 348 | | |
324 | 349 | | |
325 | 350 | | |
326 | 351 | | |
327 | 352 | | |
328 | 353 | | |
329 | 354 | | |
330 | | - | |
| 355 | + | |
331 | 356 | | |
332 | 357 | | |
333 | | - | |
334 | | - | |
| 358 | + | |
335 | 359 | | |
336 | 360 | | |
337 | 361 | | |
| |||
373 | 397 | | |
374 | 398 | | |
375 | 399 | | |
376 | | - | |
| 400 | + | |
377 | 401 | | |
378 | 402 | | |
379 | 403 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
32 | | - | |
| 32 | + | |
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
| |||
924 | 924 | | |
925 | 925 | | |
926 | 926 | | |
| 927 | + | |
| 928 | + | |
| 929 | + | |
| 930 | + | |
| 931 | + | |
| 932 | + | |
| 933 | + | |
| 934 | + | |
| 935 | + | |
| 936 | + | |
| 937 | + | |
| 938 | + | |
| 939 | + | |
| 940 | + | |
| 941 | + | |
| 942 | + | |
| 943 | + | |
| 944 | + | |
| 945 | + | |
| 946 | + | |
| 947 | + | |
| 948 | + | |
| 949 | + | |
| 950 | + | |
| 951 | + | |
| 952 | + | |
| 953 | + | |
| 954 | + | |
| 955 | + | |
| 956 | + | |
| 957 | + | |
| 958 | + | |
| 959 | + | |
| 960 | + | |
| 961 | + | |
| 962 | + | |
| 963 | + | |
| 964 | + | |
| 965 | + | |
| 966 | + | |
| 967 | + | |
| 968 | + | |
| 969 | + | |
| 970 | + | |
| 971 | + | |
| 972 | + | |
| 973 | + | |
| 974 | + | |
| 975 | + | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
| 980 | + | |
| 981 | + | |
| 982 | + | |
| 983 | + | |
| 984 | + | |
| 985 | + | |
| 986 | + | |
| 987 | + | |
| 988 | + | |
0 commit comments