Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions cmd/tckglobal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,17 +303,13 @@ void run ()
if (opt.size())
stats.open_stream(opt[0][0]);

auto dwi = header_in.get_image<float>().with_direct_io(3);
ParticleGrid pgrid (dwi);

ParticleGrid pgrid (header_in);
ExternalEnergyComputer* Eext = new ExternalEnergyComputer(stats, header_in, properties);
InternalEnergyComputer* Eint = new InternalEnergyComputer(stats, pgrid);
Eint->setConnPot(cpot);
EnergySumComputer* Esum = new EnergySumComputer(stats, Eint, properties.lam_int, Eext, properties.lam_ext / ( wmscale2 * properties.weight*properties.weight));

MHSampler mhs (dwi, properties, stats, pgrid, Esum, mask); // All EnergyComputers are recursively destroyed upon destruction of mhs, except for the shared data.


MHSampler mhs (header_in, properties, stats, pgrid, Esum, mask); // All EnergyComputers are recursively destroyed upon destruction of mhs, except for the shared data.
INFO("Start MH sampler");

Thread::run (Thread::multi(mhs), "MH sampler");
Expand Down Expand Up @@ -345,7 +341,7 @@ void run ()


// Save fiso, tod and eext
Header header_out (dwi);
Header header_out (header_in);
header_out.datatype() = DataType::Float32;

opt = get_options("fod");
Expand Down
31 changes: 31 additions & 0 deletions core/spinlock.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright (c) 2008-2025 the MRtrix3 contributors.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*
* Covered Software is provided under this License on an "as is"
* basis, without warranty of any kind, either expressed, implied, or
* statutory, including, without limitation, warranties that the
* Covered Software is free of defects, merchantable, fit for a
* particular purpose or non-infringing.
* See the Mozilla Public License v. 2.0 for more details.
*
* For more details, see http://www.mrtrix.org/.
*/

#include <atomic>
#include <mutex>

