Spaces:
Running
Running
Commit
·
76aa810
1
Parent(s):
d1a29c6
test: fix OPT_STEP_ADAMW for test-backend-ops (ggml/974)
Browse files- ggml/include/ggml.h +1 -0
- ggml/src/ggml.c +6 -4
ggml/include/ggml.h
CHANGED
|
@@ -2052,6 +2052,7 @@ extern "C" {
|
|
| 2052 |
GGML_API struct ggml_tensor * ggml_opt_step_adamw(
|
| 2053 |
struct ggml_context * ctx,
|
| 2054 |
struct ggml_tensor * a,
|
|
|
|
| 2055 |
float alpha,
|
| 2056 |
float beta1,
|
| 2057 |
float beta2,
|
|
|
|
| 2052 |
GGML_API struct ggml_tensor * ggml_opt_step_adamw(
|
| 2053 |
struct ggml_context * ctx,
|
| 2054 |
struct ggml_tensor * a,
|
| 2055 |
+
struct ggml_tensor * grad,
|
| 2056 |
float alpha,
|
| 2057 |
float beta1,
|
| 2058 |
float beta2,
|
ggml/src/ggml.c
CHANGED
|
@@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
|
|
| 7818 |
struct ggml_tensor * ggml_opt_step_adamw(
|
| 7819 |
struct ggml_context * ctx,
|
| 7820 |
struct ggml_tensor * a,
|
|
|
|
| 7821 |
float alpha,
|
| 7822 |
float beta1,
|
| 7823 |
float beta2,
|
| 7824 |
float eps,
|
| 7825 |
float wd) {
|
| 7826 |
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
|
|
|
|
| 7827 |
GGML_ASSERT(alpha > 0.0f);
|
| 7828 |
GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
|
| 7829 |
GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
|
|
@@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw(
|
|
| 7842 |
|
| 7843 |
result->op = GGML_OP_OPT_STEP_ADAMW;
|
| 7844 |
result->src[0] = a;
|
| 7845 |
-
result->src[1] =
|
| 7846 |
-
result->src[2] = ggml_dup_tensor(ctx,
|
| 7847 |
-
result->src[3] = ggml_dup_tensor(ctx,
|
| 7848 |
|
| 7849 |
return result;
|
| 7850 |
}
|
|
@@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw(
|
|
| 18769 |
|
| 18770 |
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
| 18771 |
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
| 18772 |
-
struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
|
| 18773 |
ggml_build_forward_expand(gb, opt_step);
|
| 18774 |
}
|
| 18775 |
}
|
|
|
|
| 7818 |
struct ggml_tensor * ggml_opt_step_adamw(
|
| 7819 |
struct ggml_context * ctx,
|
| 7820 |
struct ggml_tensor * a,
|
| 7821 |
+
struct ggml_tensor * grad,
|
| 7822 |
float alpha,
|
| 7823 |
float beta1,
|
| 7824 |
float beta2,
|
| 7825 |
float eps,
|
| 7826 |
float wd) {
|
| 7827 |
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
|
| 7828 |
+
GGML_ASSERT(ggml_are_same_shape(a, grad));
|
| 7829 |
GGML_ASSERT(alpha > 0.0f);
|
| 7830 |
GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
|
| 7831 |
GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
|
|
|
|
| 7844 |
|
| 7845 |
result->op = GGML_OP_OPT_STEP_ADAMW;
|
| 7846 |
result->src[0] = a;
|
| 7847 |
+
result->src[1] = grad;
|
| 7848 |
+
result->src[2] = ggml_dup_tensor(ctx, grad);
|
| 7849 |
+
result->src[3] = ggml_dup_tensor(ctx, grad);
|
| 7850 |
|
| 7851 |
return result;
|
| 7852 |
}
|
|
|
|
| 18771 |
|
| 18772 |
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
| 18773 |
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
| 18774 |
+
struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
|
| 18775 |
ggml_build_forward_expand(gb, opt_step);
|
| 18776 |
}
|
| 18777 |
}
|