JITProp.java
package neureka.autograd;
import neureka.Tensor;
import neureka.common.composition.Component;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
/**
* This class keeps track of graph nodes which require
* back-propagation in order
* to be able to continue the process at a later point in time
* (based on some configurable conditions).
*
* @param <V> The type parameter of the involved tensors.
*/
public final class JITProp<V> implements Component<Tensor<V>>
{
private Set<GraphNode<V>> _finished;
private Set<GraphNode<V>> _pending;
public JITProp( Set<GraphNode<V>> pending ) {
_pending = new HashSet<>();
_pending.addAll( pending ); // Every JITProp component has their own Set.
//... otherwise this would lead to finished JIT-Propagations where in fact traversals are still pending...
}
/**
* @param pending A set of GraphNode<V> instance which are saved for future backprop continuation.
*/
public void addPending( Set<GraphNode<V>> pending ) {
if ( pending.isEmpty() )
throw new IllegalStateException("Trying to add empty pending errors set to JITProp.");
if ( !isDone() )
throw new IllegalStateException("Trying to add pending errors to JITProp which is done.");
_pending.addAll( pending );
}
/**
*
* @param finishedJITProps The reference to a GraphNote which has finished (JITed) backpropation.
*/
public void noteFinished( GraphNode<V> finishedJITProps ) {
if ( _finished == null ) _finished = new HashSet<>();
_finished.add( finishedJITProps );
if ( _pending != null ) {
Set<GraphNode<V>> intersection = _finished.stream().filter(_pending::contains).collect(Collectors.toSet());
_finished.removeAll( intersection );
_pending.removeAll( intersection );
if ( _finished.isEmpty() ) _finished = null;
if ( _pending.isEmpty() ) _pending = null;
}
}
public int finishedCount() { return _finished == null ? 0 : _finished.size(); }
public int pendingCount() { return _pending == null ? 0 : _pending.size(); }
/**
* This method triggers the continuation of the back-propagation which
* has been put on hold by saving the pending graph nodes inside this class. <br>
* The execution request happens when gradients are immediately required by a tensor,
* which is the case when the tensor is about to apply its gradients. <br>
* However because the gradient has not yet been fully calculated this method
* will be called first (assuming the tensor has a JITProp component stored).
*/
public void execute() {
if ( _pending == null ) return;
_pending.forEach( n -> {
if ( _finished == null || !_finished.contains( n ) ) {
PendingError<V> pe = n.getAndRemovePendingError();
if ( !pe.isFullyAccumulated() )
throw new IllegalStateException("Pending error has not received expected accumulation.");
n.backwardJIT( pe.getAccumulatedError() ); // Continue back-prop recursively!
}
});
if ( pendingCount() > 0 )
throw new IllegalStateException("Pending error has not received expected accumulation.");
_pending = null;
}
/**
* @return The truth value determining if the back-propagation has been completed.
*/
public boolean isDone() { return ( _finished == null && _pending == null ); }
@Override
public String toString() {
int finished = ( _finished == null ? 0 : _finished.size() );
int pending = ( _pending == null ? 0 : _pending.size() );
return this.getClass().getSimpleName()+"@"+Integer.toHexString(hashCode())+"[finished="+finished+",pending="+pending+",isDone="+isDone()+"]";
}
}