Tensorflow
Scio supports several methods of reading and writing Tensorflow records.
Reading
Depending on your input format, and if you need to provide a schema or not, there are various ways to read Tensorflow files.
tfRecordFile
reads entire TFRecord
files into byte array elements in the pipeline, tfRecordExampleFile
(or tfRecordExampleFileWithSchema
) will read Example
instances, and tfRecordSequenceExampleFile
(or tfRecordSequenceExampleFileWithSchema
) will read SequenceExample
instances:
import com.spotify.scio.ScioContext
import com.spotify.scio.values.SCollection
import com.spotify.scio.tensorflow._
import org.tensorflow.proto.example.{Example, SequenceExample}
val sc: ScioContext = ???
val recordBytes: SCollection[Array[Byte]] = sc.tfRecordFile("gs://input-record-path")
val examples: SCollection[Example] = sc.tfRecordExampleFile("gs://input-example-path")
val sequenceExamples: SCollection[SequenceExample] = sc.tfRecordSequenceExampleFile("gs://input-sequence-example-path")
Writing
Similar to reading, there are multiple ways to write Tensorflow files, depending on the format of the elements to be output. Each of these write methods is called saveAsTfRecordFile
, but only one variant of the method is available based on the element type.
- For
SCollection[T]
whereT
is a subclass ofExample
:saveAsTfRecordFile
- For
SCollection[Seq[T]]
whereT
is a subclass ofExample
:saveAsTfRecordFile
- For
SCollection[T]
whereT
is a subclass ofSequenceExample
:saveAsTfRecordFile
- For
SCollection[Array[Byte]]
, where it is recommended that the bytes are a serializedExample
:saveAsTfRecordFile
import com.spotify.scio.values.SCollection
import com.spotify.scio.tensorflow._
import org.tensorflow.proto.example.{Example, SequenceExample}
val recordBytes: SCollection[Array[Byte]] = ???
val examples: SCollection[Example] = ???
val seqExamples: SCollection[Seq[Example]] = ???
val sequenceExamples: SCollection[SequenceExample] = ???
recordBytes.saveAsTfRecordFile("gs://output-record-path")
examples.saveAsTfRecordFile("gs://output-example-path")
seqExamples.saveAsTfRecordFile("gs://output-seq-example-path")
sequenceExamples.saveAsTfRecordFile("gs://output-sequence-example-path")
Prediction/inference
Scio supports preforming inference on a saved Tensorflow model.
For an SCollection
of an arbitrary user type, predictions can be made against the raw model via predict
or using the model’s SignatureDefs with predictWithSigDef
:
import com.spotify.scio.values.SCollection
import com.spotify.scio.tensorflow._
import com.spotify.zoltar.tf.TensorFlowModel
import org.tensorflow._
import org.tensorflow.proto.example.Example
case class A()
case class B()
def toTensors(a: A): Map[String, Tensor] = ???
def fromTensors(a: A, tensors: Map[String, Tensor]): B = ???
val elements: SCollection[A] = ???
val options: TensorFlowModel.Options = ???
val fetchOpts: Seq[String] = ???
val result: SCollection[B] = elements.predict[B]("gs://model-path", fetchOpts, options)(toTensors)(fromTensors)
val b: SCollection[B] = elements.predictWithSigDef[B]("gs://model-path", options)(toTensors)(fromTensors _)
For an SCollection
of some subclass of Example
, a prediction can be made via predictTfExamples
:
import com.spotify.scio.values.SCollection
import com.spotify.scio.tensorflow._
import com.spotify.zoltar.tf.TensorFlowModel
import org.tensorflow._
import org.tensorflow.proto.example.Example
val exampleElements: SCollection[Example] = ???
val options: TensorFlowModel.Options = ???
def toExample(in: Example, tensors: Map[String, Tensor]): Example = ???
val c: SCollection[Example] = exampleElements.predictTfExamples[Example]("gs://model-path", options) {
case (a, tensors) => toExample(a, tensors)
}