namespace MR {

class SpinLock
{ NOMEMALIGN
public:
void lock() noexcept { while (m_.test_and_set(std::memory_order_acquire)); }
void unlock() noexcept { m_.clear(std::memory_order_release); }
protected:
std::atomic_flag m_ = ATOMIC_FLAG_INIT;
};

}
2 changes: 1 addition & 1 deletion src/dwi/tractography/GT/externalenergy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace MR {
namespace Tractography {
namespace GT {

ExternalEnergyComputer::ExternalEnergyComputer(Stats& stat, Header& dwiheader, const Properties& props)
ExternalEnergyComputer::ExternalEnergyComputer(Stats& stat, Header &dwiheader, const Properties& props)
: EnergyComputer(stat),
dwi(dwiheader.get_image<float>().with_direct_io(3)),
T(Transform(dwiheader).scanner2voxel),
Expand Down
5 changes: 4 additions & 1 deletion src/dwi/tractography/GT/gt.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ namespace MR {
}

double getTint() const {
std::lock_guard<std::mutex> lock (mutex);
return Tint;
}

Expand All @@ -118,10 +119,12 @@ namespace MR {


double getEextTotal() const {
std::lock_guard<std::mutex> lock (mutex);
return EextTot;
}

double getEintTotal() const {
std::lock_guard<std::mutex> lock (mutex);
return EintTot;
}

Expand Down Expand Up @@ -197,7 +200,7 @@ namespace MR {


protected:
std::mutex mutex;
mutable std::mutex mutex;
double Text, Tint;
double EextTot, EintTot;
double alpha;
Expand Down
39 changes: 21 additions & 18 deletions src/dwi/tractography/GT/internalenergy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ namespace MR {
namespace DWI {
namespace Tractography {
namespace GT {


double InternalEnergyComputer::stageConnect(const ParticleEnd& pe1, ParticleEnd &pe2)
{
// new
Expand All @@ -37,31 +37,34 @@ namespace MR {
}
return dEint / stats.getTint();
}



void InternalEnergyComputer::scanNeighbourhood(const Particle* p, const int alpha0, const double currTemp)
{
neighbourhood.resize(1);
normalization = 1.0;

Point_t ep = p->getEndPoint(alpha0);
if (pGrid.isoutofbounds(ep))
return;
size_t x, y, z;
pGrid.pos2xyz(ep, x, y, z);

float tolerance2 = Particle::L * Particle::L; // distance threshold (particle length), hard coded
float costheta = Math::sqrt1_2; // angular threshold (45 degrees), hard coded
ParticleEnd pe;
float d1, d2, d, ct;

for (int i = -1; i <= 1; i++) {
for (int j = -1; j <= 1; j++) {
for (int k = -1; k <= 1; k++) {
const ParticleGrid::ParticleVectorType* pvec = pGrid.at(x+i, y+j, z+k);
if (pvec == NULL)
const ParticleGrid::ParticleContainer* pvec = pGrid.at(x+i, y+j, z+k);
if (!pvec)
continue;

for (ParticleGrid::ParticleVectorType::const_iterator it = pvec->begin(); it != pvec->end(); ++it)

std::lock_guard<std::mutex> lock (pvec->mutex);
for (ParticleGrid::ParticleContainer::const_iterator it = pvec->begin(); it != pvec->end(); ++it)
{
pe.par = *it;
if (pe.par == p)
Expand All @@ -81,14 +84,14 @@ namespace MR {
neighbourhood.push_back(pe);
}
}

}
}
}

}


ParticleEnd InternalEnergyComputer::pickNeighbour()
{
double sum = 0.0;
Expand All @@ -103,8 +106,8 @@ namespace MR {
}
return pe;
}


}
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/dwi/tractography/GT/mhsampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ namespace MR {

void MHSampler::birth()
{
//TRACE;
//std::cerr << 'b';
stats.incN('b');

Point_t pos;
Expand All @@ -85,7 +85,7 @@ namespace MR {

void MHSampler::death()
{
//TRACE;
//std::cerr << 'd';
stats.incN('d');

Particle* par;
Expand All @@ -111,7 +111,7 @@ namespace MR {

void MHSampler::randshift()
{
//TRACE;
//std::cerr << 'r';
stats.incN('r');

Particle* par;
Expand Down Expand Up @@ -143,7 +143,7 @@ namespace MR {

void MHSampler::optshift()
{
//TRACE;
//std::cerr << 'o';
stats.incN('o');

Particle* par;
Expand Down Expand Up @@ -176,7 +176,7 @@ namespace MR {

void MHSampler::connect() // TODO Current implementation does not prevent loops.
{
//TRACE;
//std::cerr << 'c';
stats.incN('c');

Particle* par;
Expand Down
53 changes: 27 additions & 26 deletions src/dwi/tractography/GT/mhsampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef __gt_mhsampler_h__
#define __gt_mhsampler_h__

#include "header.h"
#include "image.h"
#include "transform.h"

Expand All @@ -40,77 +41,77 @@ namespace MR {
class MHSampler
{ MEMALIGN(MHSampler)
public:
MHSampler(const Image<float>& dwi, Properties &p, Stats &s, ParticleGrid &pgrid,
MHSampler(const Header& dwiheader, Properties &p, Stats &s, ParticleGrid &pgrid,
EnergyComputer* e, Image<bool>& m)
: props(p), stats(s), pGrid(pgrid), E(e), T(dwi),
dims{size_t(dwi.size(0)), size_t(dwi.size(1)), size_t(dwi.size(2))},
mask(m), lock(make_shared<SpatialLock<float>>(5*Particle::L)),
: props(p), stats(s), pGrid(pgrid), E(e), T(dwiheader),
dims{size_t(dwiheader.size(0)), size_t(dwiheader.size(1)), size_t(dwiheader.size(2))},
mask(m), lock(make_shared<SpatialLock<float>>(std::max(5*Particle::L, float(2*pGrid.spacing())))),
sigpos(Particle::L / 8.), sigdir(0.2)
{
DEBUG("Initialise Metropolis Hastings sampler.");
}

MHSampler(const MHSampler& other)
: props(other.props), stats(other.stats), pGrid(other.pGrid), E(other.E->clone()),
: props(other.props), stats(other.stats), pGrid(other.pGrid), E(other.E->clone()),
T(other.T), dims(other.dims), mask(other.mask), lock(other.lock), rng_uniform(), rng_normal(), sigpos(other.sigpos), sigdir(other.sigdir)
{
DEBUG("Copy Metropolis Hastings sampler.");
}

~MHSampler() { delete E; }

void execute();

void next();

void birth();
void death();
void randshift();
void optshift();
void connect();


protected:

Properties& props;
Stats& stats;
ParticleGrid& pGrid;
EnergyComputer* E; // Polymorphic copy requires call to EnergyComputer::clone(), hence references or smart pointers won't do.

Transform T;
vector<size_t> dims;
Image<bool> mask;

std::shared_ptr< SpatialLock<float> > lock;
Math::RNG::Uniform<float> rng_uniform;
Math::RNG::Normal<float> rng_normal;
float sigpos, sigdir;


Point_t getRandPosInMask();

bool inMask(const Point_t p);

Point_t getRandDir();

void moveRandom(const Particle* par, Point_t& pos, Point_t& dir);

bool moveOptimal(const Particle* par, Point_t& pos, Point_t& dir) const;

inline double calcShiftProb(const Particle* par, const Point_t& pos, const Point_t& dir) const
{
Point_t Dpos = par->getPosition() - pos;
Point_t Ddir = par->getDirection() - dir;
return gaussian_pdf(Dpos, sigpos) * gaussian_pdf(Ddir, sigdir);
}

inline double gaussian_pdf(const Point_t& x, double sigma) const {
return std::exp( -x.squaredNorm() / (2*sigma) ) / std::sqrt( 2*Math::pi * sigma*sigma);
}


};


}
}
Expand Down
Loading
Loading