ParsedCLImplementation.java
package neureka.backend.main.implementations;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.ImplementationFor;
import neureka.devices.opencl.KernelCode;
import neureka.devices.opencl.OpenCLDevice;
import neureka.dtype.DataType;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
public class ParsedCLImplementation extends CLImplementation
{
private final java.util.function.Function<String, String> _aliasSwapper =
s ->
"//-=<PARSED>=-//\n" +
s.replace("src1", "src1[_i_of_idx_on_tln(prv_src1_cfg, rank)]")
.replace("src2", "src2[_i_of_idx_on_tln(prv_src2_cfg, rank)]")
.replace("input1", "src1[_i_of_i(i, prv_src1_cfg, rank)]")
.replace("input2", "src2[_i_of_i(i, prv_src2_cfg, rank)]")
.replace("input", "src1[_i_of_i(i, prv_src1_cfg, rank)]")
.replace("output", "drn[_i_of_i(i, prv_drn_cfg, rank)]")
.replace("handle", "src1[_i_of_idx_on_tln(prv_src1_cfg, rank)]")
.replace("drain", "src2[_i_of_idx_on_tln(prv_src2_cfg, rank)]")
.replace("origin", "drn[di]")
.replace("target", "frn[_i_of_idx_on_tln(prv_frn_cfg, rank)]") +
"\n//-=<PARSED>=-//";
private final java.util.function.Function<String, String> asAdvanced =
s ->
s.replace("target", "frn[_i_of_idx_on_tln(prv_frn2_cfg, rank)]")
.replace("input3","frn[_i_of_idx_on_tln(prv_frn2_cfg, rank)]")
.replace("//-=<ARGUMENT>=-//", "")
.replace("//-=<CONFIGURATION>=-//", "");
private final KernelCode[] _kernels;
public ParsedCLImplementation(
ImplementationFor<OpenCLDevice> lambda,
int arity,
String kernelSource,
String activationSource,
String differentiationSource,
String postfix,
Function<KernelCode, KernelCode[]> dataTypeAdapter
) {
super( lambda, arity );
String parsedCode = null;
String parsedName = null;
if ( activationSource == null && differentiationSource == null ) {
parsedCode = kernelSource;
parsedName = postfix;
} else if (kernelSource.contains("__kernel")) {
boolean templateFound;
String[] parts = kernelSource.split("__kernel")[ 1 ].split("\\(")[ 0 ].split(" ");
templateFound = parts[parts.length - 1].contains("template");
if (!templateFound)
throw new IllegalStateException("Invalid source code passed to AbstractCLExecution!");
else {
Map<String, String> map = _getParsedKernelsFromTemplate(
parts[parts.length - 1],
kernelSource,
activationSource,
differentiationSource,
postfix
);
parsedName = map.keySet().toArray(new String[ 0 ])[ 0 ];
parsedCode = map.values().toArray(new String[ 0 ])[ 0 ];
}
}
_kernels = dataTypeAdapter.apply( new KernelCode( parsedName, parsedCode ) );
}
private Map<String, String> _getParsedKernelsFromTemplate(
String templateName,
String kernelSource,
String activationSource,
String differentiationSource,
String postfix
) {
Map<String, String> code = new HashMap<>();
String preName = templateName.replace("template", "");
String source = kernelSource.replace("template", "");
String[] parts = source.split("//-=<OPERATION>=-//");
Parser parser = ( n, f, s ) -> {
String convcode =
parts[ 0 ].replace(preName, preName + n) +
_aliasSwapper.apply(f) +
parts[ 2 ] +
_aliasSwapper.apply(s) +
parts[4];
boolean isAdvanced = s.contains("target")&&s.contains("drain")&&s.contains("handle")
|| s.contains("input1")&&s.contains("input2")&&s.contains("input3");
convcode = (isAdvanced) ? asAdvanced.apply(convcode) : convcode;
code.put(preName + n, convcode);
};
//Tensor t0_origin, Tensor t1_handle, Tensor t2_drain ... when d>=0
//Tensor t0_drain, Tensor t1_src1, Tensor t2_src2
//drn[di], src1[_i_of_idx_on_tln(prv_src1_cfg, rank)], src2[_i_of_idx_on_tln(prv_src2_cfg, rank)]
//default: src1 o src2 -> drain
//inverse: src1/fdrn <-src2 <- drain
//===========================================================================
parser.apply(
postfix,
activationSource,
differentiationSource
);
return code;
}
@Override
public KernelCode getKernelFor( ExecutionCall<OpenCLDevice> call ) {
DataType<?> callType = call.input(0 ).getDataType();
return Arrays.stream(_kernels)
.filter( k -> k.getDataType().equals( callType ) )
.findFirst()
.orElse(_kernels[0]);
}
@Override
public KernelCode[] getKernelCode() {
return _kernels;
}
private interface Parser
{
void apply( String name, String first, String second );
}
}