zea.agent.masks

Mask generation utilities.

These masks are used as a measurement operator for focused scan-line subsampling.

Functions

indices_to_k_hot(indices, n_possible_actions)

Convert a list of indices to a k-hot encoded vector.

initial_equispaced_lines(n_actions, ...[, ...])

Generate an initial equispaced k-hot line mask.

k_hot_to_indices(selected_lines, n_actions)

Convert k-hot encoded lines to indices of selected actions.

lines_to_im_size(lines, img_size)

Convert k-hot-encoded line vectors to image size.

make_line_mask(line_indices, image_shape[, ...])

Creates a mask with vertical (i.e. second axis) lines at specified indices.

next_equispaced_lines(previous_lines[, shift])

Rolls the previous equispaced mask of shape (..., n_possible_actions) to the right by shift which is 1 by default.

random_uniform_lines(n_actions, ...[, seed, ...])

Will generate a mask with random lines.

zea.agent.masks.indices_to_k_hot(indices, n_possible_actions, dtype='bool')[source]

Convert a list of indices to a k-hot encoded vector.

A k-hot encoded vector is suitable during tracing when the number of actions can change. This is the default represenation for actions in zea.

Parameters:
  • indices (List[int]) – List of indices to set to 1.

  • n_possible_actions (int) – Total number of possible actions.

  • dtype (str, optional) – Data type of the mask. Defaults to _DEFAULT_DTYPE.

Returns:

k-hot-encoded vector of shape (n_possible_actions).

Return type:

Tensor

zea.agent.masks.initial_equispaced_lines(n_actions, n_possible_actions, dtype='bool', assert_equal_spacing=True)[source]

Generate an initial equispaced k-hot line mask.

For example, if n_actions=2 and n_possible_actions=6, then initial_mask=[1, 0, 0, 1, 0, 0].

Parameters:
  • n_actions (int) – Number of actions to be selected.

  • n_possible_actions (int) – Number of possible actions.

  • dtype (str, optional) – Data type of the mask. Defaults to _DEFAULT_DTYPE.

  • assert_equal_spacing (bool, optional) – If True, asserts that n_possible_actions is divisible by n_actions, this means that every line will have the exact same spacing. Otherwise, there might be some spacing differences. Defaults to True.

Returns:

k-hot-encoded line vector of shape (n_possible_actions).

Needs to be converted to image size.

Return type:

Tensor

zea.agent.masks.k_hot_to_indices(selected_lines, n_actions, fill_value=-1)[source]

Convert k-hot encoded lines to indices of selected actions.

Parameters:
  • selected_lines (Tensor) – k-hot encoded lines of shape (batch_size, n_possible_actions).

  • n_actions (int) – Number of lines selected.

  • fill_value (int, optional) – Value to fill in case there are not enough selected actions. Defaults to -1.

Returns:

Indices of selected actions of shape (batch_size, n_actions).

If there are fewer than n_actions selected, the remaining indices will be filled with fill_value.

Return type:

Tensor

zea.agent.masks.lines_to_im_size(lines, img_size)[source]

Convert k-hot-encoded line vectors to image size.

Parameters:
  • lines (Tensor) – shape is (n_masks, n_possible_actions)

  • img_size (tuple) – (height, width)

Returns:

Masks of shape (n_masks, img_size, img_size)

Return type:

Tensor

zea.agent.masks.make_line_mask(line_indices, image_shape, line_width=1, dtype='bool')[source]

Creates a mask with vertical (i.e. second axis) lines at specified indices.

Parameters:
  • line_indices (List[int]) – A list of indices where the lines should be drawn.

  • image_shape (List[int]) – The shape of the image as [height, width, channels].

  • line_width (int, optional) – The width of each line. Defaults to 1.

  • dtype (str, optional) – The data type of the mask. Defaults to “float32”.

Returns:

A tensor of the same shape as image_shape with lines drawn

at the specified indices.

Return type:

mask (Tensor)

zea.agent.masks.next_equispaced_lines(previous_lines, shift=1)[source]

Rolls the previous equispaced mask of shape (…, n_possible_actions) to the right by shift which is 1 by default.

zea.agent.masks.random_uniform_lines(n_actions, n_possible_actions, n_masks, seed=None, dtype='bool')[source]

Will generate a mask with random lines.

Guarantees precisely n_actions.

Parameters:
  • n_actions (int) – Number of actions to be selected.

  • n_possible_actions (int) – Number of possible actions.

  • n_masks (int) – Number of masks to generate.

  • seed (int | SeedGenerator | jax.random.key, optional) – Seed for random number generation. Defaults to None.

Returns:

k-hot-encoded line vectors of shape (n_masks, n_possible_actions).

Needs to be converted to image size.

Return type:

Tensor