Skip to content

Adding Batched Tensor Dot + Simplifying API#309

Merged
JeanKossaifi merged 12 commits into
tensorly:mainfrom
JeanKossaifi:main
Aug 26, 2021
Merged

Adding Batched Tensor Dot + Simplifying API#309
JeanKossaifi merged 12 commits into
tensorly:mainfrom
JeanKossaifi:main

Conversation

@JeanKossaifi

@JeanKossaifi JeanKossaifi commented Aug 23, 2021

Copy link
Copy Markdown
Member

Batched tensor-dot

This PR adds a new batched_tensor_dot function. It extends the signature of the standard tensordot function by adding a batched_modes parameters. It also fixes #250.

The signature is

batched_tensor_dot(tensor1, tensor2, contraction_modes, batched_modes)

API simplification

Using the new batched-tensordot, we can simplify the overall API so I removed the existing outer, contract, tensor_dot and batched_tensor_dot. These are all encompassed in the new function.

I generalized the existing outer function and added batched_outer, which now both support lists of tensors of arbitrary shapes, and left inner for convenience for end users, though I think this can also be removed.

Implementation

I've written a few versions and ran some comparisons. For the einsum tenalg backend there's an einsum version and I've tried a few versions using matmul or even just broacasting + sum.

def batched_tensor_dot_einsum(tensor1, tensor2, modes, batched_modes=()):
    """Tensor contraction between two tensors on specified modes
    
    Parameters
    ----------
    tensor1 : tl.tensor
    tensor2 : tl.tensor
    modes : int list or int
        modes on which to contract tensor1 and tensor2
    batched_modes : int or tuple[int]

    Returns
    -------
    contraction : tensor1 contracted with tensor2 on the specified modes
    """
    modes1, modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, modes)
    batch_modes1, batch_modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, batched_modes, batched_modes=True)

    start = ord('a')
    order_t1 = tl.ndim(tensor1)
    all_modes1 = [chr(start+i) for i in range(order_t1)]
    all_modes2 = [chr(start+i+order_t1) for i in range(tl.ndim(tensor2))]

    for m1, m2 in zip(modes1+batch_modes1, modes2+batch_modes2):
        all_modes2[m2] = all_modes1[m1]
    
    remaining_modes1 = [j for i, j in enumerate(all_modes1) if i not in modes1]
    remaining_modes2 = [j for i, j in enumerate(all_modes2) if i not in modes2+batch_modes2]
    remaining_modes = remaining_modes1 + remaining_modes2
    to_str = lambda x : ''.join(x)
    equation = f'{to_str(all_modes1)},{to_str(all_modes2)}->{to_str(remaining_modes)}'
    
    return tl.einsum(equation, tensor1, tensor2)


def batched_tensor_dot_matmul1(tensor1, tensor2, modes, batched_modes=()):
    modes1, modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, modes)
    batch_modes1, batch_modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, batched_modes, batched_modes=True)
    
    contraction_shape = [s for (i, s) in enumerate(tl.shape(tensor1)) if i in modes1]
    contraction_dim = prod(contraction_shape)
    batch_shape = [s for (i, s) in enumerate(tl.shape(tensor1)) if i in batch_modes1]
    
    # Prepare to reorganize the modes afterwards by moving bactch size back to their place
    # (while ommiting modes contracted over)
    final_modes = []
    n_batches = len(batch_modes1)
    batch_counter = 0
    free_counter = 0
    for i in range(tl.ndim(tensor1)):
        if i in modes1:
            continue
        elif i in batch_modes1:
            final_modes.append(batch_counter)
            batch_counter += 1
        else:
            final_modes.append(free_counter+n_batches)
            free_counter += 1

    # We will reorganize tensor1 to (batch_modes, new_modes1, contraction_modes)
    new_modes1 = [i for i in range(tensor1.ndim) if i not in batch_modes1+modes1]
    new_shape1 = [tl.shape(tensor1)[i] for i in new_modes1]
    tensor1 = tl.transpose(tensor1, batch_modes1 + new_modes1 + modes1)
    tensor1 = tl.reshape(tensor1, (*batch_shape, -1, contraction_dim))
    
    # Tensor2 will be (batch_modes, contraction_modes, new_modes2)
    new_modes2 = [i for i in range(tensor2.ndim) if i not in batch_modes2+modes2]
    new_shape2 = [tl.shape(tensor2)[i] for i in new_modes2]
    tensor2 = tl.transpose(tensor2, batch_modes2+modes2+new_modes2)
    tensor2 = tl.reshape(tensor2, (*batch_shape, contraction_dim, -1))
    
    res = tl.matmul(tensor1, tensor2)
    res = tl.reshape(res, (*batch_shape, *new_shape1, *new_shape2))

    final_modes += [i for i in range(res.ndim) if i not in final_modes]
    
    if final_modes:
        res = tl.transpose(res, final_modes)

    return res

