If you’re working with multidimensional tensors (eg. in numpy or pytorch), a helpful pattern is often to use pattern matching to get the sizes of various dimensions. Like this: batch, chan, w, h = x.shape. And sometimes you already know some of these dimensions, and want to assert that they have the correct values. Here is a convenient way to do that. Define the following class and single instance of it:
class _MustBe:
""" class for asserting that a dimension must have a certain value.
the class itself is private, one should import a particular object,
"must_be" in order to use the functionality. example code:
`batch, chan, must_be[32], must_be[32] = image.shape` """
def __setitem__(self, key, value):
assert key == value, "must_be[%d] does not match dimension %d" % (key, value)
must_be = _MustBe()
This hack overrides index assignment and replaces it with an assertion. To use, import must_be from the file where you defined it. Now you can do stuff like this:
batch, must_be[3] = v.shape
must_be[batch], l, n = A.shape
must_be[batch], must_be[n], m = B.shape
...
If you’re working with multidimensional tensors (eg. in numpy or pytorch), a helpful pattern is often to use pattern matching to get the sizes of various dimensions. Like this:
batch, chan, w, h = x.shape
. And sometimes you already know some of these dimensions, and want to assert that they have the correct values. Here is a convenient way to do that. Define the following class and single instance of it:This hack overrides index assignment and replaces it with an assertion. To use, import
must_be
from the file where you defined it. Now you can do stuff like this:Linkpost for: https://pbement.com/posts/must_be.html