zea.agent.gumbel¶
Gumbel-Softmax trick implemented with the multi-backend keras.ops
.
Functions
|
Applies the hard straight-through estimator to the given k-hot encoded tensor. |
Classes
|
SubsetOperator applies the Gumbel-Softmax trick for continuous top-k selection. |
- class zea.agent.gumbel.SubsetOperator(k, tau=1.0, hard=False, n_value_dims=1)[source]¶
Bases:
object
SubsetOperator applies the Gumbel-Softmax trick for continuous top-k selection.
- Parameters:
k (int) – The number of elements to select.
tau (float, optional) – The temperature parameter for Gumbel-Softmax. Defaults to 1.0.
hard (bool, optional) – Whether to use straight-through Gumbel-Softmax. Defaults to False.
- Sources:
- zea.agent.gumbel.hard_straight_through(khot_orig, k, n_value_dims=1)[source]¶
Applies the hard straight-through estimator to the given k-hot encoded tensor.
- Parameters:
khot_orig (Tensor) – The original k-hot encoded tensor.
k (int) – The number of top elements to select.
n_value_dims (int, optional) – The number of value dimensions in the input tensor. Defaults to 1. E.g. for a 2D image mask, n_value_dims=2.
- Returns:
- The tensor after applying the hard straight-through estimator,
with the same shape as khot_orig.
- Return type:
Tensor