shapes

package
v0.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 4, 2025 License: Apache-2.0 Imports: 12 Imported by: 1

Documentation

Overview

Package shapes defines Shape and DType and associated tools.

Shape represents the shape (rank, dimensions, and DType) of either a Tensor or the expected shape of a node in a computation Graph. DType indicates the type of a Tensor's unit element.

Shape and DType are used both by the concrete tensor values (see tensor package) and when working on the computation graph (see graph package).

Go float16 support (commonly used by Nvidia GPUs) uses github.com/x448/float16 implementation, and bfloat16 uses a simple implementation in github.com/gomlx/gopjrt/dtypes/bfloat16.

## Glossary

  • Rank: number of axes (dimensions) of a Tensor.
  • Axis: is the index of a dimension on a multidimensional Tensor. Sometimes used interchangeably with Dimension, but here we try to refer to a dimension index as "axis" (plural axes), and its size as its dimension.
  • Dimension: the size of a multi-dimension Tensor in one of its axes. See the example below.
  • DType: the data type of the unit element in a tensor. Enumeration defined in github.com/gomlx/gopjrt/dtypes
  • Scalar: is a shape where there are no axes (or dimensions), only a single value of the associated DType.

Example: The multi-dimensional array `[][]int32{{0, 1, 2}, {3, 4, 5}}` if converted to a Tensor would have shape `(int32)[2 3]`. We say it has rank 2 (so 2 axes), axis 0 has dimension 2, and axis 1 has dimension 3. This shape could be created with `shapes.Make(int32, 2, 3)`.

## Asserts

When coding ML models, one delicate part is keeping tabs on the shape of graph nodes -- unfortunately, there is no compile-time checking of values, so validation only happens in runtime. To facilitate and also to serve as code documentation, this package provides two variations of _assert_ functionality. Examples:

AssertRank and AssertDims check that the rank and dimensions of the given object (that has a `Shape` method) match, otherwise it panics. The `-1` means the dimension is unchecked (it can be anything).

func modelGraph(ctx *context.Context, spec any, inputs []*Node) ([]*Node) {
	_ = spec  // Not needed here, we know the dataset.
	shapes.AssertRank(inputs, 2)
	batchSize := inputs.Shape().Dimensions[0]
	logits := layers.Dense(ctx, inputs[0], /* useBias= */ true, /* outputDim= */ 1)
	shapes.AssertDims(logits, batchSize, -1)
	return []*Node{logits}
}

