/** * Creates the transform for a {@link ParDo}-compatible {@link AppliedPTransform}. * * <p>The input may generally be a deserialized transform so it may not actually be a {@link * ParDo}. Instead {@link ParDoTranslation} will be used to extract fields. */ @SuppressWarnings({"unchecked", "rawtypes"}) public static <InputT, OutputT> SplittableParDo<InputT, OutputT, ?> forAppliedParDo( AppliedPTransform<PCollection<InputT>, PCollectionTuple, ?> parDo) { checkArgument(parDo != null, "parDo must not be null"); try { Map<TupleTag<?>, Coder<?>> outputTagsToCoders = Maps.newHashMap(); for (Map.Entry<TupleTag<?>, PValue> entry : parDo.getOutputs().entrySet()) { outputTagsToCoders.put(entry.getKey(), ((PCollection) entry.getValue()).getCoder()); } return new SplittableParDo( ParDoTranslation.getDoFn(parDo), ParDoTranslation.getSideInputs(parDo), ParDoTranslation.getMainOutputTag(parDo), ParDoTranslation.getAdditionalOutputTags(parDo), outputTagsToCoders); } catch (IOException exc) { throw new RuntimeException(exc); } }
@Override public Map<String, RunnerApi.StateSpec> translateStateSpecs(SdkComponents components) throws IOException { Map<String, RunnerApi.StateSpec> stateSpecs = new HashMap<>(); for (Map.Entry<String, DoFnSignature.StateDeclaration> state : signature.stateDeclarations().entrySet()) { RunnerApi.StateSpec spec = ParDoTranslation.translateStateSpec( getStateSpecOrThrow(state.getValue(), doFn), components); stateSpecs.put(state.getKey(), spec); } return stateSpecs; }
@Override public boolean matches(AppliedPTransform<?, ?, ?> application) { if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals( PTransformTranslation.urnForTransformOrNull(application.getTransform()))) { try { return ParDoTranslation.usesStateOrTimers(application); } catch (IOException e) { throw new RuntimeException( String.format( "Transform with URN %s could not be translated", PTransformTranslation.PAR_DO_TRANSFORM_URN), e); } } return false; }
public static TupleTag<?> getMainOutputTag(AppliedPTransform<?, ?, ?> application) throws IOException { PTransform<?, ?> transform = application.getTransform(); if (transform instanceof ParDo.MultiOutput) { return ((ParDo.MultiOutput<?, ?>) transform).getMainOutputTag(); } return getMainOutputTag(getParDoPayload(application)); }
public static DoFn<?, ?> getDoFn(AppliedPTransform<?, ?, ?> application) throws IOException { PTransform<?, ?> transform = application.getTransform(); if (transform instanceof ParDo.MultiOutput) { return ((ParDo.MultiOutput<?, ?>) transform).getFn(); } return getDoFn(getParDoPayload(application)); }
DoFn<InputT, OutputT> doFn; try { doFn = (DoFn<InputT, OutputT>) ParDoTranslation.getDoFn(context.getCurrentTransform()); } catch (IOException e) { throw new RuntimeException(e); mainOutputTag = ParDoTranslation.getMainOutputTag(context.getCurrentTransform()); } catch (IOException e) { throw new RuntimeException(e); sideInputs = ParDoTranslation.getSideInputs(context.getCurrentTransform()); } catch (IOException e) { throw new RuntimeException(e); boolean usesStateOrTimers; try { usesStateOrTimers = ParDoTranslation.usesStateOrTimers(context.getCurrentTransform()); } catch (IOException e) { throw new RuntimeException(e);
try { final AppliedPTransform pTransform = beamNode.toAppliedPTransform(ctx.getPipeline()); final DoFn doFn = ParDoTranslation.getDoFn(pTransform); final TupleTag mainOutputTag = ParDoTranslation.getMainOutputTag(pTransform); final TupleTagList additionalOutputTags = ParDoTranslation.getAdditionalOutputTags(pTransform);
@Test public void testToProto() throws Exception { SdkComponents components = SdkComponents.create(); components.registerEnvironment(Environments.createDockerEnvironment("java")); ParDoPayload payload = ParDoTranslation.translateParDo(parDo, p, components); assertThat(ParDoTranslation.getDoFn(payload), equalTo(parDo.getFn())); assertThat(ParDoTranslation.getMainOutputTag(payload), equalTo(parDo.getMainOutputTag())); for (PCollectionView<?> view : parDo.getSideInputs()) { payload.getSideInputsOrThrow(view.getTagInternal().getId()); } }
public static TupleTagList getAdditionalOutputTags(AppliedPTransform<?, ?, ?> application) throws IOException { PTransform<?, ?> transform = application.getTransform(); if (transform instanceof ParDo.MultiOutput) { return ((ParDo.MultiOutput<?, ?>) transform).getAdditionalOutputTags(); } RunnerApi.PTransform protoTransform = PTransformTranslation.toProto( application, SdkComponents.create(application.getPipeline().getOptions())); ParDoPayload payload = ParDoPayload.parseFrom(protoTransform.getSpec().getPayload()); TupleTag<?> mainOutputTag = getMainOutputTag(payload); Set<String> outputTags = Sets.difference( protoTransform.getOutputsMap().keySet(), Collections.singleton(mainOutputTag.getId())); ArrayList<TupleTag<?>> additionalOutputTags = new ArrayList<>(); for (String outputTag : outputTags) { additionalOutputTags.add(new TupleTag<>(outputTag)); } return TupleTagList.of(additionalOutputTags); }
@Test public void testStateSpecToFromProto() throws Exception { // Encode SdkComponents sdkComponents = SdkComponents.create(); sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java")); RunnerApi.StateSpec stateSpecProto = ParDoTranslation.translateStateSpec(stateSpec, sdkComponents); // Decode RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(sdkComponents.toComponents()); StateSpec<?> deserializedStateSpec = ParDoTranslation.fromProto(stateSpecProto, rehydratedComponents); assertThat(stateSpec, equalTo(deserializedStateSpec)); } }
@Override public SdkFunctionSpec translateDoFn(SdkComponents newComponents) { return ParDoTranslation.translateDoFn(fn, pke.getMainOutputTag(), newComponents); }
public static RunnerApi.PCollection getMainInput( RunnerApi.PTransform ptransform, Components components) throws IOException { checkArgument( ptransform.getSpec().getUrn().equals(PAR_DO_TRANSFORM_URN), "Unexpected payload type %s", ptransform.getSpec().getUrn()); return components.getPcollectionsOrThrow( ptransform.getInputsOrThrow(getMainInputName(ptransform))); }
public static TupleTag<?> getMainOutputTag(ParDoPayload payload) throws InvalidProtocolBufferException { return doFnAndMainOutputTagFromProto(payload.getDoFn()).getMainOutputTag(); }
ParDoSingle.class.getSimpleName())); return ParDoTranslation.payloadForParDoLike( new ParDoTranslation.ParDoLike() { @Override
ParDoTranslation.getMainInput(protoTransform, components), equalTo(components.getPcollectionsOrThrow(mainInputId))); assertThat(ParDoTranslation.getMainInputName(protoTransform), equalTo("mainInputName"));
DoFn<InputT, OutputT> doFn; try { doFn = (DoFn<InputT, OutputT>) ParDoTranslation.getDoFn(context.getCurrentTransform()); } catch (IOException e) { throw new RuntimeException(e); mainOutputTag = ParDoTranslation.getMainOutputTag(context.getCurrentTransform()); } catch (IOException e) { throw new RuntimeException(e); sideInputs = ParDoTranslation.getSideInputs(context.getCurrentTransform()); } catch (IOException e) { throw new RuntimeException(e); boolean usesStateOrTimers; try { usesStateOrTimers = ParDoTranslation.usesStateOrTimers(context.getCurrentTransform()); } catch (IOException e) { throw new RuntimeException(e);
@Override public RunnerApi.SdkFunctionSpec translateDoFn(SdkComponents newComponents) { return ParDoTranslation.translateDoFn( parDo.getFn(), parDo.getMainOutputTag(), newComponents); }
/** Returns the name of the main input of the ptransform. */ public static String getMainInputName(RunnerApi.PTransformOrBuilder ptransform) throws IOException { checkArgument( ptransform.getSpec().getUrn().equals(PAR_DO_TRANSFORM_URN), "Unexpected payload type %s", ptransform.getSpec().getUrn()); ParDoPayload payload = ParDoPayload.parseFrom(ptransform.getSpec().getPayload()); return getMainInputName(ptransform, payload); }
public static DoFn<?, ?> getDoFn(ParDoPayload payload) throws InvalidProtocolBufferException { return doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn(); }
return payloadForParDoLike( new ParDoLike() { @Override