#Any examples of algorithm.reduction.argmax usage?

6 messages · Page 1 of 1 (latest)

frank cave
#

I am writing fast where function, and from description argwhere looks like a good candidate but can't figure out how to use it. For example what is OutputChainPtr?
@always_inline fn where(self, val: Int8, vec: DTypePointer[DType.int8]): @parameter fn w[nelts : Int](i : Int): var res = (self.values.simd_load[nelts](i) == val) #algorithm.reduction.argmax algorithm.vectorize[nelts, w](ar_size*ar_size)

frank cave
#

@rich sundial hi, any help here? I have spent a week trying to find a way to use SIMD for tasks like argwhere(

rich sundial
frank cave
#

hi @rich sundial no new so far?

rich sundial
#

Hi @frank cave it's very low level at the moment:

from memory.buffer import NDBuffer
from runtime.llcl import Runtime, OwningOutputChainPtr
from algorithm.reduction import argmax

fn main():
    alias size = 42

    let vector = NDBuffer[1, DimList(size), DType.int32].stack_allocation()
    let output = NDBuffer[1, DimList(1), DType.index].stack_allocation()

    for i in range(size):
        vector[i] = i

    with Runtime() as runtime:
        let out_chain = OwningOutputChainPtr(runtime)
        argmax(
            rebind[NDBuffer[1, DimList.create_unknown[1](), DType.int32]](vector),
            0,
            rebind[NDBuffer[1, DimList.create_unknown[1](), DType.index]](output),
            out_chain.borrow(),
        )
        out_chain.wait()
        print("argmax:", output[0])
frank cave
#

Thanks, @rich sundial . So, does this function only return a single element? The documentation says it "Finds the indices of the maximum element along the specified axis." I was actually expecting it to return multiple indices if there are repeated elements, which is why I'm interested in this function.