| [general] | |
| name = "activation" | |
| universal = false | |
| [torch] | |
| src = [ | |
| "torch-ext/torch_binding.cpp", | |
| "torch-ext/torch_binding.h" | |
| ] | |
| [kernel.activation] | |
| backend = "rocm" | |
| rocm-archs = [ "gfx90a", "gfx942" ] | |
| src = [ | |
| "activation/poly_norm.cu", | |
| "activation/fused_mul_poly_norm.cu", | |
| "activation/rms_norm.cu", | |
| "activation/fused_add_rms_norm.cu", | |
| "activation/cuda_compat.h", | |
| "activation/dispatch_utils.h", | |
| "activation/assert_utils.h", | |
| "activation/atomic_utils.h", | |
| ] | |
| depends = [ "torch" ] | |
| [kernel.activation_cuda] | |
| backend = "cuda" | |
| src = [ | |
| "activation/poly_norm.cu", | |
| "activation/fused_mul_poly_norm.cu", | |
| "activation/rms_norm.cu", | |
| "activation/fused_add_rms_norm.cu", | |
| "activation/cuda_compat.h", | |
| "activation/dispatch_utils.h", | |
| "activation/assert_utils.h", | |
| "activation/atomic_utils.h", | |
| ] | |
| depends = ["torch"] | |