def batched_tensor_dot_matmul2(tensor1, tensor2, modes, batched_modes=()):
    modes1, modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, modes)
    batch_modes1, batch_modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, batched_modes, batched_modes=True)

    contraction_shape = [s for (i, s) in enumerate(tl.shape(tensor1)) if i in modes1]
    contraction_dim = prod(contraction_shape)
    batch_shape = [s for (i, s) in enumerate(tl.shape(tensor1)) if i in batch_modes1]
    
    n_free = tl.ndim(tensor1) - len(modes1) - len(batch_modes1)
    # We will reorganize tensor1 by just moving the contraction modes to the end
    modes_begin1 = []
    modes_end1 = []
    shape_begin2 = []
    last_mode_is_batched = False
    last_mode = tl.ndim(tensor1) - 1
    for i, s in enumerate(tl.shape(tensor1)):
        if i in batch_modes1:
            modes_begin1.append(i)
            shape_begin2.append(s)
            last_mode_is_batched = True
        elif i in modes1:
            modes_end1.append(i)
        else:
            modes_begin1.append(i)
            if i != last_mode or n_free:
                shape_begin2.append(1)
            last_mode_is_batched = False
    tensor1 = tl.transpose(tensor1, modes_begin1+modes_end1)

    n_modes_1 = tl.ndim(tensor1) - len(modes1)
    shape = list(tl.shape(tensor1))[:n_modes_1]
    if last_mode_is_batched:
        shape += [1]
    elif shape_begin2:
        shape_begin2.pop(-1)

    tensor1 = tl.reshape(tensor1, (*shape, contraction_dim))
    
    # these are neither batch-size nor contraction modes: put them last
    new_modes2 = [i for i in range(tensor2.ndim) if i not in batch_modes2+modes2]
    new_shape2 = [tl.shape(tensor2)[i] for i in new_modes2]
    tensor2 = tl.transpose(tensor2, batch_modes2+modes2+new_modes2)

    if not new_modes2:
        squeeze_last = True
    else:
        squeeze_last = False
    tensor2 = tl.reshape(tensor2, (*shape_begin2, contraction_dim, -1))
    
    res = torch.matmul(tensor1, tensor2)
    
    out_shape = list(tl.shape(res))
    if squeeze_last:
        out_shape = out_shape[:-1]
    else:
        out_shape = out_shape[:-1] + new_shape2
    if last_mode_is_batched:
        out_shape.pop(n_modes_1)
    
    return tl.reshape(res, out_shape)

def batched_tensor_dot_sum(tensor1, tensor2, modes, batched_modes=()):
    modes1, modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, modes)
    batch_modes1, batch_modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, batched_modes, batched_modes=True)
    
    # Similarly, we reorganize tensor2 by moving the contraction modes before last, and others at the end
    shape1 = tensor1.shape + (1, )*(tensor2.ndim - len(modes2+batch_modes2))
    
    shape2 = []
    permute_modes2 = []
    batch_counter = 0
    contraction_counter = 0
    for i in range(tensor1.ndim):
        if i in batch_modes1:
            index = batch_modes2[batch_counter]
            permute_modes2.append(index)
            batch_counter += 1
            shape2.append(tensor2.shape[index])
        elif i in modes1:
            index = modes2[contraction_counter]
            permute_modes2.append(index)
            contraction_counter += 1
            shape2.append(tensor2.shape[index])
        else:
            shape2.append(1)
    permute_modes2 += [i for i in range(tl.ndim(tensor2)) if i not in permute_modes2]
    new_shape2 = [s for (i, s) in enumerate(tl.shape(tensor2)) if i not in batch_modes2+modes2]
    shape2 += new_shape2

    tensor1 = tl.reshape(tensor1, shape1)
    tensor2 = tl.transpose(tensor2, permute_modes2)
    tensor2 = tl.reshape(tensor2, shape2)
    res = tensor1 * tensor2

    for mode in sorted(modes1, reverse=True):
        res = tl.sum(res, axis=mode)

    return res

Note that the last implementation is mostly a proof of concept for tests and is very memory inefficient.

Performance

I ran some timings on a GPU with the pytorch backend, in several configurations:

Block TT contraction

This use-case happens when manipulating Block TT decompositions.

tensor-train

Other configurations

It seems that einsum and matmul1 version are overall the best ones.

