FunctionParser.java

package neureka.math.parsing;

import neureka.backend.api.BackendContext;
import neureka.backend.api.Operation;
import neureka.math.Function;
import neureka.math.implementations.FunctionConstant;
import neureka.math.implementations.FunctionInput;
import neureka.math.implementations.FunctionNode;
import neureka.math.implementations.FunctionVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 *  The {@link FunctionParser} takes a {@link BackendContext} instance based on which
 *  it builds {@link Function} implementation instances, usually by parsing {@link String}s.
 *  The information needed for parsing is being provided by the {@link Operation}s within the formerly
 *  mentioned {@link BackendContext}...
 */
public class FunctionParser
{
    private static final Logger _LOG = LoggerFactory.getLogger(FunctionParser.class);

    private static final Pattern _variablePattern = Pattern.compile("^(-?[iI]{1}[g]?\\[?[ ]*[g]?[jJ]+[ ]*\\]?)");
    private static final Pattern _inputPattern = Pattern.compile("^(-?[iI]{1}[g]?\\[?[ ]*[g]?[0-9]+[ ]*\\]?)");
    private static final Pattern _constantPattern = Pattern.compile("^((-?[0-9]*|[0-9]*)[.]?[0-9]*((e|E)[-]?[0-9]+)?)");

    private static final Pattern _permutePattern = Pattern.compile("^(\\[{1}(.,)*(.)+[,]?\\]{1}:?((\\({1}[.]*\\){1})|(.+)))");
    private static final Pattern _nodePattern = Pattern.compile("^([\\(]{1}.+[\\)]{1})");

    private final BackendContext _context;

    /**
     * @param context The {@link BackendContext} which will be used as a basis to parse new {@link Function}
     *                implementation instance from provided {@link String} expressions.
     */
    public FunctionParser( BackendContext context ) { _context = context; }
    
    /**
     * @param operation The {@link Operation} based on which the {@link Function} ought to be created.
     * @param numberOfArgs The number of arguments the produced {@link Function} ought to have.
     * @param doAD The flag determining if the {@link Function} built by this method should perform autograd or not.
     * @return A {@link Function} implementation instance which satisfied the supplied parameters.
     */
    public Function parse( Operation operation, int numberOfArgs, boolean doAD )
    {
        if ( operation.isIndexer() )
            return parse( operation.getIdentifier() + "( I[j] )", doAD );

        String args = IntStream.iterate( 0, n -> n + 1 )
                                .limit( numberOfArgs )
                                .mapToObj( i -> "I[" + i + "]" )
                                .collect( Collectors.joining( ", " ) );

        // A function always has to be parsable:
        return parse( operation.getIdentifier() + "(" + args + ")", doAD );
    }

    /**
     * @param expression contains the function as String provided by the user
     * @param doAD       is used to turn autograd on or off for this function
     * @return the function which has been built from the expression
     */
    public Function parse( String expression, boolean doAD )
    {
        if (
            expression.length() > 0 &&
            (expression.charAt( 0 ) != '(' || expression.charAt( expression.length() - 1 ) != ')')
        )
            expression = ("(" + expression + ")");

        if ( _context.getFunctionCache().has( expression, doAD ) )
            return _context.getFunctionCache().get( expression, doAD );

        expression = ParseUtil.unpackAndCorrect( expression );
        Function built = _parse( expression, doAD );
        if ( built != null )
            _context.getFunctionCache().put( built );
        else
            _LOG.error("Failed to parse function based on expression '"+expression+"' and autograd flag '"+doAD+"'.");
        return built;
    }

