Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 185 additions & 18 deletions cs/cs_parallel/VowpalWabbitThreadedLearning.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,41 @@ namespace VW
/// <summary>
/// VW wrapper supporting multi-core learning by utilizing thread-based allreduce.
/// </summary>
/// <remarks>
/// <para>
/// This class manages multiple internal VW instances coordinated via allreduce.
/// Each instance runs on its own thread with a bounded work queue. Calls to
/// <see cref="Learn(string)"/> enqueue examples and return immediately — learning
/// happens asynchronously on one of the internal instances (chosen by
/// <see cref="VowpalWabbitSettings.ExampleDistribution"/>). This is not the same as
/// thread-safe concurrent learning; callers submit work from a single thread (or
/// coordinate externally) and the class handles parallelism internally.
/// </para>
/// <para>Typical usage:</para>
/// <code>
/// using (var vw = new VowpalWabbitThreadedLearning(settings))
/// {
/// foreach (var example in examples)
/// vw.Learn(example);
///
/// // Option A: save mid-training by flushing
/// var saveTask = vw.SaveModel("checkpoint.model");
/// await vw.Flush();
/// await saveTask;
///
/// // Option B: save after training is complete
/// await vw.Complete();
/// await vw.SaveModel("final.model");
/// }
/// </code>
/// <para>
/// Weight synchronization (allreduce) occurs automatically every
/// <see cref="VowpalWabbitSettings.ExampleCountPerRun"/> examples, or explicitly
/// via <see cref="Flush"/>. Deferred operations like <see cref="SaveModel()"/> and
/// <see cref="PerformanceStatistics"/> execute at these synchronization points, or
/// can be called directly after <see cref="Complete"/>.
/// </para>
/// </remarks>
public class VowpalWabbitThreadedLearning : IDisposable
{
/// <summary>
Expand Down Expand Up @@ -58,6 +93,11 @@ public class VowpalWabbitThreadedLearning : IDisposable
/// </summary>
private Task[] completionTasks;

/// <summary>
/// Combined task tracking whether all completion tasks have finished.
/// </summary>
private Task allCompletedTask;

/// <summary>
/// Number of examples seen sofar. Used by round robin example distributor.
/// </summary>
Expand Down Expand Up @@ -147,8 +187,10 @@ public VowpalWabbitThreadedLearning(VowpalWabbitSettings settings)
// perform final AllReduce
vw.EndOfPass();

// execute synchronization actions
foreach (var syncAction in this.syncActions.RemoveAll())
// atomically drain and mark complete — allows sync actions
// (e.g. SaveModel) to be enqueued between Complete() and this
// continuation executing
foreach (var syncAction in this.syncActions.CompleteAndRemoveAll())
{
syncAction(vw);
}
Expand Down Expand Up @@ -229,6 +271,39 @@ private uint CheckEndOfPass()
return exampleCount;
}

/// <summary>
/// Forces an AllReduce synchronization and drains all pending sync actions
/// (e.g. <see cref="SaveModel()"/>, <see cref="PerformanceStatistics"/>)
/// without waiting for <see cref="VowpalWabbitSettings.ExampleCountPerRun"/> to be reached.
/// </summary>
/// <returns>Task that completes once the synchronization and all pending sync actions have executed.</returns>
public Task Flush()
{
var completionSource = new TaskCompletionSource<bool>();

this.syncActions.Add(vw => completionSource.SetResult(true));

this.observers[0].OnNext(vw =>
{
// perform AllReduce
vw.EndOfPass();

// execute synchronization actions
foreach (var syncAction in this.syncActions.RemoveAll())
{
syncAction(vw);
}
});

for (int i = 1; i < this.observers.Length; i++)
{
// perform AllReduce
this.observers[i].OnNext(vw => vw.EndOfPass());
}

return completionSource.Task;
}

/// <summary>
/// Enqueues an action to be executed on one of vw instances.
/// </summary>
Expand Down Expand Up @@ -279,6 +354,12 @@ internal Task<T> Post<T>(Func<VowpalWabbit, T> func)
/// Learns from the given example.
/// </summary>
/// <param name="line">The example to learn.</param>
/// <remarks>
/// This method enqueues the example for asynchronous learning on one of the
/// internal VW instances and returns immediately. The example string is captured
/// by the work item and must not be mutated after this call. To ensure all
/// enqueued learning is complete, call <see cref="Complete"/> or <see cref="Flush"/>.
/// </remarks>
public void Learn(string line)
{
Debug.Assert(line != null);
Expand All @@ -287,9 +368,15 @@ public void Learn(string line)
}

