BackPropTargetCollector.java

package neureka.autograd;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.TreeMap;
import java.util.stream.Collectors;

class BackPropTargetCollector<V> {

    private TreeMap<GraphNode<V>, Value> _targetsToAgents;

    /**
     * @param target nodes are graph nodes which contain either tensors requiring errors for accumulation and/or more targets.
     * @param agent ADAction's are used during back-propagation in order to distribute an error throughout the graph.
     */
    public void put( int index, GraphNode<V> target, ADAction agent ) {
        if ( _targetsToAgents == null ) _targetsToAgents = new TreeMap<>((a, b) -> a.hashCode() - b.hashCode());

        if ( _targetsToAgents.containsKey( target ) )
            _targetsToAgents.get( target ).agents().add( agent );
        else
            _targetsToAgents.put( target, new Value(index, agent) );
    }

    public List<BackPropTargets<V>> getTargets() {
        if ( _targetsToAgents == null ) return Collections.emptyList();
        else
            return _targetsToAgents.entrySet()
                    .stream()
                    .map( e -> new BackPropTargets<>( e.getValue().index(), e.getKey(), e.getValue().agents() ) )
                    .collect(Collectors.toList());
    }


    private static class Value {
        private final int _index;
        private final List<ADAction> _agents = new ArrayList<>();

        private Value(int index, ADAction agent) {
            _index = index;
            _agents.add(agent);
        }

        public int index() { return _index; }

        public List<ADAction> agents() { return _agents; }
    }

}