LinearCrossEntropyOptions#
- class torch.nn.LinearCrossEntropyOptions(allow_retain_graph=False, batch_chunk_size=None, chunking_method='auto', acc_policy='auto', acc_dtype=None)[source]#
Configuration for the chunked implementation of
linear_cross_entropy().The chunked implementation processes the batch dimension in pieces, so the full
(num_batches, num_classes)logits tensor is never materialized – useful whennum_classesis much larger thanin_features(e.g. LLM vocabulary heads). Passoptions=Noneto use the reference path; pass an instance of this class to opt in.Zero-argument
LinearCrossEntropyOptions()leavesacc_policyandchunking_methodset to"auto", resolved at call time from_AUTO_DEFAULTS(per-(device, dtype) picks measured on A100 / x86 CPU); unlisted pairs fall back to("compact", "aspect_ratio:2").Supports a subset of
linear_cross_entropy(); unsupported configurations fall through to the reference path with a warning.Chunking is a win when
num_batches >= in_featuresandnum_classes > in_features; below that, the reference path is cheaper.- acc_dtype: dtype | None#
Dtype for internal accumulation.
Noneresolves at call time totorch.float32underacc_policy="auto"with fp16/bf16 input on hardware with mixed-precision mm (CUDA SM 7.0+ for fp16, SM 8.0+ for bf16, and CPU); otherwise to the input dtype. Mixed-precision currently requires fp16/bf16 input withacc_dtype=torch.float32.
- acc_policy: Literal['accurate', 'balanced', 'compact', 'auto']#
Precision/memory trade-off for the chunked path. Controls which intermediates are kept in
acc_dtypevs. the input dtype, and whether the per-chunk weight-gradient scratch buffer is materialized."auto"(default) – per-(device, dtype) pick from_AUTO_DEFAULTS; unlisted pairs fall back to"compact". The fallback assumes a CUDA-like backend with hardware-native low-precision matmul; pass"accurate"explicitly on backends that emulate fp16/bf16 GEMMs via fp32 upcast."accurate"– broadest use ofacc_dtype; noticeably better input-grad accuracy when chunk size is large relative tonum_classes. Highest peak memory and slowest of the chunked policies on CUDA. Only chunked policy whose weight-grad matmul runs in fp32 on CPU (other policies hit CPU’s emulated low-precision path, ~20-50x slower)."balanced"–acc_dtypeonly where needed for gradient correctness; keeps a(num_classes, in_features)acc_dtypescratch for cross-chunk weight-grad accumulation. Same precision as"accurate"in bf16, slightly looser in fp16, faster than"accurate"in both."compact"– like"balanced"but drops the weight-grad scratch and accumulates per-chunk directly viaaddmm_(cuBLAS uses an fp32 internal accumulator, so bulk precision matches"balanced"). Savesnum_classes * in_features * sizeof(acc_dtype)– typically several hundred MB for an LLM head. On non-CUDA mixed-precision falls back to"balanced".
Policy effects (
"balanced"vs"accurate") are visible only whenacc_dtypediffers from the input dtype;"compact"saves memory in both regimes.
- allow_retain_graph: bool#
Allow
retain_graph=Trueon backward.When
False(default), backward consumes pre-computed gradient buffers in place; a second.backward()raisesRuntimeError.When
True, the buffers are preserved at the cost of one extra gradient-sized allocation per call.Higher-order autograd (gradgrad, forward-mode AD) is unsupported.
Under
torch.compile()this field is auto-promoted toTruebecause the default-mode second-backward guard relies on a ctx mutation Dynamo doesn’t preserve; the wrapper warns on the promotion.
- batch_chunk_size: int | None#
Batch rows per chunk. The op loops over
ceil(num_batches / batch_chunk_size)chunks; smaller values cut peak memory but launch more kernels. DefaultNonemeans a single chunk. Cannot be combined withchunking_method– if both are set and disagree,ValueErroris raised.
- chunking_method: str | None#
Heuristic for picking
batch_chunk_size."auto"(default) – resolves to a per-(device, dtype) pick from_AUTO_DEFAULTSat call time; falls back to"aspect_ratio:2"for unlisted pairs."aspect_ratio"– sizes each chunk so its(batch_chunk_size, num_classes)logits buffer matches the(num_batches, in_features)input in memory:next_pow2(ceil(num_batches / ceil(num_classes / in_features))). Best whennum_classes >> in_features(LLM vocab heads)."aspect_ratio:N"(N >= 1) – same, divided byN. ~N times less peak memory at the cost of N times more chunks.None– disables the heuristic; usesbatch_chunk_size.