    /**
     * @param expression is a blueprint String for the function builder
     * @param doAD       enables or disables autograd for this function
     * @return a function which has been built by the given expression
     */
    private Function _parse( String expression, boolean doAD )
    {
        // TODO: Remove this! It's error prone! (Operations should define parsing to some extent)
        expression = expression
                .replace("<<", "" + ((char) 171))
                .replace(">>", "" + ((char) 187));
        expression = expression
                .replace("<-", "<")
                .replace("->", ">");

        if ( expression.equals("") )
            return new FunctionConstant("0");

        expression = ParseUtil.unpackAndCorrect(expression);
        List<String> foundOperations = new ArrayList<>();
        List<String> foundComponents = new ArrayList<>();

        for ( int ei = 0; ei < expression.length(); ) {
            final String newComponent = ParseUtil.findComponentIn( expression, ei );
            if ( newComponent != null ) {
                // Empty strings are not components and will be skipped:
                if ( newComponent.trim().isEmpty()) ei += newComponent.length();
                else // String has content so lets add it to the lists:
                {
                    if ( foundComponents.size() <= foundOperations.size() ) {
                        foundComponents.add(newComponent);
                    }
                    ei += newComponent.length(); // And now we continue parsing where the string ends...
                    // After a component however, we expect an operator:
                    final String newOperation = ParseUtil.parsedOperation( expression, ei );
                    if ( newOperation != null ) {
                        ei += newOperation.length();
                        if ( newOperation.length() <= 0 ) continue;
                        foundOperations.add( newOperation );
                    }
                }
            }
            else
                ++ei; // Parsing failed for this index so let's try the next one!
        }
        //---

        int counter = _context.size();
        for ( int j = _context.size(); j > 0; --j ) {
            if ( !foundOperations.contains( _context.getOperation(j - 1).getOperator() ) )
                --counter;
            else
                j = 0;
        }
        for ( int operationID = 0; operationID < counter; operationID++ ) {
            final List<String> newJunctors = new ArrayList<>();
            final List<String> newComponents = new ArrayList<>();
            if ( foundOperations.contains( _context.getOperation( operationID ).getOperator() ) ) {
                String currentChain = null;
                boolean groupingOccurred = false;
                boolean enoughPresent = ParseUtil.numberOfOperationsWithin( foundOperations ) > 1;// Otherwise: I[j]**4 goes nuts!
                if ( enoughPresent ) {
                    String[] foundCompArray = foundComponents.toArray( new String[ 0 ] );
                    int length = foundCompArray.length;
                    for ( int ci = 0; ci < length; ci++ ) {
                        String currentComponent;
                        currentComponent = foundCompArray[ci];
                        String currentOperation = null;
                        if ( foundOperations.size() > ci ) {
                            currentOperation = foundOperations.get(ci);
                        }
                        if ( currentOperation != null ) {
                            if (
                                    currentOperation.equals(_context.getOperation(operationID).getOperator())
                            ) {
                                final String newChain =
                                        ParseUtil.groupBy(
                                                _context.getOperation(operationID).getOperator(),
                                                currentChain,
                                                currentComponent,
                                                currentOperation
                                        );
                                if ( newChain != null ) currentChain = newChain;

                                groupingOccurred = true;
                            } else {
                                if ( currentChain == null ) newComponents.add(currentComponent);
                                else newComponents.add( currentChain + currentComponent );
                                newJunctors.add( currentOperation );
                                groupingOccurred = true;
                                currentChain = null;
                            }
                        } else {
                            if ( currentChain == null )
                                newComponents.add( currentComponent );
                            else {
                                newComponents.add( currentChain + currentComponent );
                                groupingOccurred = true;
                            }
                            currentChain = null;
                        }
                    }
                }
                if ( groupingOccurred ) {
                    foundOperations = newJunctors;
                    foundComponents = newComponents;
                }
            }
        }

        // building sources and function:
        if ( foundComponents.size() == 1 )
            return _buildFunction( foundComponents.get(0), doAD );
        else
            // It's not a function but operators:
            return _buildOperators( foundComponents, foundOperations, doAD );

    }

    private Function _buildFunction( String foundComponent, boolean doAD ) {
        ArrayList<Function> sources = new ArrayList<>();
        String possibleFunction = ParseUtil.parsedOperation(
                foundComponent,
                0
        );
        if ( possibleFunction != null && possibleFunction.length() > 1 ) {
            for ( int oi = 0; oi < _context.size(); oi++ ) {
                if (_context.getOperation(oi).getIdentifier().equalsIgnoreCase(possibleFunction)) {
                    List<String> parameters = ParseUtil.findParametersIn(
                                                                        foundComponent,
                                                                        possibleFunction.length()
                                                                );
                    assert parameters != null;
                    for ( String p : parameters ) sources.add(parse(p, doAD));
                    return new FunctionNode( _context.getOperation( oi ), sources, doAD );
                }
            }
        }
        //---
        String component = ParseUtil.unpackAndCorrect( foundComponent );

        if ( _constantPattern.matcher( component ).matches()   ) return new FunctionConstant( component );
        else if ( _inputPattern.matcher( component ).find()    ) return FunctionInput.of( component, doAD );
        else if ( _variablePattern.matcher( component ).find() ) return new FunctionVariable( component );
        else if ( component.startsWith("-") ) {
            component = "-1 * "+component.substring(1);
            return _parse(component, doAD);
        }
        // If the component did not trigger constant/input/variable creation: -> Cleaning!
        String cleaned = ParseUtil.cleanedHeadAndTail( component );
        String raw = component.replace( cleaned, "" );
        String assumed = ParseUtil.assumptionBasedOn( raw );
        if ( assumed.trim().equals("") ) component = cleaned;
        else component = assumed + cleaned;
        // Let's try again:
        Function result;
        try {
            result = parse( component, doAD );
        } catch (Exception e) {
            throw new IllegalStateException("Failed to parse expression '"+component+"'! Cause: "+e.getCause());
        }
        return result;
    }

    private Function _buildOperators(
            List<String> foundComponents,
            List<String> foundOperators,
            boolean doAD
    ) {
        // identifying operator id:
        int operationIndex = 0;
        if ( foundOperators.size() >= 1 ) {
            for (int currentIndex = 0; currentIndex < _context.size(); ++currentIndex) {
                if ( _context.getOperation(currentIndex).getOperator().equals(foundOperators.get( 0 )) ) {
                    operationIndex = currentIndex;
                }
            }
        }
        String asString = foundComponents.stream()
                .collect(
                        Collectors.joining(
                                _context.getOperation( operationIndex ).getOperator()
                        )
                );
        // More than one component left:
        ArrayList<Function> sources = new ArrayList<>();
        if ( _context.getOperation( operationIndex ).getArity() > 1 ) {
            foundComponents = _groupAccordingToArity(
                                        _context.getOperation( operationIndex ).getArity(),
                                        foundComponents,
                                        _context.getOperation( operationIndex ).getOperator()
                                );
        } else if ( _permutePattern.matcher(asString).matches() ) {
            foundComponents.set(0, foundComponents.get( 0 ).substring(1));
            String[] splitted;
            if (foundComponents.get(foundComponents.size() - 1).contains("]")) {
                int offset = 1;
                if (foundComponents.get(foundComponents.size() - 1).contains("]:")) {
                    offset = 2;
                    splitted = foundComponents.get(foundComponents.size() - 1).split("]:");
                } else {
                    splitted = foundComponents.get(foundComponents.size() - 1).split("]");
                }
                if (splitted.length > 1) {
                    splitted = new String[]{splitted[ 0 ], foundComponents.get(foundComponents.size() - 1).substring(splitted[ 0 ].length() + offset)};
                    foundComponents.remove(foundComponents.size() - 1);
                    foundComponents.addAll(Arrays.asList(splitted));
                }
            }
        }
        for ( String component : foundComponents )
            sources.add(
                    parse(component, doAD) // a dangerous recursion lives here!
            );

        sources.trimToSize();
        if ( sources.size() == 1 ) return sources.get( 0 );
        if ( sources.size() == 0 ) return null;
        ArrayList<Function> newVariable = new ArrayList<>();
        for ( Function source : sources ) {
            if ( source != null ) newVariable.add(source);
        }
        sources = newVariable;
        return new FunctionNode( _context.getOperation(operationIndex), sources, doAD );
    }

    private List<String> _groupAccordingToArity(int arity, List<String> components, String operator) {
        if ( components.size() > arity && arity > 1 ) {
            String newComponent =
                    "(" +
                            IntStream.iterate( 0, n -> n + 1 )
                             .limit(arity)
                             .mapToObj( components::get )
                             .collect(Collectors.joining( operator )) +
                    ")";
            for ( int i = 0; i < arity; i++ )  components.remove(components.get( 0 ));
            components.add(0, newComponent);
            return _groupAccordingToArity( arity, components, operator );
        }
        return components;
    }

}