CPUDot.java

  1. package neureka.backend.main.implementations.linear;

  2. import neureka.Tensor;
  3. import neureka.backend.api.ExecutionCall;
  4. import neureka.backend.api.ImplementationFor;
  5. import neureka.backend.main.operations.linear.internal.blas.DOT;
  6. import neureka.devices.host.CPU;

  7. public class CPUDot implements ImplementationFor<CPU> {

  8.     @Override
  9.     public Tensor<?> run(ExecutionCall<CPU> call) {

  10.         if ( !call.validate().all( (t1, t2) -> t1.getNDConf().getLayout().isCompatible(t2.getNDConf().getLayout()) ).isValid() )
  11.             throw new IllegalArgumentException(
  12.                         "Data layout inconsistency between provided tensors encountered. " +
  13.                         "All tensors must be of the same layout."
  14.                     );

  15.         if ( !call.validate().allShare(Tensor::getDataType).isValid() )
  16.             throw new IllegalArgumentException(
  17.                        "Type inconsistency between provided tensors encountered. " +
  18.                        "All tensors must be of the same type."
  19.                     );

  20.         int[] shapeA = call.input( 1 ).getNDConf().shape();
  21.         int[] shapeB = call.input( 2 ).getNDConf().shape();
  22.         int[] shapeC = call.input( 0 ).getNDConf().shape();

  23.         if ( shapeA.length != 1 || shapeB.length != 1 || shapeC.length != 1 )
  24.             throw new IllegalArgumentException("Dot product only works on vectors.");

  25.         if ( shapeA[0] != shapeB[0] )
  26.             throw new IllegalArgumentException("Dot product only works on vectors of the same length.");

  27.         // A * B = C // [N]*[N] = [1]
  28.         int size = shapeA[0];

  29.         Class<?> type = call.input( 0 ).getDataType().getItemTypeClass();
  30.         if ( type == Double.class ) {
  31.             double[] A = call.input(Double.class, 1).mut().getDataAs(double[].class);
  32.             double[] B = call.input(Double.class, 2).mut().getDataAs(double[].class);
  33.             double[] C = call.input(Double.class, 0).mut().getDataForWriting(double[].class);
  34.             execute( A, B, C, size );
  35.         } else if ( type == Float.class ) {
  36.             float[] A = call.input(Float.class, 1).mut().getDataAs(float[].class);
  37.             float[] B = call.input(Float.class, 2).mut().getDataAs(float[].class);
  38.             float[] C = call.input(Float.class, 0).mut().getDataForWriting(float[].class);
  39.             execute( A, B, C, size );
  40.         }
  41.         else if ( type == Long.class ) {
  42.             long[] A = call.input(Long.class, 1).mut().getDataAs(long[].class);
  43.             long[] B = call.input(Long.class, 2).mut().getDataAs(long[].class);
  44.             long[] C = call.input(Long.class, 0).mut().getDataForWriting(long[].class);
  45.             execute( A, B, C, size );
  46.         }
  47.         else if ( type == Integer.class ) {
  48.             int[] A = call.input(Integer.class, 1).mut().getDataAs(int[].class);
  49.             int[] B = call.input(Integer.class, 2).mut().getDataAs(int[].class);
  50.             int[] C = call.input(Integer.class, 0).mut().getDataForWriting(int[].class);
  51.             execute( A, B, C, size );
  52.         }
  53.         else
  54.             throw new IllegalArgumentException(
  55.                         "Data type '"+type.getSimpleName()+"' not yet supported " +
  56.                         "for CPU based dot product!"
  57.                     );

  58.         return call.input( 0 );
  59.     }

  60.     private static void execute( double[] A, double[] B, double[] C, int size ) {
  61.         C[0] = DOT.invoke( A, 0, B, 0, 0, size );
  62.     }

  63.     private static void execute( float[] A, float[] B, float[] C, int size ) {
  64.         C[0] = DOT.invoke( A, 0, B, 0, 0, size );
  65.     }

  66.     private static void execute( long[] A, long[] B, long[] C, int size ) {
  67.         C[0] = DOT.invoke( A, 0, B, 0, 0, size );
  68.     }

  69.     private static void execute( int[] A, int[] B, int[] C, int size ) {
  70.         C[0] = DOT.invoke( A, 0, B, 0, 0, size );
  71.     }

  72. }