/** * @param miningSchema {@link MiningSchema} from a model * @return index of the {@link MiningField.UsageType#PREDICTED} feature */ public static Integer findTargetIndex(MiningSchema miningSchema) { List<MiningField> miningFields = miningSchema.getMiningFields(); for (int i = 0; i < miningFields.size(); i++) { if (miningFields.get(i).getUsageType() == MiningField.UsageType.PREDICTED) { return i; } } return null; }
@Test public void testBuildMiningSchema() { MiningSchema miningSchema = AppPMMLUtils.buildMiningSchema(buildTestSchema()); List<MiningField> miningFields = miningSchema.getMiningFields(); assertEquals(4, miningFields.size()); String[] fieldNames = { "foo", "bar", "baz", "bing" }; for (int i = 0; i < fieldNames.length; i++) { assertEquals(fieldNames[i], miningFields.get(i).getName().getValue()); } assertEquals(MiningField.UsageType.SUPPLEMENTARY, miningFields.get(0).getUsageType()); assertEquals(MiningField.UsageType.PREDICTED, miningFields.get(1).getUsageType()); assertEquals(MiningField.UsageType.SUPPLEMENTARY, miningFields.get(2).getUsageType()); assertEquals(MiningField.UsageType.ACTIVE, miningFields.get(3).getUsageType()); assertEquals(OpType.CATEGORICAL, miningFields.get(1).getOpType()); assertEquals(OpType.CONTINUOUS, miningFields.get(3).getOpType()); }
if (field.getUsageType() == MiningField.UsageType.ACTIVE && importances != null) { int predictorIndex = schema.featureToPredictorIndex(featureIndex); field.setImportance(importances[predictorIndex]);
assertEquals("Wrong usage type for feature " + featureName, MiningField.UsageType.PREDICTED, miningField.getUsageType()); } else { assertEquals("Wrong usage type for feature " + featureName, MiningField.UsageType.ACTIVE, miningField.getUsageType()); assertRange(miningField.getImportance(), 0.0, 1.0); assertEquals("Wrong usage type for feature " + featureName, MiningField.UsageType.SUPPLEMENTARY, miningField.getUsageType());
/** * @param miningSchema {@link MiningSchema} from a model * @return index of the {@link MiningField.UsageType#PREDICTED} feature */ public static Integer findTargetIndex(MiningSchema miningSchema) { List<MiningField> miningFields = miningSchema.getMiningFields(); for (int i = 0; i < miningFields.size(); i++) { if (miningFields.get(i).getUsageType() == MiningField.UsageType.PREDICTED) { return i; } } return null; }
private boolean isRootInMiningList(FieldName root, List<MiningField> miningList) { for(int i = 0; i < miningList.size(); i++) { MiningField mField = miningList.get(i); if(mField.getUsageType() != FieldUsageType.ACTIVE) continue; FieldName mFieldName = mField.getName(); if(root.equals(mFieldName)) { return true; } } return false; }
public static Integer getNumTargetMiningFields(MiningSchema miningSchema) { Integer cnt = 0; for(MiningField miningField: miningSchema.getMiningFields()) { if(miningField.getUsageType().equals(FieldUsageType.TARGET)) { cnt += 1; } } return cnt; }
public static Integer getNumActiveMiningFields(MiningSchema miningSchema) { Integer cnt = 0; for(MiningField miningField: miningSchema.getMiningFields()) { if(miningField.getUsageType().equals(FieldUsageType.ACTIVE)) { cnt += 1; } } return cnt; }
public static Integer getNumTargetMiningFields(MiningSchema miningSchema) { Integer cnt = 0; for(MiningField miningField: miningSchema.getMiningFields()) { if(miningField.getUsageType().equals(FieldUsageType.TARGET)) { cnt += 1; } } return cnt; }
private static List<String> getSchemaFieldViaUsageType(final MiningSchema schema, final FieldUsageType type) { List<String> targetFields = new ArrayList<String>(); for(MiningField f: schema.getMiningFields()) { FieldUsageType uType = f.getUsageType(); if(uType == type) targetFields.add(f.getName().getValue()); } return targetFields; }
public static Integer getNumActiveMiningFields(MiningSchema miningSchema) { Integer cnt = 0; for(MiningField miningField: miningSchema.getMiningFields()) { if(miningField.getUsageType().equals(FieldUsageType.ACTIVE)) { cnt += 1; } } return cnt; }
/** * This function returns all used field names based on the given mining * schema * * @param schema * the schema * @return field names */ public static List<String> getSchemaSelectedFields(final MiningSchema schema) { List<String> targetFields = new ArrayList<String>(); for(MiningField f: schema.getMiningFields()) { FieldUsageType uType = f.getUsageType(); if(uType == FieldUsageType.TARGET || uType == FieldUsageType.ACTIVE) targetFields.add(f.getName().getValue()); } return targetFields; }
MiningField.UsageType usageType = miningField.getUsageType(); switch(usageType){ case ACTIVE:
static private MiningField getTargetField(Model model){ MiningSchema miningSchema = model.getMiningSchema(); MiningField result = null; List<MiningField> miningFields = miningSchema.getMiningFields(); for(MiningField miningField : miningFields){ MiningField.UsageType usageType = miningField.getUsageType(); switch(usageType){ case TARGET: case PREDICTED: if(result != null){ throw new UnsupportedElementException(miningSchema); } result = miningField; break; default: break; } } return result; } }
private void processModel(Model model){ Set<Field<?>> targetFields = getTargetFields(); MiningSchema miningSchema = model.getMiningSchema(); if(miningSchema != null && miningSchema.hasMiningFields()){ Set<FieldName> targetFieldNames = new LinkedHashSet<>(); List<MiningField> miningFields = miningSchema.getMiningFields(); for(MiningField miningField : miningFields){ FieldName name = miningField.getName(); MiningField.UsageType usageType = miningField.getUsageType(); switch(usageType){ case TARGET: case PREDICTED: targetFieldNames.add(name); break; default: break; } } if(targetFieldNames.size() > 0){ Set<Field<?>> modelFields = getFields(model); targetFields.addAll(FieldUtil.selectAll(modelFields, targetFieldNames)); } } }
private void processModel(Model model){ Set<Field<?>> targetFields = getTargetFields(); MiningSchema miningSchema = model.getMiningSchema(); if(miningSchema != null && miningSchema.hasMiningFields()){ Set<FieldName> targetFieldNames = new LinkedHashSet<>(); List<MiningField> miningFields = miningSchema.getMiningFields(); for(MiningField miningField : miningFields){ FieldName name = miningField.getName(); MiningField.UsageType usageType = miningField.getUsageType(); switch(usageType){ case TARGET: case PREDICTED: targetFieldNames.add(name); break; default: break; } } if(targetFieldNames.size() > 0){ Set<Field<?>> modelFields = getFields(model); targetFields.addAll(FieldUtil.selectAll(modelFields, targetFieldNames)); } } }
/** * Based on the usage type, get the column indexes for corresponding fields * in the input data set * * @param pmml * the pmml model * @param type * the type * @return dic fields */ public static int[] getDicFieldIDViaType(PMML pmml, FieldUsageType type) { List<Integer> activeFields = new ArrayList<Integer>(); HashMap<String, Integer> dMap = new HashMap<String, Integer>(); int index = 0; for(DataField dField: pmml.getDataDictionary().getDataFields()) dMap.put(dField.getName().getValue(), index++); for(MiningField mField: pmml.getModels().get(0).getMiningSchema().getMiningFields()) { if(mField.getUsageType() == type) activeFields.add(dMap.get(mField.getName().getValue())); } return Ints.toArray(activeFields); }
protected List<InputField> createInputFields(MiningField.UsageType usageType){ M model = getModel(); MiningSchema miningSchema = model.getMiningSchema(); List<InputField> inputFields = new ArrayList<>(); if(miningSchema.hasMiningFields()){ List<MiningField> miningFields = miningSchema.getMiningFields(); for(MiningField miningField : miningFields){ FieldName name = miningField.getName(); if(!(miningField.getUsageType()).equals(usageType)){ continue; } Field<?> field = getDataField(name); if(field == null){ field = new VariableField(name); } InputField inputField = new InputField(field, miningField); inputFields.add(inputField); } } return ImmutableList.copyOf(inputFields); }
private Set<FieldName> processMiningModel(MiningModel miningModel){ Set<Field<?>> activeFields = DeepFieldResolverUtil.getActiveFields(this, miningModel); Set<FieldName> activeFieldNames = new HashSet<>(); Segmentation segmentation = miningModel.getSegmentation(); List<Segment> segments = segmentation.getSegments(); for(Segment segment : segments){ Model model = segment.getModel(); if(model == null){ continue; } MiningSchema miningSchema = model.getMiningSchema(); List<MiningField> miningFields = miningSchema.getMiningFields(); for(MiningField miningField : miningFields){ FieldName name = miningField.getName(); MiningField.UsageType usageType = miningField.getUsageType(); switch(usageType){ case ACTIVE: activeFieldNames.add(name); break; default: break; } } } Set<Field<?>> modelFields = getFields(miningModel); Set<Field<?>> activeModelFields = FieldUtil.selectAll(modelFields, activeFieldNames, true); activeFields.addAll(activeModelFields); expandDerivedFields(miningModel, activeFields); return FieldUtil.nameSet(activeFields); }
@Override protected FieldValue prepare(FieldName name, Object value){ ModelEvaluator<?> modelEvaluator = getModelEvaluator(); DataField dataField = modelEvaluator.getDataField(name); if(dataField == null){ throw new MissingFieldException(name); } MiningField miningField = modelEvaluator.getMiningField(name); if(miningField == null){ throw new InvisibleFieldException(name); } MiningField.UsageType usageType = miningField.getUsageType(); switch(usageType){ case ACTIVE: case GROUP: case ORDER: { return InputFieldUtil.prepareInputValue(dataField, miningField, value); } case PREDICTED: case TARGET: { return InputFieldUtil.prepareResidualInputValue(dataField, miningField, value); } default: throw new UnsupportedAttributeException(miningField, usageType); } }