Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

repo-sync-2024-11-01T13:48:20+0800 #899

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions libspu/dialect/pphlo/IR/fold.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ OpFoldResult ReverseOp::fold(FoldAdaptor) {
dims, [&](int64_t dim) { return shapedType.getDimSize(dim) == 1; })) {
return input;
}

// reverse(reverse(x, dims), dims) = x
if (auto prev = input.getDefiningOp<ReverseOp>()) {
if (prev.getDimensions() == dims) {
return prev.getOperand();
}
}

return {};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ std::vector<NdArrayRef> reduce(ReduceOp op,
ring_mul_(rs[idx], t);
}
} else {
SPU_ENFORCE("not supported reduction op");
SPU_THROW("not supported reduction op");
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions libspu/mpc/semi2k/prime_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ NdArrayRef MulPriv(KernelEvalContext* ctx, const NdArrayRef& x) {

// P0 sends (x+a) to P1 ; P1 sends (y+b) to P0
comm->sendAsync(comm->nextRank(), ring_add(a_or_b, x), "(x + a) or (y + b)");
xa_or_yb = comm->recv(comm->prevRank(), x.eltype(), "(x + a) or (y + b)");
xa_or_yb = comm->recv(comm->prevRank(), x.eltype(), "(x + a) or (y + b)")
.reshape(x.shape());
// note that our rings are commutative.
if (comm->getRank() == 0) {
ring_add_(c, ring_mul(std::move(xa_or_yb), x));
Expand Down Expand Up @@ -198,4 +199,4 @@ NdArrayRef ConvMP(KernelEvalContext* ctx, const NdArrayRef& h,
return x;
}

} // namespace spu::mpc::semi2k
} // namespace spu::mpc::semi2k
5 changes: 5 additions & 0 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"executionEnvironments": [
{"root": "."}
]
}
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
grpcio>=1.42.0,!=1.48.0
numpy>=1.22.0
numpy>=1.22.0, <2 # FIXME: for SF compatibility
protobuf>=4, <5
cloudpickle>=2.0.0
multiprocess>=0.70.12.2
Expand Down
8 changes: 2 additions & 6 deletions sml/linear_model/emulations/quantile_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import jax.numpy as jnp
Expand Down Expand Up @@ -48,15 +49,10 @@ def proc(X, y):
def generate_data():
from jax import random

# 设置随机种子
key = random.PRNGKey(42)
# 生成 X 数据
key, subkey = random.split(key)
X = random.normal(subkey, (100, 2))
# 生成 y 数据
y = (
5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1
) # 高相关性,带有小噪声
y = 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1
return X, y

try:
Expand Down
2 changes: 1 addition & 1 deletion sml/linear_model/tests/quantile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def generate_data():
# run
# Larger max_iter can give higher accuracy, but it will take more time to run
proc = proc_wrapper(
quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=200
quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=20
)
result, coef, intercept = spsim.sim_jax(sim, proc)(X, y)
rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2))
Expand Down
3 changes: 0 additions & 3 deletions sml/linear_model/utils/_linprog_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def _pivot_col(T, tol=1e-5):

all_masked = jnp.all(mask)

# 定义根据最小值选择列的函数
ma = jnp.where(mask, jnp.inf, T[-1, :-1])
min_col = jnp.argmin(ma)

Expand All @@ -44,12 +43,10 @@ def _pivot_row(T, pivcol, phase, tol=1e-5, max_val=1e10):

q = jnp.where(ma >= max_val, jnp.inf, mb / ma)

# 选择最小比值的行
min_rows = jnp.nanargmin(q)
all_masked = jnp.all(mask)

row = min_rows
# 处理全被掩盖的情况
row = jnp.where(all_masked, 0, row)

return ~all_masked, row
Expand Down
1 change: 0 additions & 1 deletion spu/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ py_binary(
srcs = ["jnp_debug.py"],
deps = [
"//spu:api",
"//spu/intrinsic:all_intrinsics",
"//spu/utils:simulation",
],
)
Expand Down
7 changes: 3 additions & 4 deletions spu/tests/jnp_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import jax.numpy as jnp
import numpy as np

import spu.intrinsic as si
import spu.spu_pb2 as spu_pb2
import spu.utils.simulation as ppsim

Expand All @@ -31,9 +30,9 @@
copts.disable_div_sqrt_rewrite = True

x = np.random.randn(3, 4)
y = np.random.randn(5, 6)
fn = lambda x, y: si.example_binary(x, y)
# fn = lambda x, y: jnp.matmul(x, y)
y = np.random.randn(4, 5)
fn = lambda x, y: jnp.matmul(x, y)

spu_fn = ppsim.sim_jax(sim, fn, copts=copts)
z = spu_fn(x, y)

Expand Down
Loading