zea.agent.gumbel

Gumbel-Softmax trick implemented with the multi-backend keras.ops.

Functions

hard_straight_through(khot_orig, k[, ...])

Applies the hard straight-through estimator to the given k-hot encoded tensor.

Classes

SubsetOperator(k[, tau, hard, n_value_dims])

SubsetOperator applies the Gumbel-Softmax trick for continuous top-k selection.

class zea.agent.gumbel.SubsetOperator(k, tau=1.0, hard=False, n_value_dims=1)[source]

Bases: object

SubsetOperator applies the Gumbel-Softmax trick for continuous top-k selection.

Parameters:
  • k (int) – The number of elements to select.

  • tau (float, optional) – The temperature parameter for Gumbel-Softmax. Defaults to 1.0.

  • hard (bool, optional) – Whether to use straight-through Gumbel-Softmax. Defaults to False.

Sources:
gumbel_sample(shape)[source]

Samples from Gumbel(0,1) distribution

zea.agent.gumbel.hard_straight_through(khot_orig, k, n_value_dims=1)[source]

Applies the hard straight-through estimator to the given k-hot encoded tensor.

Parameters:
  • khot_orig (Tensor) – The original k-hot encoded tensor.

  • k (int) – The number of top elements to select.

  • n_value_dims (int, optional) – The number of value dimensions in the input tensor. Defaults to 1. E.g. for a 2D image mask, n_value_dims=2.

Returns:

The tensor after applying the hard straight-through estimator,

with the same shape as khot_orig.

Return type:

Tensor