Source code for greatx.utils.bunchdict

from collections import OrderedDict

from tabulate import tabulate


[docs]class BunchDict(OrderedDict): """Container object for datasets Dictionary-like object that exposes its keys as attributes and remembers insertion order. Examples -------- >>> b = BunchDict(a=1, b=2) >>> b Objects in BunchDict: ╒═════════╤═══════════╕ │ Names │ Objects │ ╞═════════╪═══════════╡ │ a │ 1 │ ├─────────┼───────────┤ │ b │ 2 │ ╘═════════╧═══════════╛ >>> b['b'] 2 >>> b.b 2 >>> b.a = 3 >>> b['a'] 3 >>> b.c = 6 >>> b['c'] 6 >>> # Converting objects in BunchDict to `torch.Tensor` if possible. >>> b = BunchDict(a=[1,2,3]) >>> b.to_tensor() Objects in BunchDict: ╒═════════╤═══════════════════════════════╕ │ Names │ Objects │ ╞═════════╪═══════════════════════════════╡ │ a │ Tensor, shape=torch.Size([3]) │ │ │ tensor([1, 2, 3]) │ ╘═════════╧═══════════════════════════════╛ >>> b.a tensor([1, 2, 3]) """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __setattr__(self, key, value): self[key] = value def __dir__(self): return self.keys() def __getattr__(self, key): try: return self[key] except KeyError: raise AttributeError(key)
[docs] def to_tensor(self, device: str = 'cpu', dtype=None) -> "BunchDict": """Convert objects in BunchDict to :class:`torch.Tensor` Parameters ---------- device : str, optional device of the converted tensors, by default 'cpu' dtype : _type_, optional data types of the converted tensors, by default None Returns ------- the converted BunchDict """ import torch device = torch.device(device) for k, v in self.items(): try: self[k] = torch.as_tensor(v, dtype=dtype, device=device) except RuntimeError: pass return self
def __repr__(self) -> str: table_headers = ["Names", "Objects"] items = tuple(map(prettify, self.items())) table = tabulate(items, headers=table_headers, tablefmt="fancy_grid") return table __str__ = __repr__
def prettify(item): key, val = item if val is None: return key, 'None' if hasattr(val, "shape"): if len(val.shape) == 0 and hasattr(val, "item"): val = f"{val.__class__.__name__}, {val.item()}" else: val = f"{val.__class__.__name__}, shape={val.shape}\n{val}" else: if isinstance(val, str): val = f"{val}" else: try: val = f"{type(val).__name__}, len={len(val)}" except TypeError: pass return key, val