```

If you don't want to panic, but instead return an error through the `graph.Graph`, you can use the `Node.AssertDims()` method. So it would look like `logits.AssertDims(batchSize, -1)`.

Index

Constants

View Source
const UncheckedAxis = int(-1)

UncheckedAxis can be used in CheckDims or AssertDims functions for an axis whose dimension doesn't matter.

Variables

This section is empty.

Functions

func Assert

func Assert(shaped HasShape, dtype dtypes.DType, dimensions ...int)

Assert checks that the shape has the given dtype, dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It panics if it doesn't match.

func AssertDims

func AssertDims(shaped HasShape, dimensions ...int)

AssertDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It panics if it doesn't match.

See usage example in package shapes documentation.

func AssertRank

func AssertRank(shaped HasShape, rank int)

AssertRank checks that the shape has the given rank.

It panics if it doesn't match.

See usage example in package shapes documentation.

func AssertScalar

func AssertScalar(shaped HasShape)

AssertScalar checks that the shape is a scalar.

It panics if it doesn't match.

See usage example in package shapes documentation.

func CastAsDType

func CastAsDType(value any, dtype dtypes.DType) any

CastAsDType casts a numeric value to the corresponding for the DType. If the value is a slice it will convert to a newly allocated slice of the given DType.

It doesn't work for complex numbers.

func CheckDims

func CheckDims(shaped HasShape, dimensions ...int) error

CheckDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It returns an error if the rank is different or any of the dimensions.

func CheckRank

func CheckRank(shaped HasShape, rank int) error

CheckRank checks that the shape has the given rank.

It returns an error if the rank is different.

func CheckScalar

func CheckScalar(shaped HasShape) error

CheckScalar checks that the shape is a scalar.

It returns an error if shape is not a scalar.

func ConvertTo

func ConvertTo[T dtypes.NumberNotComplex](value any) T

ConvertTo converts any scalar (typically returned by `tensor.Local.Value()`) of the supported dtypes to `T`. Returns 0 if value is not a scalar or not a supported number (e.g: bool). It doesn't work for if T (the output type) is a complex number. If value is a complex number, it converts by taking the real part of the number and discarding the imaginary part.

func UnsafeSliceForDType

func UnsafeSliceForDType(dtype dtypes.DType, unsafePtr unsafe.Pointer, len int) (any, error)

UnsafeSliceForDType creates a slice of the corresponding dtype and casts it to any. It uses unsafe.Slice. Set `len` to the number of `DType` elements (not the number of bytes).

Types

type HasShape

type HasShape interface {
	Shape() Shape
}

HasShape is an interface for objects that have an associated Shape. `tensor.Tensor` (concrete tensor) and `graph.Node` (tensor representations in a computation graph), `context.Variable` and Shape itself implement the interface.

type Shape

type Shape struct {
	DType       dtypes.DType
	Dimensions  []int
	TupleShapes []Shape // Shapes of the tuple, if this is a tuple.
}

Shape represents the shape of either a Tensor or the expected shape of the value from a computation node.

Use Make to create a new shape. See example in package shapes documentation.

func ConcatenateDimensions

func ConcatenateDimensions(s1, s2 Shape) (shape Shape)

ConcatenateDimensions of two shapes. The resulting rank is the sum of both ranks. They must have the same dtype. If any of them is a scalar, the resulting shape will be a copy of the other. It doesn't work for Tuples.

func FromAnyValue

func FromAnyValue(v any) (shape Shape, err error)

FromAnyValue attempts to convert a Go "any" value to its expected shape. Accepted values are plain-old-data (POD) types (ints, floats, complex), slices (or multiple level of slices) of POD.

It returns the expected shape.

Example:

shape := shapes.FromAnyValue([][]float64{{0, 0}}) // Returns shape (Float64)[1 2]

func GobDeserialize

func GobDeserialize(decoder *gob.Decoder) (s Shape, err error)

GobDeserialize a Shape. Returns new Shape or an error.

func Invalid

func Invalid() Shape

Invalid returns an invalid shape.

Invalid().IsOk() == false.

func Make

func Make(dtype dtypes.DType, dimensions ...int) Shape

Make returns a Shape structure filled with the values given. See MakeTuple for tuple shapes.

func MakeTuple

func MakeTuple(elements []Shape) Shape

MakeTuple returns a shape representing a tuple of elements with the given shapes.

func Scalar

func Scalar[T dtypes.Number]() Shape

Scalar returns a scalar Shape for the given type.

func (Shape) Assert

func (s Shape) Assert(dtype dtypes.DType, dimensions ...int)

Assert checks that the shape has the given dtype, dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It panics if it doesn't match.

func (Shape) AssertDims

func (s Shape) AssertDims(dimensions ...int)

AssertDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It panics if it doesn't match.

See usage example in package shapes documentation.

func (Shape) AssertRank

func (s Shape) AssertRank(rank int)

AssertRank checks that the shape has the given rank.

It panics if it doesn't match.

See usage example in package shapes documentation.

func (Shape) AssertScalar

func (s Shape) AssertScalar()

AssertScalar checks that the shape is a scalar.

It panics if it doesn't match.

See usage example in package shapes documentation.

func (Shape) Check

func (s Shape) Check(dtype dtypes.DType, dimensions ...int) error

Check that the shape has the given dtype, dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It returns an error if the dtype or rank is different or if any of the dimensions don't match.

func (Shape) CheckDims

func (s Shape) CheckDims(dimensions ...int) error

CheckDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It returns an error if the rank is different or if any of the dimensions don't match.

func (Shape) CheckRank

func (s Shape) CheckRank(rank int) error

CheckRank checks that the shape has the given rank.

It returns an error if the rank is different.

func (Shape) CheckScalar

func (s Shape) CheckScalar() error

CheckScalar checks that the shape is a scalar.

It returns an error if shape is not a scalar.

func (Shape) Clone

func (s Shape) Clone() (s2 Shape)

Clone returns a new deep copy of the shape.

func (Shape) Dim

func (s Shape) Dim(axis int) int

Dim returns the dimension of the given axis. axis can take negative numbers, in which case it counts as starting from the end -- so axis=-1 refers to the last axis. Like with a slice indexing, it panics for an out-of-bound axis.

func (Shape) Equal

func (s Shape) Equal(s2 Shape) bool

Equal compares two shapes for equality: dtype and dimensions are compared.

func (Shape) EqualDimensions

func (s Shape) EqualDimensions(s2 Shape) bool

EqualDimensions compares two shapes for equality of dimensions. Dtypes can be different.

func (Shape) GobSerialize

func (s Shape) GobSerialize(encoder *gob.Encoder) (err error)

GobSerialize shape in binary format.

func (Shape) IsScalar

func (s Shape) IsScalar() bool

IsScalar returns whether the shape represents a scalar, that is there are no dimensions (rank==0).

func (Shape) IsTuple

func (s Shape) IsTuple() bool

IsTuple returns whether the shape represents a tuple.

func (Shape) IsZeroSize

func (s Shape) IsZeroSize() bool

IsZeroSize returns whether any of the dimensions is zero, in which case it's an empty shape, with no data attached to it.

Notice scalars are not zero in size -- they have size one, but rank zero.

func (Shape) Memory

func (s Shape) Memory() uintptr

Memory returns the memory used to store an array of the given shape, the same as the size in bytes. Careful, so far all types in Go and on device seem to use the same sizes, but future type this is not guaranteed.

func (Shape) Ok

func (s Shape) Ok() bool

Ok returns whether this is a valid Shape. A "zero" shape, that is just instantiating it with Shape{} will be invalid.

func (Shape) Rank

func (s Shape) Rank() int

Rank of the shape, that is, the number of dimensions.

func (Shape) Shape

func (s Shape) Shape() Shape

Shape returns a shallow copy of itself. It implements the HasShape interface.

func (Shape) Size

func (s Shape) Size() (size int)

Size returns the number of elements (not bytes) for this shape. It's the product of all dimensions.

For the number of bytes used to store this shape, see Shape.Memory.

func (Shape) String

func (s Shape) String() string

String implements stringer, pretty-prints the shape.

func (Shape) ToStableHLO

func (s Shape) ToStableHLO() string

ToStableHLO returns the ToStableHLO representation of the shape's type.

func (Shape) TupleSize

func (s Shape) TupleSize() int

TupleSize returns the number of elements in the tuple, if it is a tuple.

func (Shape) WriteStableHLO

func (s Shape) WriteStableHLO(writer io.Writer) error

WriteStableHLO writes the StableHLO representation of the shape's type to the given writer.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL