get_dtype¶
- torchtune.utils.precision.get_dtype(dtype: Optional[str] = None, device: Optional[device] = None) dtype [source]¶
Get the torch.dtype corresponding to the given precision string.
- NOTE: If bf16 precision is requested with a CUDA device, we verify whether the device indeed supports
bf16 kernels. If not, the dtype returned is torch.float32.
- Parameters:
dtype (Optional[str]) – The precision dtype. Default:
None
, in which we default to torch.float32device (Optional[torch.device]) – Device in use for training. Only CUDA and CPU devices are supported. If a CUDA device is passed in, additional checking is done to ensure that the device supports the requested precision. Default:
None
, in which case a CUDA device is assumed.
- Raises:
ValueError – if precision isn’t supported by the precision utils
- Returns:
The corresponding torch.dtype.
- Return type: