| | #include "torch_binding.h" |
| |
|
| | #include <torch/library.h> |
| |
|
| | #include "registration.h" |
| |
|
| | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
| | |
| | ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, " |
| | "float eps) -> ()"); |
| | ops.impl("poly_norm", torch::kCUDA, &poly_norm); |
| |
|
| | ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! " |
| | "bias_grad, Tensor output_grad, Tensor input, Tensor weight, float " |
| | "eps) -> ()"); |
| | ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward); |
| |
|
| | |
| | ops.def( |
| | "rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()"); |
| | ops.impl("rms_norm", torch::kCUDA, &rms_norm); |
| |
|
| | ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor " |
| | "output_grad, Tensor input, Tensor weight, float eps) -> ()"); |
| | ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward); |
| |
|
| | |
| | ops.def("fused_mul_poly_norm(Tensor! out, Tensor input, Tensor mul, Tensor " |
| | "weight, Tensor bias, " |
| | "float eps) -> ()"); |
| | ops.impl("fused_mul_poly_norm", torch::kCUDA, &fused_mul_poly_norm); |
| |
|
| | ops.def("fused_mul_poly_norm_backward(Tensor! input_grad, Tensor! mul_grad, " |
| | "Tensor! weight_grad, Tensor! " |
| | "bias_grad, Tensor output_grad, Tensor input, Tensor mul, Tensor " |
| | "weight, Tensor " |
| | "bias, float eps) -> ()"); |
| | ops.impl("fused_mul_poly_norm_backward", torch::kCUDA, |
| | &fused_mul_poly_norm_backward); |
| |
|
| | |
| | ops.def( |
| | "fused_add_rms_norm(Tensor! out, Tensor! add_out, Tensor input, Tensor " |
| | "residual, Tensor " |
| | "weight, float eps) -> ()"); |
| | ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); |
| |
|
| | ops.def( |
| | "fused_add_rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, " |
| | "Tensor " |
| | "output_grad, Tensor add_output_grad, Tensor input, Tensor weight, float " |
| | "eps) -> ()"); |
| | ops.impl("fused_add_rms_norm_backward", torch::kCUDA, |
| | &fused_add_rms_norm_backward); |
| | } |
| |
|
| | REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
| |
|