Smile

This module contains wrappers of some Smile regression classes as FittedPredictor.Fitter<double[], double[], Double>.

Examples

Random forest

double[][] data = createData();
List<Object> dataList = Arrays.asList(data);
RandomForestPredictorFitter rfpf = new RandomForestPredictorFitter();
FittedPredictor<double[], double[], Double> predictor = rfpf.fit(
        dataList, 
        e -> { 
            double[] s = (double[]) e;
            double[] f = new double[s.length -1];
            System.arraycopy(s, 0, f, 0, f.length);
            return f;
        },
        e -> new double[] { ((double[]) e)[((double[]) e).length - 1] });       

System.out.println(predictor.getError());
double[] prediction = predictor.predict(new double[] { 6, 7, 8 });
System.out.println(prediction[0]);

OLS

double[][] data = createData();
List<Object> dataList = Arrays.asList(data);
OLSPredictorFitter olspf = new OLSPredictorFitter();
FittedPredictor<double[], double[], Double> predictor = olspf.fit(
        dataList, 
        e -> { 
            double[] s = (double[]) e;
            double[] f = new double[s.length -1];
            System.arraycopy(s, 0, f, 0, f.length);
            return f;
        },
        e -> new double[] { ((double[]) e)[((double[]) e).length - 1] });       

System.out.println(predictor.getError());

double[] prediction = predictor.predict(new double[] { 6, 7, 8 });
System.out.println(prediction[0]);

Composition

double[][] data = createData();
List<Object> dataList = Arrays.asList(data);
OLSPredictorFitter olspf = new OLSPredictorFitter();

Fitter<double[], double[], Double> other = new AbstractDoubleFitter() {
    
    @Override
    protected Function<double[][], double[][]> fit(double[][] features, double[][] labels) {
        return input -> {
            double[][] result = new double[labels.length][];
            for (int i = 0; i < result.length; ++i) {
                result[i] = new double[labels[i].length];
                result[i][0] = 0.22;
            }
            return result;
        };
    }
    
};

FittedPredictor.Fitter<double[], double[], Double> composite = olspf.compose(other);

FittedPredictor<double[], double[], Double> predictor = composite.fit(
        dataList, 
        e -> { 
            double[] s = (double[]) e;
            double[] f = new double[s.length -1];
            System.arraycopy(s, 0, f, 0, f.length);
            return f;
        },
        e -> new double[] { ((double[]) e)[((double[]) e).length - 1] });       

System.out.println(predictor.getError());

double[] prediction = predictor.predict(new double[] { 6, 7, 8 });
System.out.println(prediction[0]);

The other fitter above always predicts 0.22 - just for testing.

Recursive (Autoregression)

OLSRecursivePredictorFitter fitter = new OLSRecursivePredictorFitter();

double[][] data = {
    { 1, 2, 3, 4.1, 5, 6, 7 },
    { 2, 3, 4, 4.9, 6, 7, 8 },
    { 3, 4, 5, 6.1, 7, 8, 9 },
    { 4, 5, 6, 6.9, 8, 9, 10 },
    { 5, 6, 7, 8.1, 9, 10, 11 },            
    { 6, 7, 8, 9.1, 10, 11, 12 },           
    { 7, 8, 9.1, 10, 11, 12, 13 },          
    { 8, 9.1, 10, 11, 12, 13, 14 }          
};              

int labels = 3;

List<Object> dataList = org.assertj.core.util.Arrays.asList(data);
FittedPredictor<double[], double[], Double> predictor = fitter.fit(
        dataList, 
        e -> { 
            double[] s = (double[]) e;
            double[] f = new double[s.length - labels];
            System.arraycopy(s, 0, f, 0, f.length);
            return f;
        },
        e -> { 
            double[] s = (double[]) e;
            double[] l = new double[labels];
            System.arraycopy(s, s.length - labels, l, 0, l.length);
            return l;
        });     

System.out.println(predictor.getError());

double[] prediction = predictor.predict(new double[] { 6, 7, 8, 9 });
System.out.println(prediction[0]);      

Adapting features and labels

Map<String,Double> ageMap = Map.of(
    "Alice", 25.0,
    "Bob", 30.0,
    "Eve", 35.0,
    "Mallory", 40.0
);

Map<String,Double> weightMap = Map.of(
        "Alice", 100.0,
        "Bob", 120.0,
        "Eve", 140.0,
        "Mallory", 0.0
    );
        
OLSPredictorFitter olspf = new OLSPredictorFitter();
Fitter<String, double[], Double> featureAdapted = olspf.adaptFeature(name -> new double[] { ageMap.get(name) });
Fitter<String, Double, Double> labelAdapted = featureAdapted.adaptLabel(
        weight -> new double[] { weight },
        wa -> wa[0]
);

FittedPredictor<String, Double, Double> predictor = labelAdapted.fit(
        List.of("Alice", "Bob", "Eve"),
        Function.identity(),
        weightMap::get);        

System.out.println(predictor.getError());

double prediction = predictor.predict("Mallory");
System.out.println(prediction);

Regression Tree

double[][] data = createData();
List<Object> dataList = Arrays.asList(data);
RegressionTreePredictorFitter rtpf = new RegressionTreePredictorFitter();
FittedPredictor<double[], double[], Double> predictor = rtpf.fit(
        dataList, 
        e -> { 
            double[] s = (double[]) e;
            double[] f = new double[s.length -1];
            System.arraycopy(s, 0, f, 0, f.length);
            return f;
        },
        e -> new double[] { ((double[]) e)[((double[]) e).length - 1] });       

double[] prediction = predictor.predict(new double[] { 6, 7, 8 });
System.out.println(prediction[0]);

MLP

double[][] data = createData();
List<Object> dataList = Arrays.asList(data);
MLPPredictorFitter mlppf = new MLPPredictorFitter();
FittedPredictor<double[], double[], Double> predictor = mlppf.fit(
        dataList, 
        e -> { 
            double[] s = (double[]) e;
            double[] f = new double[s.length -1];
            System.arraycopy(s, 0, f, 0, f.length);
            return f;
        },
        e -> new double[] { ((double[]) e)[((double[]) e).length - 1] });       

System.out.println(predictor.getError());
double[] prediction = predictor.predict(new double[] { 6, 7, 8 });
System.out.println(prediction[0]);

The above example uses regression MLPs with a single output. As such there is an MLP per label element. An implementation using a single MLP with multiple outputs will be provided in the future.