Skip to content

Adding ZUNA to Braindecode#1020

Open
jonathanhuml wants to merge 6 commits into
braindecode:masterfrom
jonathanhuml:zuna
Open

Adding ZUNA to Braindecode#1020
jonathanhuml wants to merge 6 commits into
braindecode:masterfrom
jonathanhuml:zuna

Conversation

@jonathanhuml

Copy link
Copy Markdown

SUMMARY

We propose to add ZUNA, a masked diffusion autoencoder trained to perform masked channel infilling and superresolution for arbitrary electrode numbers and positions in EEG signals.

While the original encoder-decoder model has 380M parameters, this port is a feature extractor that only exposes the latents and does not perform reconstruction. The implementation also includes basic support for channel masking and dropped-channel inference, either by channel index or by channel name. This is intended to preserve some of the practical behavior expected from a model trained around masked channel infilling.
By discarding the decoder, the total model size is about 170M parameters. We are totally happy to add the decoder back in later: let us know what BrainDecode could most benefit from! We did this to keep the file as lightweight and readable as possible. We are currently training a new version and will likely submit another PR soon, so we can definitely integrate this into the next version if so desired.

We have tried to keep as close as possible to the requirements and base.py structure in Braindecode. We have two main design decisions that would probably be helpful to highlight for the Braindecode team for any potential feedback:
ZUNA currently depends on PyTorch flex_attention. This is only available in PyTorch >=2.5, while Braindecode currently supports PyTorch >=2.0. To avoid making ZUNA break imports for users on older PyTorch versions, this PR uses a soft import pattern similar to the Hugging Face integration. As a result, Braindecode can still be imported normally, but instantiating or running ZUNA requires a compatible PyTorch version. ZUNA could theoretically support other attention variants at the cost of less efficient GPU usage, but we have implemented the flex-only variant for now
ZUNA is montage-agnostic but requires 3D channel positions during the forward pass. These positions are not inherent to the model weights, so this PR allows users to provide them dynamically rather than binding them at model initialization. When chs_info contains the necessary position metadata, we also support extracting the positions from there.

jonathanhuml and others added 6 commits May 13, 2026 12:18
…ntions

ZUNA worked CUDA-only because flex_attention is unstable for autograd on
CPU; this collapses the whole API onto SDPA and trims the inference
port to a maintainable size.

Model (braindecode/models/zuna.py, 861 -> 400 lines):
- Drop flex_attention path, BlockMask, _create_document_mask, packed
  multi-document layout, and the use_flex / attention_impl plumbing.
  SDPA per-batch is mathematically equivalent (each sample is its own
  document) and supports CPU + fp16 + autograd.
- Replace custom _RMSNorm with torch.nn.RMSNorm.
- Drop _ZUNAEncoderArgs dataclass, _build_zuna_encoder_args validation,
  InitStdFactor enum, and the six zero-arg config helpers in favour of
  module constants plus one patchable _encoder_config().
- Drop unused dataclass knobs: init_base_std, init_std_factor,
  encoder_hidden_dim, n_kv_heads, ffn_dim_multiplier, multiple_of,
  dropout_type / dropout_vec, encoder_latent_downsample_factor,
  tok_idx_type.
- Drop _apply_channel_mask + dropped_channels API (caller-side concern).
- Drop _discretize_channel_positions string extremes_type; inline the
  constant ZUNA_POS_HALF_RANGE = 0.12.
- Drop _repeat_kv / GQA plumbing (never used: n_kv_heads always equals
  n_heads in the published config).
- Cache positions from chs_info at construction; resolve once.
- Replace 4-lookup + cat RoPE gather with one indexed read + flatten.
- Bucket positions in fp32 so fp16 inference doesn't shift bucket edges.
- Distinguish 'no coords' vs 'names without montage' errors.
- Split forward shape validation so ndim / n_chans / n_times mismatches
  get targeted messages.

Param count matches Zyphra/ZUNA published checkpoint exactly
(172,069,668 params for n_chans=22, n_outputs=4).

Tests (test/unit_tests/models/test_models.py, 290 -> 151 lines):
- Drop _zuna_forward inference_mode wrapper - no longer needed.
- Drop @pytest.mark.skipif(not zuna.HAS_FLEX) on every test.
- Drop redundant channel_mask / dropped_channels tests with the API.
- Drop _zuna_published_config snapshot test (superseded by inline
  constants).
- Parametrize n_times rejection over (1279, 1281).
- Add test_zuna_requires_montage_when_names_only.
- Drop CUDA-only skips in test_integration.py for ZUNA.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants