Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions mpl/_mpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch

def compute_weights(losses, indices, weights, ratio, p):
size = losses.size(0)

# find first nonzero element
pos = 0
while losses[pos]< 1e-5:
pos += 1
n = size - pos
m = int(ratio * n)
if n <= 0 or m <= 0:
raise ValueError
q = p / (p - 1.0)
c = m - n + 1
a = [0.0 , 0.0]
i = pos
nu = 0.0
while i < n and nu < 1e-5:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, the main reason for implement compute_weight in C++ is that python's loops are too slow and this loop will be a too slow for large batch size. Can you provide performance report which show that python code isn't slower than c++ implementation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. However i haven't compared yet. Thanks for your patience again.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cann't run build.py. So i try the python version , but it doesn't work. Have you succeed?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@01lin currently this implementation is not compatible with pytorch 0.5+

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@belbes With pytorch 0.4.0 and python3, I also cann't run build.py. And the error is " cffi.error.VerificationError: CompileError: command 'gcc' failed with exit status 1". Can you help me?

loss_q = (losses[i] / losses[size - 1]) ** q
a[0] = a[1]
a[1] += loss_q
c += 1
nu = c * loss_q - a[1]

# compute alpha
if nu < 1e-5:
i += 1
c += 1
a[0] = a[1]
alpha = (a[0] / c) ** (1 / q) * losses[size - 1]

# compute_weights
tau = 1.0 / (n ** (1.0 / q)*(m **(1.0 / p)))
k = i
while k < n:
# maybe wrong
weights[indices[k]] = tau
k += 1
if alpha > -1e-5:
k = pos
while k < i:
weights[indices[k]] = tau * (losses[k] / alpha) ** (q - 1)
k += 1
20 changes: 0 additions & 20 deletions mpl/build.py

This file was deleted.

62 changes: 0 additions & 62 deletions mpl/src/lib_mpl.cpp

This file was deleted.

5 changes: 0 additions & 5 deletions mpl/src/lib_mpl.h

This file was deleted.