public GeneralRegressionSpec( ModelSchema modelSchema, RegressionTable regressionTable, LinkFunction linkFunction ) { super( modelSchema ); this.linkFunction = linkFunction; addRegressionTable( regressionTable ); }
public Parameter( String name, double beta, List<Predictor> predictors ) { this.name = name; this.beta = beta; addPredictors( predictors ); }
public void addPredictors( List<Predictor> predictors ) { for( Predictor predictor : predictors ) addPredictor( predictor ); }
GeneralRegressionSpec regressionSpec = new GeneralRegressionSpec( modelSchema ); regressionSpec.setNormalization( new SoftMaxNormalization() ); RegressionTable regressionTable = new RegressionTable( "versicolor" ); regressionTable.addParameter( new Parameter( "intercept", 86.7061379450354d ) ); regressionTable.addParameter( new Parameter( "p0", -11.3336819785783d, new CovariantPredictor( "sepal_length" ) ) ); regressionTable.addParameter( new Parameter( "p1", -40.8601511206805d, new CovariantPredictor( "sepal_width" ) ) ); regressionTable.addParameter( new Parameter( "p2", 38.439099544679d, new CovariantPredictor( "petal_length" ) ) ); regressionTable.addParameter( new Parameter( "p3", -12.2920287460217d, new CovariantPredictor( "petal_width" ) ) ); regressionSpec.addRegressionTable( regressionTable ); RegressionTable regressionTable = new RegressionTable( "virginica" ); regressionTable.addParameter( new Parameter( "intercept", -111.666532867146d ) ); regressionTable.addParameter( new Parameter( "p0", -47.1170644419116d, new CovariantPredictor( "sepal_length" ) ) ); regressionTable.addParameter( new Parameter( "p1", -51.6805606658275d, new CovariantPredictor( "sepal_width" ) ) ); regressionTable.addParameter( new Parameter( "p2", 108.27736751831d, new CovariantPredictor( "petal_length" ) ) ); regressionTable.addParameter( new Parameter( "p3", 54.0277175236148d, new CovariantPredictor( "petal_width" ) ) ); regressionSpec.addRegressionTable( regressionTable ); RegressionTable regressionTable = new RegressionTable( "setosa" ); regressionTable.addParameter( new Parameter( "intercept", 0d ) ); regressionSpec.addRegressionTable( regressionTable );
GeneralRegressionSpec regressionSpec = new GeneralRegressionSpec( modelSchema ); RegressionTable regressionTable = new RegressionTable(); regressionTable.addParameter( new Parameter( "intercept", 2.24166872421148d ) ); regressionTable.addParameter( new Parameter( "p1", 0.53448203205212d, new CovariantPredictor( "sepal_width" ) ) ); regressionTable.addParameter( new Parameter( "p2", 0.691035562908626d, new CovariantPredictor( "petal_length" ) ) ); regressionTable.addParameter( new Parameter( "p3", -0.21488157609202d, new CovariantPredictor( "petal_width" ) ) ); regressionTable.addParameter( new Parameter( "p4", 0d, new FactorPredictor( "species", "setosa" ) ) ); regressionTable.addParameter( new Parameter( "p5", -0.43150751368126d, new FactorPredictor( "species", "versicolor" ) ) ); regressionTable.addParameter( new Parameter( "p6", -0.61868924203063d, new FactorPredictor( "species", "virginica" ) ) ); regressionSpec.addRegressionTable( regressionTable ); PredictionRegressionFunction regressionFunction = new PredictionRegressionFunction( regressionSpec );
GeneralRegressionSpec regressionSpec = new GeneralRegressionSpec( modelSchema ); regressionSpec.setLinkFunction( LinkFunction.LOGIT ); RegressionTable table = new RegressionTable(); table.addParameter( new Parameter( "p0", -16.9456960387809d ) ); table.addParameter( new Parameter( "p1", 11.7592159418536d, new CovariantPredictor( "sepal_length" ) ) ); table.addParameter( new Parameter( "p2", 7.84157781514097d, new CovariantPredictor( "sepal_width" ) ) ); table.addParameter( new Parameter( "p3", -20.0880078273996d, new CovariantPredictor( "petal_length" ) ) ); table.addParameter( new Parameter( "p4", -21.6076488529538d, new CovariantPredictor( "petal_width" ) ) ); regressionSpec.addRegressionTable( table ); PredictionRegressionFunction regressionFunction = new PredictionRegressionFunction( regressionSpec );
public ParameterExpression( Fields argumentsFields, Parameter parameter ) { this.name = parameter.getName(); this.beta = parameter.getBeta(); factorInvokers = new FactorInvoker[ parameter.getFactors().size() ]; for( int i = 0; i < parameter.getFactors().size(); i++ ) { FactorPredictor predictor = parameter.getFactors().get( i ); int pos = argumentsFields.getPos( predictor.getFieldName() ); factorInvokers[ i ] = new FactorInvoker( pos, predictor ); } covariantInvokers = new CovariantInvoker[ parameter.getCovariants().size() ]; for( int i = 0; i < parameter.getCovariants().size(); i++ ) { CovariantPredictor predictor = parameter.getCovariants().get( i ); int pos = argumentsFields.getPos( predictor.getFieldName() ); covariantInvokers[ i ] = new CovariantInvoker( pos, predictor ); } }
@Override public void operate( FlowProcess flowProcess, FunctionCall<Context<BaseRegressionFunction.ExpressionContext>> functionCall ) { TupleEntry arguments = functionCall.getArguments(); ExpressionEvaluator[] expressions = functionCall.getContext().payload.expressions; double[] results = functionCall.getContext().payload.results; for( int i = 0; i < expressions.length; i++ ) results[ i ] = expressions[ i ].calculate( arguments ); LOG.debug( "raw regression: {}", results ); for( int i = 0; i < expressions.length; i++ ) results[ i ] = getSpec().getLinkFunction().calculate( results[ i ] ); LOG.debug( "link regression: {}", results ); results = getSpec().getNormalization().normalize( results ); LOG.debug( "probabilities: {}", results ); double max = Doubles.max( results ); int index = Doubles.indexOf( results, max ); String category = expressions[ index ].getTargetCategory(); LOG.debug( "category: {}", category ); if( !getSpec().getModelSchema().isIncludePredictedCategories() ) { functionCall.getOutputCollector().add( functionCall.getContext().result( category ) ); return; } Tuple result = functionCall.getContext().tuple; result.set( 0, category ); for( int i = 0; i < results.length; i++ ) result.set( i + 1, results[ i ] ); functionCall.getOutputCollector().add( result ); } }
ExpressionEvaluator bind( Fields argumentFields ) { if( isNoOp() ) return new ExpressionEvaluator( targetCategory ); ParameterExpression[] expressions = new ParameterExpression[ parameters.size() ]; int count = 0; for( Parameter parameter : parameters.values() ) expressions[ count++ ] = parameter.createExpression( argumentFields ); return new ExpressionEvaluator( targetCategory, expressions ); }
@Override public void prepare( FlowProcess flowProcess, OperationCall<Context<ExpressionContext>> operationCall ) { super.prepare( flowProcess, operationCall ); Fields argumentFields = operationCall.getArgumentFields(); operationCall.getContext().payload = new ExpressionContext(); operationCall.getContext().payload.expressions = getSpec().getRegressionTableEvaluators( argumentFields ); } }
public CategoricalRegressionFunction( GeneralRegressionSpec regressionSpec ) { super( regressionSpec ); if( regressionSpec.getNormalization() == null ) throw new IllegalArgumentException( "normalization may not be null" ); ModelSchema modelSchema = regressionSpec.getModelSchema(); DataField predictedField = modelSchema.getPredictedField( modelSchema.getPredictedFieldNames().get( 0 ) ); if( !( predictedField instanceof CategoricalDataField ) ) throw new IllegalArgumentException( "predicted field must be categorical" ); if( ( (CategoricalDataField) predictedField ).getCategories().size() != regressionSpec.getRegressionTables().size() ) throw new IllegalArgumentException( "predicted field categories must be same size as the number of regression tables" ); }
@Override public void operate( FlowProcess flowProcess, FunctionCall<Context<BaseRegressionFunction.ExpressionContext>> functionCall ) { ExpressionEvaluator evaluator = functionCall.getContext().payload.expressions[ 0 ]; LinkFunction linkFunction = getSpec().linkFunction; double result = evaluator.calculate( functionCall.getArguments() ); double linkResult = linkFunction.calculate( result ); LOG.debug( "result: {}", linkResult ); functionCall.getOutputCollector().add( functionCall.getContext().result( linkResult ) ); } }
public PredictionRegressionFunction( GeneralRegressionSpec param ) { super( param ); if( getSpec().getRegressionTables().size() != 1 ) throw new IllegalArgumentException( "regression function only supports a single table, got: " + getSpec().getRegressionTables().size() ); }
public ExpressionEvaluator[] getRegressionTableEvaluators( Fields argumentFields ) { List<RegressionTable> tables = new ArrayList<RegressionTable>( regressionTables ); final DataField predictedField = getModelSchema().getPredictedField( getModelSchema().getPredictedFieldNames().get( 0 ) ); // order tables in category order as this is the declared field name order if( predictedField instanceof CategoricalDataField ) { Ordering<RegressionTable> ordering = Ordering.natural().onResultOf( new Function<RegressionTable, Comparable>() { private List<String> categories = ( (CategoricalDataField) predictedField ).getCategories(); @Override public Comparable apply( RegressionTable regressionTable ) { return categories.indexOf( regressionTable.getTargetCategory() ); } } ); Collections.sort( tables, ordering ); } ExpressionEvaluator[] evaluators = new ExpressionEvaluator[ tables.size() ]; for( int i = 0; i < tables.size(); i++ ) evaluators[ i ] = tables.get( i ).bind( argumentFields ); return evaluators; }
@Override public Comparable apply( RegressionTable regressionTable ) { return categories.indexOf( regressionTable.getTargetCategory() ); } } );
public boolean isNoOp() { for( Parameter parameter : parameters.values() ) { if( !parameter.isNoOp() ) return false; } return true; }
/** * Returns the corresponding LinkFunction * * @param functionName String * @return LinkFunction */ public static LinkFunction getFunction( String functionName ) { for( LinkFunction lf : values() ) if( lf.function.matches( functionName ) ) return lf; return LinkFunction.NONE; }
@Override public void prepare( FlowProcess flowProcess, OperationCall<Context<ExpressionContext>> operationCall ) { super.prepare( flowProcess, operationCall ); // cache the result array operationCall.getContext().payload.results = new double[ operationCall.getContext().payload.expressions.length ]; }