tensorflow string_split使用,不懂的方法覺得還是 看 原始碼 來的 快

先看原始碼 位置:

原始碼如下:

@tf_export("string_split")

def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name

"""Split elements of `source` based on `delimiter` into a `SparseTensor`.

Let N be the size of source (typically N will be the batch size). Split each

element of `source` based on `delimiter` and return a `SparseTensor`

containing the split tokens. Empty tokens are ignored.

If `delimiter` is an empty string, each element of the `source` is split

into individual strings, each containing one byte. (This includes splitting

multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is

treated as a set of delimiters with each considered a potential split point.

For example:

N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output

will be

st.indices = [0, 0;

0, 1;

1, 0;

1, 1;

1, 2]

st.shape = [2, 3]

st.values = ['hello', 'world', 'a', 'b', 'c']

Args:

source: `1-D` string `Tensor`, the strings to split.

delimiter: `0-D` string `Tensor`, the delimiter character, the string should

be length 0 or 1.

skip_empty: A `bool`. If `True`, skip the empty strings from the result.

Raises:

ValueError: If delimiter is not a string.

Returns:

A `SparseTensor` of rank `2`, the strings split according to the delimiter.

The first column of the indices corresponds to the row in `source` and the

second column corresponds to the index of the split component in this row.

"""

delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string)

source = ops.convert_to_tensor(source, dtype=dtypes.string)

indices, values, shape = gen_string_ops.string_split(

source, delimiter=delimiter, skip_empty=skip_empty)

indices.set_shape([None, 2])

values.set_shape([None])

shape.set_shape([2])

return sparse_tensor.SparseTensor(indices, values, shape)

引數 說明 主要 說一下 返回值 的 結構:

  該操作 會返回 一個 秩為 2 的 sparsetensor(稀疏張量),該張量的第一列 表示 所在 的行 索引。第二列表示,拆分後

各個 元素 的索引,比如 'hello world'  返回的是 [[0,0],[0,1]] ,而 ‘a b c ’返回的如下(跑一下試試)

SparseTensorValue(indices=array([[0, 0],
       [0, 1],
       [0, 2]], dtype=int64), values=array([b'a', b'b', b'c'], dtype=object), dense_shape=array([1, 3], dtype=int64))
.