{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-10-25T01:24:08.933415Z", "iopub.status.busy": "2024-10-25T01:24:08.933170Z", "iopub.status.idle": "2024-10-25T01:24:08.936980Z", "shell.execute_reply": "2024-10-25T01:24:08.936372Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "drGgRRpWf2Qm" }, "source": [ "# Working with sparse tensors" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "UIiXFIS4fj1m" }, "source": [ "When working with tensors that contain a lot of zero values, it is important to store them in a space- and time-efficient manner. Sparse tensors enable efficient storage and processing of tensors that contain a lot of zero values. Sparse tensors are used extensively in encoding schemes like [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) as part of data pre-processing in NLP applications and for pre-processing images with a lot of dark pixels in computer vision applications." ] }, { "cell_type": "markdown", "metadata": { "id": "A8XXQW3ENU5m" }, "source": [ "## Sparse tensors in TensorFlow\n", "\n", "TensorFlow represents sparse tensors through the `tf.sparse.SparseTensor` object. Currently, sparse tensors in TensorFlow are encoded using the coordinate list (COO) format. This encoding format is optimized for hyper-sparse matrices such as embeddings.\n", "\n", "The COO encoding for sparse tensors is comprised of:\n", "\n", " * `values`: A 1D tensor with shape `[N]` containing all nonzero values.\n", " * `indices`: A 2D tensor with shape `[N, rank]`, containing the indices of the nonzero values.\n", " * `dense_shape`: A 1D tensor with shape `[rank]`, specifying the shape of the tensor.\n", "\n", "A ***nonzero*** value in the context of a `tf.sparse.SparseTensor` is a value that's not explicitly encoded. It is possible to explicitly include zero values in the `values` of a COO sparse matrix, but these \"explicit zeros\" are generally not included when referring to nonzero values in a sparse tensor.\n", "\n", "Note: `tf.sparse.SparseTensor` does not require that indices/values be in any particular order, but several ops assume that they're in row-major order. Use `tf.sparse.reorder` to create a copy of the sparse tensor that is sorted in the canonical row-major order. " ] }, { "cell_type": "markdown", "metadata": { "id": "6Aq7ruwlyz79" }, "source": [ "## Creating a `tf.sparse.SparseTensor`\n", "\n", "Construct sparse tensors by directly specifying their `values`, `indices`, and `dense_shape`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:08.940115Z", "iopub.status.busy": "2024-10-25T01:24:08.939878Z", "iopub.status.idle": "2024-10-25T01:24:11.308524Z", "shell.execute_reply": "2024-10-25T01:24:11.307659Z" }, "id": "SI2Mv3tihcmY" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-10-25 01:24:09.202320: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1729819449.223893 16549 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1729819449.230517 16549 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:11.311898Z", "iopub.status.busy": "2024-10-25T01:24:11.311494Z", "iopub.status.idle": "2024-10-25T01:24:11.938402Z", "shell.execute_reply": "2024-10-25T01:24:11.937643Z" }, "id": "vqQKGva4zSCs" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "W0000 00:00:1729819451.911465 16549 gpu_device.cc:2344] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n", "Skipping registering GPU devices...\n" ] } ], "source": [ "st1 = tf.sparse.SparseTensor(indices=[[0, 3], [2, 4]],\n", " values=[10, 20],\n", " dense_shape=[3, 10])" ] }, { "cell_type": "markdown", "metadata": { "id": "l9eJeh31fWyr" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "-M3fMTFL0hXa" }, "source": [ "When you use the `print()` function to print a sparse tensor, it shows the contents of the three component tensors:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:11.941341Z", "iopub.status.busy": "2024-10-25T01:24:11.941085Z", "iopub.status.idle": "2024-10-25T01:24:11.945093Z", "shell.execute_reply": "2024-10-25T01:24:11.944491Z" }, "id": "3oHWtmsBMLAI" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SparseTensor(indices=tf.Tensor(\n", "[[0 3]\n", " [2 4]], shape=(2, 2), dtype=int64), values=tf.Tensor([10 20], shape=(2,), dtype=int32), dense_shape=tf.Tensor([ 3 10], shape=(2,), dtype=int64))\n" ] } ], "source": [ "print(st1)" ] }, { "cell_type": "markdown", "metadata": { "id": "qqePKJG6MNWk" }, "source": [ "It is easier to understand the contents of a sparse tensor if the nonzero `values` are aligned with their corresponding `indices`. Define a helper function to pretty-print sparse tensors such that each nonzero value is shown on its own line." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:11.947693Z", "iopub.status.busy": "2024-10-25T01:24:11.947446Z", "iopub.status.idle": "2024-10-25T01:24:11.951454Z", "shell.execute_reply": "2024-10-25T01:24:11.950829Z" }, "id": "R_xFYuOo1ZE_" }, "outputs": [], "source": [ "def pprint_sparse_tensor(st):\n", " s = \"\"" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:11.953878Z", "iopub.status.busy": "2024-10-25T01:24:11.953666Z", "iopub.status.idle": "2024-10-25T01:24:11.959992Z", "shell.execute_reply": "2024-10-25T01:24:11.959395Z" }, "id": "be4Dyiqt0fEH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "print(pprint_sparse_tensor(st1))" ] }, { "cell_type": "markdown", "metadata": { "id": "3FBt8qk_zmz5" }, "source": [ "You can also construct sparse tensors from dense tensors by using `tf.sparse.from_dense`, and convert them back to dense tensors by using `tf.sparse.to_dense`." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:11.962469Z", "iopub.status.busy": "2024-10-25T01:24:11.962237Z", "iopub.status.idle": "2024-10-25T01:24:11.972002Z", "shell.execute_reply": "2024-10-25T01:24:11.971444Z" }, "id": "cYwuCuNMf0Fu" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "st2 = tf.sparse.from_dense([[1, 0, 0, 8], [0, 0, 0, 0], [0, 0, 3, 0]])\n", "print(pprint_sparse_tensor(st2))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:11.974298Z", "iopub.status.busy": "2024-10-25T01:24:11.974083Z", "iopub.status.idle": "2024-10-25T01:24:11.979060Z", "shell.execute_reply": "2024-10-25T01:24:11.978525Z" }, "id": "eFVPrwNPzyZw" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[1 0 0 8]\n", " [0 0 0 0]\n", " [0 0 3 0]], shape=(3, 4), dtype=int32)\n" ] } ], "source": [ "st3 = tf.sparse.to_dense(st2)\n", "print(st3)" ] }, { "cell_type": "markdown", "metadata": { "id": "GeuvyL_Z0Mwh" }, "source": [ "## Manipulating sparse tensors\n", "\n", "Use the utilities in the `tf.sparse` package to manipulate sparse tensors. Ops like `tf.math.add` that you can use for arithmetic manipulation of dense tensors do not work with sparse tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "LMYW4U4Qavvd" }, "source": [ "Add sparse tensors of the same shape by using `tf.sparse.add`. " ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:11.981529Z", "iopub.status.busy": "2024-10-25T01:24:11.981321Z", "iopub.status.idle": "2024-10-25T01:24:11.989158Z", "shell.execute_reply": "2024-10-25T01:24:11.988494Z" }, "id": "vJwuSQIjayiN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "st_a = tf.sparse.SparseTensor(indices=[[0, 2], [3, 4]],\n", " values=[31, 2], \n", " dense_shape=[4, 10])\n", "\n", "st_b = tf.sparse.SparseTensor(indices=[[0, 2], [3, 0]],\n", " values=[56, 38],\n", " dense_shape=[4, 10])\n", "\n", "st_sum = tf.sparse.add(st_a, st_b)\n", "\n", "print(pprint_sparse_tensor(st_sum))" ] }, { "cell_type": "markdown", "metadata": { "id": "ls8_aQvnqZMj" }, "source": [ "Use `tf.sparse.sparse_dense_matmul` to multiply sparse tensors with dense matrices." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:11.991508Z", "iopub.status.busy": "2024-10-25T01:24:11.991284Z", "iopub.status.idle": "2024-10-25T01:24:11.996643Z", "shell.execute_reply": "2024-10-25T01:24:11.996055Z" }, "id": "S0tWRLiE04uL" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 78]\n", " [162]], shape=(2, 1), dtype=int32)\n" ] } ], "source": [ "st_c = tf.sparse.SparseTensor(indices=([0, 1], [1, 0], [1, 1]),\n", " values=[13, 15, 17],\n", " dense_shape=(2,2))\n", "\n", "mb = tf.constant([[4], [6]])\n", "product = tf.sparse.sparse_dense_matmul(st_c, mb)\n", "\n", "print(product)" ] }, { "cell_type": "markdown", "metadata": { "id": "9hxClYvfceZA" }, "source": [ "Put sparse tensors together by using `tf.sparse.concat` and take them apart by using `tf.sparse.slice`.\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:11.999425Z", "iopub.status.busy": "2024-10-25T01:24:11.999201Z", "iopub.status.idle": "2024-10-25T01:24:12.008345Z", "shell.execute_reply": "2024-10-25T01:24:12.007691Z" }, "id": "cp4NEW_5yLEY" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0]\n", " [0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0]\n", " [0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0]\n", " [0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]], shape=(8, 17), dtype=int32)\n" ] } ], "source": [ "sparse_pattern_A = tf.sparse.SparseTensor(indices = [[2,4], [3,3], [3,4], [4,3], [4,4], [5,4]],\n", " values = [1,1,1,1,1,1],\n", " dense_shape = [8,5])\n", "sparse_pattern_B = tf.sparse.SparseTensor(indices = [[0,2], [1,1], [1,3], [2,0], [2,4], [2,5], [3,5], \n", " [4,5], [5,0], [5,4], [5,5], [6,1], [6,3], [7,2]],\n", " values = [1,1,1,1,1,1,1,1,1,1,1,1,1,1],\n", " dense_shape = [8,6])\n", "sparse_pattern_C = tf.sparse.SparseTensor(indices = [[3,0], [4,0]],\n", " values = [1,1],\n", " dense_shape = [8,6])\n", "\n", "sparse_patterns_list = [sparse_pattern_A, sparse_pattern_B, sparse_pattern_C]\n", "sparse_pattern = tf.sparse.concat(axis=1, sp_inputs=sparse_patterns_list)\n", "print(tf.sparse.to_dense(sparse_pattern))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.010516Z", "iopub.status.busy": "2024-10-25T01:24:12.010303Z", "iopub.status.idle": "2024-10-25T01:24:12.017574Z", "shell.execute_reply": "2024-10-25T01:24:12.016921Z" }, "id": "XmE87XVPWPmc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[0 0 0 0 0]\n", " [0 0 0 0 0]\n", " [0 0 0 0 1]\n", " [0 0 0 1 1]\n", " [0 0 0 1 1]\n", " [0 0 0 0 1]\n", " [0 0 0 0 0]\n", " [0 0 0 0 0]], shape=(8, 5), dtype=int32)\n", "tf.Tensor(\n", "[[0]\n", " [0]\n", " [1]\n", " [1]\n", " [1]\n", " [1]\n", " [0]\n", " [0]], shape=(8, 1), dtype=int32)\n", "tf.Tensor([], shape=(8, 0), dtype=int32)\n" ] } ], "source": [ "sparse_slice_A = tf.sparse.slice(sparse_pattern_A, start = [0,0], size = [8,5])\n", "sparse_slice_B = tf.sparse.slice(sparse_pattern_B, start = [0,5], size = [8,6])\n", "sparse_slice_C = tf.sparse.slice(sparse_pattern_C, start = [0,10], size = [8,6])\n", "print(tf.sparse.to_dense(sparse_slice_A))\n", "print(tf.sparse.to_dense(sparse_slice_B))\n", "print(tf.sparse.to_dense(sparse_slice_C))" ] }, { "cell_type": "markdown", "metadata": { "id": "37SOx7wB1eSX" }, "source": [ "If you're using TensorFlow 2.4 or above, use `tf.sparse.map_values` for elementwise operations on nonzero values in sparse tensors." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.019863Z", "iopub.status.busy": "2024-10-25T01:24:12.019653Z", "iopub.status.idle": "2024-10-25T01:24:12.024509Z", "shell.execute_reply": "2024-10-25T01:24:12.023935Z" }, "id": "daZaPkkA1d09" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 6 0 0 13]\n", " [ 0 0 0 0]\n", " [ 0 0 8 0]], shape=(3, 4), dtype=int32)\n" ] } ], "source": [ "st2_plus_5 = tf.sparse.map_values(tf.add, st2, 5)\n", "print(tf.sparse.to_dense(st2_plus_5))" ] }, { "cell_type": "markdown", "metadata": { "id": "3zkRcxeo2Elw" }, "source": [ "Note that only the nonzero values were modified – the zero values stay zero.\n", "\n", "Equivalently, you can follow the design pattern below for earlier versions of TensorFlow:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.026727Z", "iopub.status.busy": "2024-10-25T01:24:12.026520Z", "iopub.status.idle": "2024-10-25T01:24:12.031106Z", "shell.execute_reply": "2024-10-25T01:24:12.030548Z" }, "id": "bFSNOOqC0ySb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 6 0 0 13]\n", " [ 0 0 0 0]\n", " [ 0 0 8 0]], shape=(3, 4), dtype=int32)\n" ] } ], "source": [ "st2_plus_5 = tf.sparse.SparseTensor(\n", " st2.indices,\n", " st2.values + 5,\n", " st2.dense_shape)\n", "print(tf.sparse.to_dense(st2_plus_5))" ] }, { "cell_type": "markdown", "metadata": { "id": "GFhO2ZZ53ga1" }, "source": [ "## Using `tf.sparse.SparseTensor` with other TensorFlow APIs\n", "\n", "Sparse tensors work transparently with these TensorFlow APIs:\n", "\n", "* `tf.keras`\n", "* `tf.data`\n", "* `tf.Train.Example` protobuf\n", "* `tf.function`\n", "* `tf.while_loop`\n", "* `tf.cond`\n", "* `tf.identity`\n", "* `tf.cast`\n", "* `tf.print`\n", "* `tf.saved_model`\n", "* `tf.io.serialize_sparse`\n", "* `tf.io.serialize_many_sparse`\n", "* `tf.io.deserialize_many_sparse`\n", "* `tf.math.abs`\n", "* `tf.math.negative`\n", "* `tf.math.sign`\n", "* `tf.math.square`\n", "* `tf.math.sqrt`\n", "* `tf.math.erf`\n", "* `tf.math.tanh`\n", "* `tf.math.bessel_i0e`\n", "* `tf.math.bessel_i1e`\n", "\n", "Examples are shown below for a few of the above APIs." ] }, { "cell_type": "markdown", "metadata": { "id": "6uNUl7EgSYGC" }, "source": [ "### `tf.keras`\n", "\n", "A subset of the `tf.keras` API supports sparse tensors without expensive casting or conversion ops. The Keras API lets you pass sparse tensors as inputs to a Keras model. Set `sparse=True` when calling `tf.keras.Input` or `tf.keras.layers.InputLayer`. You can pass sparse tensors between Keras layers, and also have Keras models return them as outputs. If you use sparse tensors in `tf.keras.layers.Dense` layers in your model, they will output dense tensors.\n", "\n", "The example below shows you how to pass a sparse tensor as an input to a Keras model if you use only layers that support sparse inputs." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.033790Z", "iopub.status.busy": "2024-10-25T01:24:12.033584Z", "iopub.status.idle": "2024-10-25T01:24:12.293461Z", "shell.execute_reply": "2024-10-25T01:24:12.292725Z" }, "id": "E8za5DK8vfo7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 86ms/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 87ms/step\n" ] }, { "data": { "text/plain": [ "array([[ 1.8707037e-02, 7.7025330e-01, 2.2425324e-01, -1.9139588e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [-6.2435389e-02, -4.5783034e-01, 1.2970567e-03, -1.8046319e-01],\n", " [-8.0019468e-01, 9.0452707e-01, 2.1884918e-02, -1.3622781e+00]],\n", " dtype=float32)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.keras.Input(shape=(4,), sparse=True)\n", "y = tf.keras.layers.Dense(4)(x)\n", "model = tf.keras.Model(x, y)\n", "\n", "sparse_data = tf.sparse.SparseTensor(\n", " indices = [(0,0),(0,1),(0,2),\n", " (4,3),(5,0),(5,1)],\n", " values = [1,1,1,1,1,1],\n", " dense_shape = (6,4)\n", ")\n", "\n", "model(sparse_data)\n", "\n", "model.predict(sparse_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZtVYmr7dt0-x" }, "source": [ "### `tf.data`\n", "\n", "The `tf.data` API enables you to build complex input pipelines from simple, reusable pieces. Its core data structure is `tf.data.Dataset`, which represents a sequence of elements in which each element consists of one or more components.\n", "\n", "#### Building datasets with sparse tensors\n", "\n", "Build datasets from sparse tensors using the same methods that are used to build them from `tf.Tensor`s or NumPy arrays, such as `tf.data.Dataset.from_tensor_slices`. This op preserves the sparsity (or sparse nature) of the data." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.296374Z", "iopub.status.busy": "2024-10-25T01:24:12.296118Z", "iopub.status.idle": "2024-10-25T01:24:12.312776Z", "shell.execute_reply": "2024-10-25T01:24:12.312147Z" }, "id": "3y9tiwuZ5oTD" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ "dataset = tf.data.Dataset.from_tensor_slices(sparse_data)\n", "for element in dataset: \n", " print(pprint_sparse_tensor(element))" ] }, { "cell_type": "markdown", "metadata": { "id": "hFaY5Org59qk" }, "source": [ "#### Batching and unbatching datasets with sparse tensors\n", "\n", "You can batch (combine consecutive elements into a single element) and unbatch datasets with sparse tensors using the `Dataset.batch` and `Dataset.unbatch` methods respectively." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.315625Z", "iopub.status.busy": "2024-10-25T01:24:12.315400Z", "iopub.status.idle": "2024-10-25T01:24:12.331072Z", "shell.execute_reply": "2024-10-25T01:24:12.330475Z" }, "id": "WkKE0VY66Ii2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n" ] } ], "source": [ "batched_dataset = dataset.batch(2)\n", "for element in batched_dataset:\n", " print (pprint_sparse_tensor(element))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.333379Z", "iopub.status.busy": "2024-10-25T01:24:12.333126Z", "iopub.status.idle": "2024-10-25T01:24:12.365564Z", "shell.execute_reply": "2024-10-25T01:24:12.364765Z" }, "id": "ikZzPxl56bx1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ "unbatched_dataset = batched_dataset.unbatch()\n", "for element in unbatched_dataset:\n", " print (pprint_sparse_tensor(element))" ] }, { "cell_type": "markdown", "metadata": { "id": "6ywfpD_EIMd3" }, "source": [ "You can also use `tf.data.experimental.dense_to_sparse_batch` to batch dataset elements of varying shapes into sparse tensors. " ] }, { "cell_type": "markdown", "metadata": { "id": "oB8QKh7p6ltl" }, "source": [ "#### Transforming Datasets with sparse tensors\n", "\n", "Transform and create sparse tensors in Datasets using `Dataset.map`." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.368582Z", "iopub.status.busy": "2024-10-25T01:24:12.367960Z", "iopub.status.idle": "2024-10-25T01:24:12.405999Z", "shell.execute_reply": "2024-10-25T01:24:12.405333Z" }, "id": "E5lhicwef7Ah" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ "transform_dataset = dataset.map(lambda x: x*2)\n", "for i in transform_dataset:\n", " print(pprint_sparse_tensor(i))" ] }, { "cell_type": "markdown", "metadata": { "id": "DBfQvIVutp65" }, "source": [ "### tf.train.Example\n", "\n", "`tf.train.Example` is a standard protobuf encoding for TensorFlow data. When using sparse tensors with `tf.train.Example`, you can:\n", "\n", "* Read variable-length data into a `tf.sparse.SparseTensor` using `tf.io.VarLenFeature`. However, you should consider using `tf.io.RaggedFeature` instead.\n", "\n", "* Read arbitrary sparse data into a `tf.sparse.SparseTensor` using `tf.io.SparseFeature`, which uses three separate feature keys to store the `indices`, `values`, and `dense_shape`." ] }, { "cell_type": "markdown", "metadata": { "id": "Pir2Xt3nSe-4" }, "source": [ "### `tf.function`\n", "\n", "The `tf.function` decorator precomputes TensorFlow graphs for Python functions, which can substantially improve the performance of your TensorFlow code. Sparse tensors work transparently with both `tf.function` and [concrete functions](https://www.tensorflow.org/guide/function#obtaining_concrete_functions)." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.408603Z", "iopub.status.busy": "2024-10-25T01:24:12.408350Z", "iopub.status.idle": "2024-10-25T01:24:12.448348Z", "shell.execute_reply": "2024-10-25T01:24:12.447707Z" }, "id": "6jXDueTOSeYO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[225 0 0]\n", " [ 0 0 0]\n", " [ 0 0 625]], shape=(3, 3), dtype=int32)\n" ] } ], "source": [ "@tf.function\n", "def f(x,y):\n", " return tf.sparse.sparse_dense_matmul(x,y)\n", "\n", "a = tf.sparse.SparseTensor(indices=[[0, 3], [2, 4]],\n", " values=[15, 25],\n", " dense_shape=[3, 10])\n", "\n", "b = tf.sparse.to_dense(tf.sparse.transpose(a))\n", "\n", "c = f(a,b)\n", "\n", "print(c)" ] }, { "cell_type": "markdown", "metadata": { "id": "YPe5uC_X7XjZ" }, "source": [ "## Distinguishing missing values from zero values\n", "\n", "Most ops on `tf.sparse.SparseTensor`s treat missing values and explicit zero values identically. This is by design — a `tf.sparse.SparseTensor` is supposed to act just like a dense tensor.\n", "\n", "However, there are a few cases where it can be useful to distinguish zero values from missing values. In particular, this allows for one way to encode missing/unknown data in your training data. For example, consider a use case where you have a tensor of scores (that can have any floating point value from -Inf to +Inf), with some missing scores. You can encode this tensor using a sparse tensor where the explicit zeros are known zero scores but the implicit zero values actually represent missing data and not zero. \n", "\n", "Note: This is generally not the intended usage of `tf.sparse.SparseTensor`s; and you might want to also consider other techniques for encoding this such as for example using a separate mask tensor that identifies the locations of known/unknown values. However, exercise caution while using this approach, since most sparse operations will treat explicit and implicit zero values identically." ] }, { "cell_type": "markdown", "metadata": { "id": "tZ17F9e3ZJDS" }, "source": [ "Note that some ops like `tf.sparse.reduce_max` do not treat missing values as if they were zero. For example, when you run the code block below, the expected output is `0`. However, because of this exception, the output is `-3`." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.451103Z", "iopub.status.busy": "2024-10-25T01:24:12.450854Z", "iopub.status.idle": "2024-10-25T01:24:12.456247Z", "shell.execute_reply": "2024-10-25T01:24:12.455693Z" }, "id": "kcNBVVtBZav_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(-3, shape=(), dtype=int32)\n" ] } ], "source": [ "print(tf.sparse.reduce_max(tf.sparse.from_dense([-5, 0, -3])))" ] }, { "cell_type": "markdown", "metadata": { "id": "zhzWLW-bMfI5" }, "source": [ "In contrast, when you apply `tf.math.reduce_max` to a dense tensor, the output is 0 as expected." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-10-25T01:24:12.458811Z", "iopub.status.busy": "2024-10-25T01:24:12.458580Z", "iopub.status.idle": "2024-10-25T01:24:12.462627Z", "shell.execute_reply": "2024-10-25T01:24:12.462048Z" }, "id": "7Xy-g3VDNK9d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(0, shape=(), dtype=int32)\n" ] } ], "source": [ "print(tf.math.reduce_max([-5, 0, -3]))" ] }, { "cell_type": "markdown", "metadata": { "id": "uK3U8l0kNL37" }, "source": [ "## Further reading and resources\n", "\n", "* Refer to the [tensor guide](https://www.tensorflow.org/guide/tensor) to learn about tensors.\n", "* Read the [ragged tensor guide](https://www.tensorflow.org/guide/ragged_tensor) to learn how to work with ragged tensors, a type of tensor that lets you work with non-uniform data.\n", "* Check out this object detection model in the [TensorFlow Model Garden](https://github.com/tensorflow/models) that uses sparse tensors in a [`tf.Example` data decoder](https://github.com/tensorflow/models/blob/9139a7b90112562aec1d7e328593681bd410e1e7/research/object_detection/data_decoders/tf_example_decoder.py).\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "sparse_tensor.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.20" } }, "nbformat": 4, "nbformat_minor": 0 }