/// <summary>
/// Learns from the given example.
/// Learns from the given multi-line example.
/// </summary>
/// <param name="lines">The multi-line example to learn.</param>
/// <remarks>
/// This method enqueues the example for asynchronous learning on one of the
/// internal VW instances and returns immediately. The lines are captured
/// by the work item and must not be mutated after this call. To ensure all
/// enqueued learning is complete, call <see cref="Complete"/> or <see cref="Flush"/>.
/// </remarks>
public void Learn(IEnumerable<string> lines)
{
Debug.Assert(lines != null);
Expand All @@ -300,69 +387,116 @@ public void Learn(IEnumerable<string> lines)
/// <summary>
/// Synchronized performance statistics.
/// </summary>
/// <remarks>The task is only completed after synchronization of all instances, triggered <see cref="VowpalWabbitSettings.ExampleCountPerRun"/> example.</remarks>
/// <remarks>
/// Can be accessed before or after <see cref="Complete"/>. If accessed after completion,
/// returns statistics directly from the root VW instance.
/// </remarks>
public Task<VowpalWabbitPerformanceStatistics> PerformanceStatistics
{
get
{
if (this.allCompletedTask != null && this.allCompletedTask.IsCompleted)
{
return Task.FromResult(this.vws[0].PerformanceStatistics);
}

var completionSource = new TaskCompletionSource<VowpalWabbitPerformanceStatistics>();

this.syncActions.Add(vw => completionSource.SetResult(vw.PerformanceStatistics));
if (!this.syncActions.TryAdd(vw => completionSource.SetResult(vw.PerformanceStatistics)))
{
return Task.FromResult(this.vws[0].PerformanceStatistics);
}

return completionSource.Task;
}
}

/// <summary>
/// Signal that no more examples are send.
/// Signals that no more examples will be submitted.
/// </summary>
/// <returns>Task completes once the learning and cleanup is done.</returns>
/// <returns>Task that completes once all enqueued examples have been learned
/// and a final allreduce synchronization has been performed.</returns>
/// <remarks>
/// After awaiting this task, the model is fully trained and methods like
/// <see cref="SaveModel()"/> and <see cref="PerformanceStatistics"/> can be
/// called synchronously (they execute immediately on the root VW instance
/// rather than being deferred to a sync point).
/// </remarks>
public Task Complete()
{
// make sure no more sync actions are added, which might otherwise never been called
this.syncActions.CompleteAdding();

foreach (var actionBlock in this.actionBlocks)
{
actionBlock.Complete();
}

return Task.WhenAll(this.completionTasks);

this.allCompletedTask = Task.WhenAll(this.completionTasks);
return this.allCompletedTask;
}

/// <summary>
/// Saves a model as part of the synchronization.
/// </summary>
/// <returns>Task compeletes once the model is saved.</returns>
/// <remarks>
/// Can be called before or after <see cref="Complete"/>. If called after completion,
/// the model is saved directly on the root VW instance. If called before, the save is
/// deferred until the next synchronization point or completion.
/// </remarks>
/// <returns>Task that completes once the model is saved.</returns>
public Task SaveModel()
{
if (this.allCompletedTask != null && this.allCompletedTask.IsCompleted)
{
this.vws[0].SaveModel();
return Task.FromResult(true);
}

var completionSource = new TaskCompletionSource<bool>();

this.syncActions.Add(vw =>
if (!this.syncActions.TryAdd(vw =>
{
vw.SaveModel();
completionSource.SetResult(true);
});
}))
{
// sync actions were already drained and marked complete
this.vws[0].SaveModel();
return Task.FromResult(true);
}

return completionSource.Task;
}

/// <summary>
/// Saves a model as part of the synchronization.
/// </summary>
/// <returns>Task compeletes once the model is saved.</returns>
/// <remarks>
/// Can be called before or after <see cref="Complete"/>. If called after completion,
/// the model is saved directly on the root VW instance. If called before, the save is
/// deferred until the next synchronization point or completion.
/// </remarks>
/// <returns>Task that completes once the model is saved.</returns>
public Task SaveModel(string filename)
{
Debug.Assert(!string.IsNullOrEmpty(filename));

if (this.allCompletedTask != null && this.allCompletedTask.IsCompleted)
{
this.vws[0].SaveModel(filename);
return Task.FromResult(true);
}

var completionSource = new TaskCompletionSource<bool>();

this.syncActions.Add(vw =>
if (!this.syncActions.TryAdd(vw =>
{
vw.SaveModel(filename);
completionSource.SetResult(true);
});
}))
{
// sync actions were already drained and marked complete
this.vws[0].SaveModel(filename);
return Task.FromResult(true);
}

return completionSource.Task;
}
Expand Down Expand Up @@ -444,6 +578,23 @@ public void Add(T item)
}
}

/// <summary>
/// Tries to add an object to the end of the list.
/// </summary>
/// <param name="item">The object to be added to the list.</param>
/// <returns>True if the item was added; false if the list has been marked complete.</returns>
public bool TryAdd(T item)
{
lock (this.lockObject)
{
if (completed)
return false;

this.items.Add(item);
return true;
}
}

/// <summary>
/// Marks this list as complete. Any subsequent calls to <see cref="Add"/> will trigger an <see cref="InvalidOperationException"/>.
/// </summary>
Expand All @@ -455,6 +606,22 @@ public void CompleteAdding()
}
}

/// <summary>
/// Atomically marks this list as complete and removes all elements.
/// </summary>
/// <returns>The elements removed.</returns>
public T[] CompleteAndRemoveAll()
{
lock (this.lockObject)
{
this.completed = true;
var ret = this.items.ToArray();
this.items.Clear();

return ret;
}
}

/// <summary>
/// Removes all elements from the list.
/// </summary>
Expand Down
Loading
Loading