Some more timings:
Brrr,Brrr,-1,2
Brrr,Brrr,1,2,3
Brrr,Brrr,1,2
rBrr,rrBr,(2,-1),(1,-1),1,2
rBrr,rrBr,(2,-1),(1,0),1,2

@codecov

codecov Bot commented Aug 23, 2021

Copy link
Copy Markdown

Codecov Report

Merging #309 (5ef28c4) into main (da8c57f) will increase coverage by 4.67%.
The diff coverage is 91.37%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #309      +/-   ##
==========================================
+ Coverage   85.07%   89.74%   +4.67%     
==========================================
  Files          95       97       +2     
  Lines        5125     5217      +92     
==========================================
+ Hits         4360     4682     +322     
+ Misses        765      535     -230     
Impacted Files Coverage Δ
tensorly/tenalg/einsum_tenalg/__init__.py 100.00% <ø> (ø)
tensorly/utils/__init__.py 100.00% <ø> (ø)
...ensorly/tenalg/einsum_tenalg/_batched_tensordot.py 17.64% <17.64%> (ø)
tensorly/tenalg/core_tenalg/moments.py 42.85% <50.00%> (ø)
tensorly/utils/_prod.py 71.42% <71.42%> (ø)
tensorly/tenalg/tenalg_utils.py 94.73% <94.73%> (ø)
tensorly/tenalg/core_tenalg/outer_product.py 96.42% <96.29%> (+5.51%) ⬆️
tensorly/tenalg/__init__.py 77.77% <100.00%> (-0.35%) ⬇️
tensorly/tenalg/core_tenalg/__init__.py 100.00% <100.00%> (ø)
tensorly/tenalg/core_tenalg/_batched_tensordot.py 100.00% <100.00%> (ø)
... and 22 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update da8c57f...5ef28c4. Read the comment docs.

Comment thread tensorly/tenalg/core_tenalg/_batched_tensor_dot.py

Parameters
----------
modes : int or tuple[int] or (modes1, modes2)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you distinguish between the second and third case?

a_shape  = (5, 4, 5)
b_shape  = (8, 4, 5)
_validate_contraction_modes(a_shape, b_shape, (1,2))

currently returns an error. I think we should remove the second case from the docstring. This will also keep it consistent with tl.tensordot.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hesitant about this very usecase. It is convenient to be able to pass a tuple in. But then the case where it's of length 2 is indeed ambiguous, so I left the error. Might indeed be best to keep it consistent with tensordot.

from math import prod


def batched_tensor_dot(tensor1, tensor2, modes, batched_modes=()):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change its name to batched_tensordot to both be consistent with tenordot and to differentiate it with the removedtl.tenalg.batched_tensor_dot which was an outer product.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm conflicted - I have implemented other version for factorized tensors, e.g. tensor_dot_tucker or tucker_dot_tucker. It would be more consistent to use x_dot_y everywhere. I don't think differentiating with the previous one is necessary - we can document the change and it wasn't used anywhere in the library anyway.

However, being consistent with Numpy might be reason enough. If we go for batched_tensordot we could actually consider just calling it tensordot, with an additional batched_modes parameter.

Comment thread tensorly/tenalg/core_tenalg/_batched_tensor_dot.py Outdated
Comment thread tensorly/tenalg/core_tenalg/_batched_tensor_dot.py Outdated
Comment thread tensorly/tenalg/core_tenalg/_batched_tensor_dot.py
@JeanKossaifi

Copy link
Copy Markdown
Member Author

So the two upstanding questions would be:

Naming (batched_tensor_dot or batched_tensordot)

batched_tensor_dot

I called it that for consistency with other functions I am writing for other version for factorized tensors, e.g. tensor_dot_tucker or tucker_dot_tucker. It might be more consistent to use x_dot_y everywhere. I don't mind the name clash with the previous function we had in TensorLy.

batched_tensordot

The reason to use this name is to be consistent with Numpy might be reason enough. If we go for this batched_tensordot, we could actually consider just calling it tensordot, with an additional batched_modes parameter.

Case where modes is a single tuple or list

This isn't allowed by NumPy but I allowed it.
If there is a single tuple, it is assumed the modes are the same for both tensors.

There is an ambiguous for the case where the tuple is of length 2, e.g.

tensordot(tensor1, tensor2, [a, b])

Does the above mean we are contracting tensor1's mode a with tensor2's mode b or tensor1 and tensor2 along the sames a and b? Currently this case raises an error.

@JeanKossaifi

Copy link
Copy Markdown
Member Author

Renamed the function to tensordot. Merging for now, we can modify name and API as needed.

@JeanKossaifi JeanKossaifi merged commit c6074c7 into tensorly:main Aug 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

tensordot vs tensor_dot vs contract vs inner

2 participants