Tensorflow

Scio supports several methods of reading and writing Tensorflow records.

Reading

Depending on your input format, and if you need to provide a schema or not, there are various ways to read Tensorflow files.

tfRecordFile reads entire TFRecord files into byte array elements in the pipeline, tfRecordExampleFile (or tfRecordExampleFileWithSchema) will read Example instances, and tfRecordSequenceExampleFile (or tfRecordSequenceExampleFileWithSchema) will read SequenceExample instances:

import com.spotify.scio.ScioContext
import com.spotify.scio.values.SCollection
import com.spotify.scio.tensorflow._
import org.tensorflow.proto.example.{Example, SequenceExample}

val sc: ScioContext = ???
val recordBytes: SCollection[Array[Byte]] = sc.tfRecordFile("gs://input-record-path")
val examples: SCollection[Example] = sc.tfRecordExampleFile("gs://input-example-path")
val sequenceExamples: SCollection[SequenceExample] = sc.tfRecordSequenceExampleFile("gs://input-sequence-example-path")

Writing

Similar to reading, there are multiple ways to write Tensorflow files, depending on the format of the elements to be output. Each of these write methods is called saveAsTfRecordFile, but only one variant of the method is available based on the element type.

import com.spotify.scio.values.SCollection
import com.spotify.scio.tensorflow._
import org.tensorflow.proto.example.{Example, SequenceExample}

val recordBytes: SCollection[Array[Byte]] = ???
val examples: SCollection[Example] = ???
val seqExamples: SCollection[Seq[Example]] = ???
val sequenceExamples: SCollection[SequenceExample] = ???

recordBytes.saveAsTfRecordFile("gs://output-record-path")
examples.saveAsTfRecordFile("gs://output-example-path")
seqExamples.saveAsTfRecordFile("gs://output-seq-example-path")
sequenceExamples.saveAsTfRecordFile("gs://output-sequence-example-path")

Prediction/inference

Scio supports preforming inference on a saved Tensorflow model.

For an SCollection of an arbitrary user type, predictions can be made against the raw model via predict or using the model’s SignatureDefs with predictWithSigDef:

import com.spotify.scio.values.SCollection
import com.spotify.scio.tensorflow._
import com.spotify.zoltar.tf.TensorFlowModel
import org.tensorflow._
import org.tensorflow.proto.example.Example

case class A()
case class B()

def toTensors(a: A): Map[String, Tensor] = ???
def fromTensors(a: A, tensors: Map[String, Tensor]): B = ???

val elements: SCollection[A] = ???
val options: TensorFlowModel.Options = ???
val fetchOpts: Seq[String] = ???

val result: SCollection[B] = elements.predict[B]("gs://model-path", fetchOpts, options)(toTensors)(fromTensors)
val b: SCollection[B] = elements.predictWithSigDef[B]("gs://model-path", options)(toTensors)(fromTensors _)

For an SCollection of some subclass of Example, a prediction can be made via predictTfExamples:

import com.spotify.scio.values.SCollection
import com.spotify.scio.tensorflow._
import com.spotify.zoltar.tf.TensorFlowModel
import org.tensorflow._
import org.tensorflow.proto.example.Example

val exampleElements: SCollection[Example] = ???
val options: TensorFlowModel.Options = ???
def toExample(in: Example, tensors: Map[String, Tensor]): Example = ???

val c: SCollection[Example] = exampleElements.predictTfExamples[Example]("gs://model-path", options) {
  case (a, tensors) => toExample(a, tensors)
}