Adding ZUNA to Braindecode#1020
Open
jonathanhuml wants to merge 6 commits into
Open
Conversation
…ntions ZUNA worked CUDA-only because flex_attention is unstable for autograd on CPU; this collapses the whole API onto SDPA and trims the inference port to a maintainable size. Model (braindecode/models/zuna.py, 861 -> 400 lines): - Drop flex_attention path, BlockMask, _create_document_mask, packed multi-document layout, and the use_flex / attention_impl plumbing. SDPA per-batch is mathematically equivalent (each sample is its own document) and supports CPU + fp16 + autograd. - Replace custom _RMSNorm with torch.nn.RMSNorm. - Drop _ZUNAEncoderArgs dataclass, _build_zuna_encoder_args validation, InitStdFactor enum, and the six zero-arg config helpers in favour of module constants plus one patchable _encoder_config(). - Drop unused dataclass knobs: init_base_std, init_std_factor, encoder_hidden_dim, n_kv_heads, ffn_dim_multiplier, multiple_of, dropout_type / dropout_vec, encoder_latent_downsample_factor, tok_idx_type. - Drop _apply_channel_mask + dropped_channels API (caller-side concern). - Drop _discretize_channel_positions string extremes_type; inline the constant ZUNA_POS_HALF_RANGE = 0.12. - Drop _repeat_kv / GQA plumbing (never used: n_kv_heads always equals n_heads in the published config). - Cache positions from chs_info at construction; resolve once. - Replace 4-lookup + cat RoPE gather with one indexed read + flatten. - Bucket positions in fp32 so fp16 inference doesn't shift bucket edges. - Distinguish 'no coords' vs 'names without montage' errors. - Split forward shape validation so ndim / n_chans / n_times mismatches get targeted messages. Param count matches Zyphra/ZUNA published checkpoint exactly (172,069,668 params for n_chans=22, n_outputs=4). Tests (test/unit_tests/models/test_models.py, 290 -> 151 lines): - Drop _zuna_forward inference_mode wrapper - no longer needed. - Drop @pytest.mark.skipif(not zuna.HAS_FLEX) on every test. - Drop redundant channel_mask / dropped_channels tests with the API. - Drop _zuna_published_config snapshot test (superseded by inline constants). - Parametrize n_times rejection over (1279, 1281). - Add test_zuna_requires_montage_when_names_only. - Drop CUDA-only skips in test_integration.py for ZUNA.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
SUMMARY
We propose to add ZUNA, a masked diffusion autoencoder trained to perform masked channel infilling and superresolution for arbitrary electrode numbers and positions in EEG signals.
While the original encoder-decoder model has 380M parameters, this port is a feature extractor that only exposes the latents and does not perform reconstruction. The implementation also includes basic support for channel masking and dropped-channel inference, either by channel index or by channel name. This is intended to preserve some of the practical behavior expected from a model trained around masked channel infilling.
By discarding the decoder, the total model size is about 170M parameters. We are totally happy to add the decoder back in later: let us know what BrainDecode could most benefit from! We did this to keep the file as lightweight and readable as possible. We are currently training a new version and will likely submit another PR soon, so we can definitely integrate this into the next version if so desired.
We have tried to keep as close as possible to the requirements and base.py structure in Braindecode. We have two main design decisions that would probably be helpful to highlight for the Braindecode team for any potential feedback:
ZUNA currently depends on PyTorch
flex_attention. This is only available in PyTorch>=2.5, while Braindecode currently supports PyTorch>=2.0. To avoid making ZUNA break imports for users on older PyTorch versions, this PR uses a soft import pattern similar to the Hugging Face integration. As a result, Braindecode can still be imported normally, but instantiating or running ZUNA requires a compatible PyTorch version. ZUNA could theoretically support other attention variants at the cost of less efficient GPU usage, but we have implemented the flex-only variant for nowZUNA is montage-agnostic but requires 3D channel positions during the forward pass. These positions are not inherent to the model weights, so this PR allows users to provide them dynamically rather than binding them at model initialization. When
chs_infocontains the necessary position metadata, we also support extracting the positions from there.